Self-Supervised Learning: An Introduction

One topic that has fascinated me and quickly became the focus of my current research is self-supervised learning (SSL). I’m writing this post to serve as an introduction to this exciting subfield of machine learning, and as a stepping stone towards more discussion later on; I’m planning to cover a few papers in SSL, so if you’re interested in that, we’ll need to go through some background first! It’s also important to note that we’ll be talking mainly about how SSL is implemented in deep learning (DL), although many concepts should apply to other machine learning methodologies.

I ran out of labels

Supervised learning (SL) underlies many of our modern-day deep learning applications and is often the first learning paradigm we’re introduced to in DL. In SL, we are given a collection of labelled examples $\{(x^{(1)}, y^{(1)}), \dots, (x^{(m)}, y^{(m)})\}$ where our task is to predict the ground truth $y^{(i)}$ given the input $x^{(i)}$. The “supervision” comes from the fact that during the training phase, our model is guided or “supervised” by the correct answers present in the labelled data. SL has led to tremendous success across several domains including, but not limited to, image classification, speech recognition, medical devices, gaming, and much more. However, we run into a huge problem when relying on SL: it requires large amounts of labelled data, which may not always be readily available, therefore model performance will always be capped by the number of labels. Further, acquiring labelled data can be incredibly difficult and expensive, such as medical imaging data, which requires medical experts to carefully annotate each image by hand. As a general rule, the more experts you require to annotate the data, the more expensive the data collection procedure will be. As models improve year after year and begin to reach the “average human” level of performance on certain tasks (e.g., written fluency with ChatGPT), we’ll require “above-average human” data, namely expert data, to fuel our models in the future if we want to sustain this continual increase in model performance.

Given this significant bottleneck, we can’t depend completely on the SL for long; we must find alternative ways of training our models on limited amounts of labelled data. If we think of our own psychology, the idea of learning with minimal labelled data intuitively makes sense. If someone shows us a picture of a car, even as a child, we might only need a couple more examples before our internal model generalizes, in which we are able to accurately classify cars in real life. But how? One hypothesis is that there are other learning mechanisms operating besides a person telling us (e.g., supervising) whether something is a car or not. We can identify features of the car naturally, we can rotate the image of a car in our head (e.g., invariant under rotational symmetries), and we can update our model through our experience by seeing more cars. This is exemplified perfectly by Yann LeCun (Meta’s Chief AI Scientist):

As babies, we learn how the world works largely by observation. We form generalized predictive models about objects in the world by learning concepts such as object permanence and gravity. Later in life, we observe the world, act on it, observe again, and build hypotheses to explain how our actions change our environment by trial and error.

— Yann LeCun (Meta Chief AI Scientist)

Do you like cake?

One way to visualize how important SSL really is to training our models is the cake analogy (originally formulated by Yann LeCun). In this analogy, we can think of SSL as the cake (i.e. the base layer), SL as the icing, and reinforcement learning (RL) as the cherry on top. These components refer to the amount of information these machine learning methods extract from the training data. That is, SSL extracts the most amount of information, SL the second most, and RL the least; LeCun also gives a rough estimate of just how much this difference in extracted information is between the methods, stating that SSL contributes millions of bits per sample, SL contributes 10-10,000 bits per sample, and RL only contributes a few bits per sample! You might guess there some pushback from the RL community. However, in practice, this is largely what we observe. For example, in the case of large language models, they are primarily trained with SSL (e.g., hiding certain words in the text, predicting hidden words), then fine-tuned with supervised learning for certain tasks (e.g., conversation), and then tuned again (just a bit!) using reinforcement learning with human feedback (RLHF).

The jigsaw puzzle

There are many, many SSL methods but let’s start with a simple example: the jigsaw puzzle. This method is defined in Noroozi & Favaro (ECCV 2017). In abstract terms, a pretext task is an auxiliary problem used to train the model initially, but one where we have no intention of performing in practice. This pretext task is usually highly related to what we call the downstream task: the actual problem we are interested in solving. Once the model is trained on the pretext task, we can take components of it and implement into a supervised model, where we train it on the downstream task; this process is an example of what is known as transfer learning.

Taken from: Noroozi & Favaro (ECCV 2017). Here (a) represents segmenting a patch from the image, (b) is the input to the network for the jigsaw puzzle task, and (c) is the ground truth.

In the jigsaw puzzle pretext task, we take an image patch and segment it into equally sized blocks, like a grid. We then shuffle the blocks and train the model to reconfigure them back to the original orientation. In theory, our model should learn the important spatial structure of the image in order to perform this task. After training is complete, we take components of this model, which are enabled with representation learning capabilities relevant to the downstream task (e.g., image classification, semantic segmentation), and insert these components into a supervised model and retrain it on the downstream task.

What’s next

The jigsaw puzzle is just one example of large collection of pretext tasks. We also haven’t talked about the architectures used in SSL models, which can look quite different from many of the standard SL models we are used to, which has a significant effect on how well the SSL model performs and how useful it will be to the downstream task. SSL is an fascinating area of research and I hope you you found this post interesting and maybe learned something new! If you enjoyed this article or want to talk more about SSL with me, please let me know in the comments or send an email!

Previous
Previous

Paper Breakdown: Channel Clustering for Time Series Forecasting

Next
Next

Maximum Likelihood Estimation: How It Works