Transfer Learning

What is Transfer Learning?

Transfer learning is an approach used to transfer information from one machine learning task to another. Practically speaking, a pre-trained model that was trained for one task is re-purposed as the starting point for a new task. As a result, great amounts of time and resources can be saved by transfer learning.

Creating complex models from scratch requires vast amounts of compute resources, data, and time. Transfer learning accelerates the process by leveraging commonalities between tasks (such as detecting edges in images) and applying those learning to a new task. Training time for a model can go from weeks to hours, making machine learning more commercially viable for many businesses.

Transfer learning is very popular in domains like computer vision and NLP where large amounts of data are needed to produce accurate models.

An Example

Let's say we have merely 1,000 images of an object we want to classify. If we take a pre-trained CNN such as ResNet-50 which was trained on millions of images, we can re-train the model on our small dataset and build a state-of-the-art model with minimal effort. In neural networks, this is achieved by removing the final layer in the existing model (called the "loss output" layer) and replacing it with a new layer for the intended prediction.

Last updated