How to train a neural network
We now know how to model a neural network. Our next step is to find all the weights that will make our model give us the output we want. For this page, I'm going to assume you know a little bit of calculus. This includes what a derivative is and using the chain rule. Also, it would help to know a little bit of linear algebra. I highly recommend checking out 3Blue1Brown on youtube if you are rusty on your calculus and linear algebra.
Mapping inputs to outputs
Say we have a list of inputs and their outputs. These could be anything, stock features and price, property features and sale price, student behaviors and average test scores, they could be whatever; but let's keep things general and say we have some inputs and outputs.
Our goal is to find the output to that last input. To do this, we must find the relationship (function) that maps these inputs to their outputs. Now for the sake of learning I'm going to cheat a little bit and tell you that this relationship is linear. This tells us that we can use the identity activation function instead of the sigmoid, and there is only 1 layer (the output layer).
If you don't see why a linear relationship only gives us one layer with the identity activation function; just bear with me. Let me give you a visual of our neural network and show you the function it represents before I explain.
Now, if we used sigmoid instead of identity, the above function would not be linear. But we know that is it linear because I cheated and told you (after you learn how NN's work, you will see that you can solve this without knowing it is linear, but for the sake of learning we should start small with an easy, linear example). Also, if we had more layers that used the identity function then we could always simplify it down to the above equation. Try it yourself by adding a middle layer with 2 neurons each with the identity activation function, then write the equation that corresponds to the network and simplify it.
You could easily solve for the weights here using many different methods. After all, it is just a system of linear equations. But, let's try to solve this in the general case that we do NOT have a linear relationship.
Measuring how wrong we are, using an error function
So, let's have some fun. How could we solve for the weights without treating this as a system of linear equations? Well, you have to start somewhere, so let's first just throw a blind guess at what they could be. I'm going to guess that W1 = .2, W2 = 1.4, and W3 = 30.
Now let's see how good our guess was and test it on one of the inputs.
Recall, the correct output (or expected output) should be 13.5; we are way off! Let's actually quantify how wrong we are, using an error function. When the NN outputs the correct value, we want the error to be 0. Let's start with; Error = Outputexpected - OutputNN (ex; 0 = 13.5-13.5 if our weights were correct).
We are almost there. Think practically, our error is something we want to minimize. Well what's the minimum of our current error function? There is no minimum! As it is, we can have -∞ error. We fix this by squaring the function, so it is not negative (in mathematics you never want to deal with an absolute value!). So, we have;
Minimizing how wrong we are (minimizing our error)
Play around with different weights using the sliders below:
Grab one of the sliders and change the weight by a very small amount and watch the error. You will notice that when you change the weight by a very small amount (∂W), you get and change in error (∂E). So, if we wanted to find how much a change in this weight affects the error, we can take the ratio of the 2, ∂E/∂W. This can be recognized as the partial derivative of the error function with respect to this weight, or how a small change in this weight changes the error.
How does this help us? Well, if you are good at calculus you may know where this is headed. But, if you are not, take a look at the graph below.
This is just an arbitrary function. Take any point on this line, and put your finger on in. Now, say out loud whether the slope at this point is either positive or negative. If the slope is positive, move your finger along the line a little bit in the +X direction. If it is negative, move it in the -X direction. If the slope is 0...pick another point. Do this for a couple points. Notice anything?
No matter what, you will always be moving your finger towards a higher value on the line. So, we can say that the direction of the derivative is always the direction of ascent. If the derivative is negative, a decrease in X will increase Y. If the derivative is positive, an increase in X will increase Y.
Back to our NN with ∂E/∂W. We want to decrease our error, so we want to move our weight in the direction of descent. We can do this by computing this derivative, then negating it. Now, we move our weight a tiny step in the direction of this negated derivative. Or;
This is just for one weight in our network. We can write this expression in matrix form, remembering that the gradient is just a vector containing all the derivatives ∂E/∂Wi. The matrix form is shown below;
Now the above equation can be repeatedly applied to the weights until they converge to the correct values. Finding this gradient is a bit of a pain if our network is larger, but it is always possible as long as we use differentiable activation functions. The repeated application of this equation is called the gradient descent algorithm.
Let me summarize the past 2 pages. We have data, and some outputs. We know there is some relationship between this data and the outputs. Inspired by the human brain, we can construct an artificial neural network. We initialize this neural network with random weights, then start testing it on the data and outputs we have. As we are testing, we quantify have wrong our network is using the error function, then minimize this error function using gradient descent.
But hold on, we still haven't found what that last output was in our example! I just told you a whole bunch of magic without proving it actually works!
I would never do such a thing. Mere; and I will prove to you that this isn't magic, just calculus...