Why Batch Norm works
This article was written after listening to Coursera’s batch norm lecture. Rather than summarizing the lecture content as it is, it is an article expressed in my own words.
Why batch norm works
In the last article, I wrote about how to implement Batch Norm and set what learnable parameters. This article will look at why the batch norm works at a more low-level level.
Covariance shift
Let’s assume the scenario where you create an ‘image classification model to classify cats. Our train dataset only contains pictures of black cats.
However, the test dataset contains colorful cat pictures.
Even if the model works well on the train, it is likely not so well on the test. This is likely to be difficult to classify in the test because the distribution of the data is different. In this case, it may be necessary to retrain the model from scratch. Even if the ground truth does not change, we will have to re-learn it. we call these distribution difference as "covariance shift"
Batch Norm In Neural Net
There is the output of the second hidden layer, \(a^{[2]}\). From the perspective of all 3rd, 4th, and 5th hidden layers, we update the parameters of the following hidden layers so that they can generalize well to the input \(a^{[2]}\).
However, the parameters are updated through the learning of the hidden layer of the first and second layers. That is since the distribution of \(a^{[2]}\) changes every time, from the perspective of the whole Neural Net, generalization is attempted on a new input every time. That is, a covariance shift occurs. Therefore, the mean and variance are unified through the batch norm to enable faster convergence. Again, mean and variance do not always follow N(0,1) after normalization. Using the learnable parameters of \(\alpha\) and \(\beta\), mean and variance are also updated with values that generalize well.
What Batch Norm Essentially Does
Batch Norm essentially reduces the amount of change in w2 and b2. We limit the range of \(a^{[2]}\) by specifying \(\alpha 2 , \beta 2\) . By limiting the range that \(a^{[2]}\) can go, we can train neural nets more reliably. In other words, it reduces the amount by which the distribution shifts. It can be seen that the previous hidden layers reduce the amount of learning for subsequent hidden layers. In other words, w3 and b3 learn about \(a^{[2]}\) independently from other layers. It becomes independent from the perspective of other layers. As a result, it speeds up the convergence of the network. When Batch Norm is not applied, convergence does not work well, so you may think that it is a parameter that affects performance. But, it's not true.
Unexpected regularization
When Batch Norm is applied in Neural Net, unintentional regularization is applied. Batch Norm is literally a method of updating the mean and variance after calculating the gradient for a mini-batch. However, since it is a mini-batch rather than the entire data, some noise is added to the mean and variance compared to the mean and variance of the entire data. Since a little noise of the mean and variance of the entire train dataset is added, the effect of regularization occurs . It is similar to regularization by adding some noise to the hidden layer output each time with dropout. However, it should be noted that the batch norm should not be used for regularization because it is literally unexpected regularization.
Leave a comment