Neural Networks
Neural networks are one of the most widely used machine learning algorithms because they work amazingly well on a lot of different tasks. They learn from looking at lots of examples where each example is labeled with a correct answer (think for example of looking at a loan application and deciding whether or not that person should be awarded a loan). One of the challenges with neural networks is that we don't always understand what they learn. We just know at the end of the day that we can give it a new example that doesn't have a lable and it can predict with fairly high accuracy what the label should be.
Neural networks (as the name implies) are inspired from the ways that neurons in our bodies work. Paths through the neurons are strengthened when they fire a lot and are weakened when they don't fire very often. Over time, useful patterns get reinforced (this is how we learn) and noise just fades away. The path between two neurons is called a synapse.
To get a feel for how this works, try out the following activity.
Example: Do you wanna be a neuron?
This activity is intended to be performed with a group of people. In the following problem, say you're given the following training examples together with their labels:
Homework Completed |
Good Exam Scores |
Bias (always 1) |
Passes the Class (Label) |
1 |
0 |
1 |
1 |
1 |
1 |
1 |
1 |
0 |
0 |
1 |
0 |
0 |
1 |
1 |
1 |
This activity is best accomplished if participants situate themselves in 4 rows as follows (these 4 rows have nothing to do with the 4 rows in the table):
- First row: 3 people ("input neurons") - each person represents a unique feature (e.g., person 1 always represents "Homework Completed")
- Second row: 3 people ("synapses") - each person is partnered with a unique input neuron person from the first row. Each synapse has a strength. They should start with a random strength. Their strength will get updated over time and they need to keep track of the updated strength for the duration of the exercise. For the purposes of our example, the random starting strengths for each synapse should be as follows:
- Synapse strength for "Homework Completed" = 1
- Synapse strength for "Good Exam Scores" = 1
- Synapse strength for "Bias" = -1.1
- Third row: 1 person ("output neuron")
- Fourth row: 1 person ("error computer")
For each example (i.e., row) in the table above, repeat the following steps:
- Input neurons assume their respective values (e.g., for the first row, the person representing "Homework Completed" assumes the value 1 while the person representing "Good Exam Scores" assumes the value 0). The error computer assumes the value of the "Label".
- Each input neuron communicates its value to its partner synapse.
- Each synapse calculates its signal. This is computed as follows: signal = value x strength. "value" is the value from the input neuron.
- Each synapse communicates its signal to the output neuron.
- The output neuron sums up the three signals. If the sum is greater than 1, the output neuron yells "FIRE" and assumes the value 1. Otherwise, the output neuron assumes the value 0.
- The output neuron communicates its value to the error computer.
- If the value from the output neuron matches the "Label", yell "BINGO". Otherwise yell "WRONG". The error computer calculates the error as follows: error = 0.1 * ("Label" - output value).
- The error computer communicates the error to the output neuron, which in turn communicates it to each of the synapses (this is called "backpropogating the error").
- Each synapse updates its strength as follows: strength = error * value + old strength (consult with the input neurons if you forgot the "value").
Going once through the whole table above is called an epoch. You should complete another epoch until you're getting "BINGO" on every row in the epoch. Synapses should keep their updated strength values across epochs.
Follow-up Questions
- What was the accuracy of the network on the first epoch? The last epoch?
- If the network starts with random values, how is it able to improve accuracy on each epoch?
- What part of the network persists over each row and each epoch? Is there any other part of the network that persists? Where then is the learning taking place?
- By looking at the synapse strengths, can you summarize or explain what the network is learning? (in this example such may be possible, but when multiple layers are added it becomes much more difficult)
- Looking at the table, what pattern do you see that the network is learning? How many unique rows are possible with 2 input features?
- What if we added another input feature (e.g., "Caught Cheating"). How would this change the layout of the network (call the topology)? How many unique rows are possible with 3 input features? Do you think more features would require more or less training? Why?
- Imagine a large dataset with many features. Some of the possible unique rows are missing (almost always the case in the real world). If you wanted to predict the Label for one of these missing rows of features, how could you use the network to predict the Label?
- What other problems (like the problem represented in the table above) might you want a neural network to learn?