ELI5: what is backpropagation?

Imagine a game of Jenga. As you play, you change the tower piece by piece, with the goal of creating the tallest tower you can.

Now, imagine if you could see the winning tower, (the last one before it topples), before you start the game. You would know all the bricks that change, and you need only work out when and how each brick can move.

Backpropagation is all about seeing that winning tower when training machine learning algorithms. Here, we look at this machine training method, and why it’s useful.

What is backpropagation?

Backpropagation is a method used in supervised machine learning. It involves lots of complicated mathematics such as linear algebra and partial derivatives.

But that’s all a bit confusing.

In simpler terms, backpropagation is a way for machine learning engineers to train and improve their algorithm. It involves using the answer they want the machine to provide, and the answer the machine gives.

With these two differing answers, engineers use their maths skills to calculate the gradient of something called a ‘cost function’ or ‘loss function’. This is a way to represent the gap between the result you want and the result you get.

Backpropagation then takes this ‘cost function’ calculation to map how changes to the algorithm will affect the output of the system.

So, what is backpropagation? The shortest answer is that it’s a way to train AI to continually improve its performance.

Artificial neural networks

Looking deeper into the ‘what is backpropagation’ question means understanding a little more about what it’s used to improve. That is, artificial neural networks and their nodes.

Backpropagation is used when training artificial neural networks (ANNs). (As with deep learning, for instance.)

An ANN consists of layers of nodes. They act rather like a filter. So, you feed your input into the one end, it filters through layers of nodes, and then you get the final output, or answer. Each node processes the information it gets, and its output has a given weight. This weight determines how important that node is to the final answer – the output your ANN ultimately provides.

This means that a more specific answer to “what is backpropagation” is that it’s a way to help ML engineers understand the relationship between nodes. This, in turn, helps them look at what needs to change in the hidden layers of your network. 

Creating a map

A good way to look at backpropagation is to view it as creating a map of the possible outcomes of your machine learning algorithm.

So, backpropagation maps all the possible answers the algorithm could provide when given input A. From there, the engineer can choose the point on the map where the loss function is the smallest. (I.e. the point in which the AI’s answer best matches the correct answer.) The result is that the output of the algorithm is the closest to the desired outcome.

Then, the AI technicians can use maths to reverse engineer the node weights needed to achieve that desired output. In this way, backpropagation lets machine learning engineers work backwards to train their system.

Why use backpropagation?

Let’s go back to the game of Jenga. With each piece you remove or place, you change the possible outcomes of the game. Removing one of the pieces renders others integral, while adding a piece creates new moves.

And changing the wrong piece makes the tower topple, putting your further from your goal.

It’s the same for machine learning. When the nodes change weight, it changes how the whole system works. So, if an engineer changes the weight of one node, it makes a chain reaction that affects the output from all the other nodes.

So, changing these nodes one-by-one in pursuit of the desired output is a herculean task. Backpropagation, meanwhile, gives engineers a way to view the bigger picture and predict the effect that each node has on the final output.

In short, it’s a consistent and more efficient way to improve an ANN.

Maths and machine learning

What is backpropagation? A deep understanding involves complex linear algebra and complicated mathematics. But if it ever comes up in casual conversation, now you know how to give a simplified answer.

Backpropagation is a way for ML programmers to map out the potential outputs of their neural networks. (And so, help them find the routes to the outputs they want.)

Useful links

What is machine learning? A beginner’s guide

ELI5: what is an artificial neural network?

ELI5: what is deep learning?