You would think transfer learning is very straightforward and simple. Alas, why not, you just initialize weights from a proven NN which solved a different, but similar enough problem. Simple enough. Right?
- Yes, if you just want to get something working.
- No, if you want to make sense of it after delving deeper into it or you want to create the absolute best transfer-learned model there is.
Let me elaborate.
While working on Thursday, our team devised an amusing and interesting game idea called Doodle Race. What is the doodle race game? You draw and AI guesses! Whoever manages to make AI guess the fastest, wins!
We decided to use the quickdraw dataset, a collection of 50 million drawings across 345 doodle categories, to build a Convolutional Neural Network (CNN) architecture model. CNN architecture is used for understanding and interpreting images or visual data or even timeseries, speech, etc.
The smaller image sizes of quickdraw dataset made it our best choice to experiment while keeping the computational costs lower. Our experiments began with a focus on finding the best solution out of two, either creating a CNN from scratch or using the transfer learning technique on a pretrained model (like efficientnet-b0).
In the process, we also started experimenting to determine what is the best way of transfer learning. While an upcoming blog post will talk more about our experiments, this blog will be focused on the theory behind transfer learning, what theory says is the best approach to build any transfer learning model, and whether the theory has actually been proven right or not.
We started building a classification CNN model to distinguish 10 doodle categories using transfer learning on the efficientnet-b0 base model.
The efficientnet-b0 is trained to solve a much bigger ImageNet dataset which consists of 1000 categories, with real-world RGB images of resolution 224x224. Whereas our experimental dataset consists of images from 10 categories, that are grayscale, and 28x28 resolution.
It was clear enough that efficientnet-b0 was an overkill in our scenario. And, hence had our doubts if it will hinder the performance on this relatively simple task. Sit tight, to know the answer for this in our upcoming blog post.
A sneak peek into our upcoming blog post: To solve the above-mentioned problem, using transfer learning, we modify the input layer of the efficientnet-b0 model to accept the grayscale images and the last layer to output only 10 categories instead of 1000. What this means is that only the modified layers were randomly initialized, while the rest of the neural network was initialized with weights from the pretrained model.
Fun fact:
Did you know that Yolo segmentation model transfer learning works even if you keep the output size the same and your data uses only a small percent of output categories, YOLO makers claim, that if you train their model using more output neurons compared to the distinct classes in your data, their NN should automatically learn to default output to zeros, for the unused output neurons. You can read the comment here. Really wild, how much flexibility NNs give us.
Without going too deep into the implementation details, let's jump right into the theory.
Transfer learning is a very useful tool in AI, where a pre-trained model, used to solve a different, but similar enough task, is used as a base, to train a model to solve the current task. So, in the case of transfer learning, to solve a specific dataset, you bring architecture, and weights from a different NN model that has been proven to work well on a similar task. This way, you can use its learnings from solving a different, but similar enough dataset, to adapt to working well on your current task/dataset. This is especially useful in cases where your dataset size is small or your current task is very complex to build new NN architecture and train from scratch.
Widely, there are two ways to go about transfer learning:
- Freeze the layers with weights from the pretrained model for the first few epochs and only train the new or the last few layers during this time.
- Don’t freeze layers, train every single parameter of the entire model right from the start.
The first way is recommended, as it prevents the knowledge (i.e. weights and biases we ported) from the pretrained model from evaporating completely due to the very large gradient that comes from the last few or randomly initialized layers, especially at the very beginning of training. The first few steps of training see a very large gradient being backpropagated throughout the entire network, since the weights were initialized randomly. Hence, the loss differential for those layers is very high at the start of training and gradually reduces as the neural network starts to learn and optimize itself.
Also, during transfer learning the last layers are mostly modified, as the last layers are responsible for the final outputs and learn more dataset distribution-specific knowledge as compared to the initial layers, which learn usecase-specific knowledge. For example: In a dogs vs cats CNN, the final layers are responsible for differentiating between cats and dogs. Whereas, the initial layers learn image-specific knowledge and feature extraction, i.e. reduce the dimensionality of an image from its original size into just a few numbers which allows the final layers to classify.
Neural networks are all about traveling in a very high dimensional space (dimensionality of the space = number of trainable parameters in a model) to arrive at an optimal position where NNs are in a state that is useful to humans. So, freezing the initial layers for a few steps or epochs and only updating the final/random weight layer's neurons during this time, would allow our NN to come closer to the area where the initial pretrained model resided. We know that this area is good because our current use case and the pretrained model's use case are similar.
Now, once the frozen finetuned model comes near the optimal area after the first few steps/epochs, we unfreeze all layers and start training all the neurons together, so that the final model moves to the absolute best state for the current use case, while starting from an optimal state.
All this makes sense theoretically and thus, is the recommended way of finetuning any neural network. I mean freezing should never hurt model performance. Right? After unfreezing all the model layers, weights cannot be in a worse position as compared to just randomly initializing newly built layer weights.
Let's say we freeze the unmodified layer weights for two epochs, then at the start of the 3rd epoch, we unfreeze all model parameters. It basically means “epoch-2-freezing-way” is equivalent to “epoch-0-no-freezing-way” with just one exception, “epoch-2-freezing-way” has trained weights in the modified layers and “epoch-0-no-freezing-way” has random weights for the modified layers. However, the unmodified layer weights are exactly the same for both “epoch-0-no-freezing-way” and “epoch-2-freezing-way”.
So, we can say that any training is ultimately better than random initializations, or at least as good as random initializations. Because the trained weights are ultimately one special case of random initializations with a very low probability. That means, if you randomly initialize the weights for a large number of times, at some point you will automatically get the set of trained parameter values.
But do real-world experiments support this conclusion, that freezing is at least as good as non-freezing though it might take more epochs? What do you think?
If you HAVE an answer, then you are wrong, let me explain.
If AI was anything like Physics, yes, you will HAVE one answer, at least in most cases, unless you decide to tread into the quantum world where all your answers are just probabilities. But, AI is not like physics, we do not fully understand AI. So a lot depends on the dataset type, its size, and numerous other factors, where the theoretically recommended way of transfer learning has been proven to worsen the performance as compared to bettering it in a few use cases (this and this).
So does that mean the recommended way of transfer learning and the theory we discussed, so far, are wrong?
No, we will soon come up with another blog post on our real experiments to show the merits and demerits of a large variety of transfer learning methods. Also, we will talk about the process you should adopt to solve any problem with transfer learning, keeping practical model performance needs and available computing resources in context.