Efficient Deep Learning: A Practical Guide - Part 1
Knowledge Distillation
Why Knowledge Distillation ?
To kick off our journey into model compression, we start with one of the most powerful and widely used techniques: Knowledge Distillation. In this article, we will explore:
- What is Knowledge Distillation: The fundamental principles behind this method.
- Revealing the Dark Knowledge: How soft probabilities help transfer hidden insights.
- Training the Student Model: The loss function that balances hard labels with teacher guidance.
- How to Implement Knowledge Distillation in Practice: A hands-on example using FasterAI.
We will detail all the steps in the article below, and provide a practical example for you to apply this technique. There exist other variations of Knowledge Distillation, which could be part of a future series of articles, so stay tuned !
1. What is Knowledge Distillation?
Knowledge Distillation is an elegant model compression technique that transfers knowledge from a large, complex model (the "teacher") to a smaller, more efficient model (the "student"). First introduced by Hinton et al. in 2015, this approach has become one of the most effective methods for creating lightweight but powerful models.
The Core Concept
The fundamental insight behind Knowledge Distillation is that large neural networks learn rich internal representations that go beyond simply mapping inputs to correct outputs. These networks develop nuanced understandings of the relationships between classes, the importance of different features, and the structure of the data space.
Traditional model training forces networks to output hard decisions - "this image is a dog, not a cat" - but discards the valuable information contained in the confidence levels and secondary predictions. Knowledge Distillation preserves this rich information during the compression process.
Benefits of Knowledge Distillation
Knowledge Distillation offers several advantages:
- Efficiency: Smaller models require less memory and computational power at inference time.
- Speed: Distilled models execute faster, making them suitable for real-time applications.
- Deployment flexibility: Compressed models can run on resource-constrained devices like smartphones or IoT devices.
- Improved generalization: Student models often show better performance on unseen data than similarly-sized models trained from scratch.
When to Use Knowledge Distillation
Knowledge Distillation is particularly effective when:
- You have a high-performing but overly large model that needs to be deployed in resource-constrained environments
- You need to maintain high accuracy while reducing inference time
- You want to transfer knowledge across different architectures or even different modalities
- You're working with limited labeled data but can leverage a pre-trained teacher
In the following sections, we'll explore exactly how Knowledge Distillation works, diving into the concept of "dark knowledge" and how we extract and transfer it from teacher to student.
2. Reveal the Dark Knowledge
But first, what is the Dark Knowledge ?
Dark Knowledge refers to subtle information contained in the probability outputs of a deep neural network. This information reveals how the model perceives the relationship between the different classes.
Let illustrate this with an example:
Let's say that we have trained a model to discriminate between the following 4 classes: [cow, dog, cat, car]
and that now we show it the following image:

We expect the model to confidently predict that the image is a dog.
However, the model's uncertainty also carries useful information: it should assign a higher probability to "cat" than to "car", even though both are incorrect. This is because dogs and cats share more visual features (e.g., tail, eyes, body structure) than dogs and cars do. This subtle relationship between class probabilities is the so-called Dark Knowledge.
Let's say the the output predictions, the so-called logits, of our network corresponding to the classes [cow, dog, cat, car]
are respectively $z$ = [1.2, 5.1, 2.9, 0.3]
. These predictions usually go through a softmax function, allowing them to be interpreted as probabilities, as:
$$ p_i = \frac{\text{exp}(z_i)}{\sum_j \text{exp}(z_j)} $$
This softmax function produces the probabilities: $p$ = [0.02, 0.88, 0.09, 0.01]
.
As can be seen, the softmax activation function emphasizes the most confident prediction, while squashing the others, thus discarding any useful inter-class information. This comes from the way classification networks are trained, by comparing those probablities to the hard-labels: [0. , 1., 0. , 0.]
, thus penalizing the model for not being super confident about its predictions.
To overcome this issue, we can use a modified version of the softmax, where the probabilities are softened, thus still carrying the dark knowledge. To do so, we introduce a temperature parameter $T$ in the softmax function as:
$$ p_i^T = \frac{\text{exp}(z_i/T)}{\sum_j \text{exp}(z_j/T)} $$
In the case of our example, with a temperature $T=3$ , we get the probabilities: $p^T$ = [0.14, 0.51, 0.25, 0.10]
. As can be observed, the desired behaviour of a softmax is still present, but now it carries more subtle information about how similar or dissimilar classes can be.
Remark: If the temperature parameter is set to 1, then the modified version of the softmax degenerates into the regular softmax
Now that the Dark Knowledge has been revealed from our network, let's tackle the second step of Knowledge Distillation, make the student model learn from !
Once we have extracted the Dark Knowledge from our teacher model, we need to effectively transfer it to the student model. This is where the training strategy becomes crucial.
The Dual-Objective Training Process
Training a student model with Knowledge Distillation differs from standard neural network training in a fundamental way, the student must simultaneously optimize for two different objectives:
- Match the ground truth: Like any supervised learning model, the student needs to learn the correct classifications from the labeled data.
- Mimic the teacher: The student also needs to capture the rich inter-class relationships that the teacher has learned.
These dual objectives create a more nuanced learning process that helps the student develop more sophisticated decision boundaries despite its smaller capacity.
Balancing Teacher Guidance and Ground Truth
To implement this dual-objective training, we design a specialized loss function that combines two components:
- Classification Loss: This is the standard cross-entropy loss between the student's predictions and the ground truth labels. It ensures the student learns to make correct predictions.
- Distillation Loss: This is the cross-entropy between the student's softened predictions and the teacher's softened predictions (both using the same temperature T). This component transfers the dark knowledge.
The overall Knowledge Distillation loss function combines these components with a weighting parameter α:
$$ L_{KD} = \overbrace{\alpha * \text{CE}(p^T_t, p^T_s)}^{\text{Teacher Loss}} + \overbrace{(1-\alpha) * \text{CE}(y_{true}, p_s)}^{\text{Classification Loss}} $$
Where:
- $p^T_t$ and $p^T_s$ are the softened probability distributions from teacher and student
- $y_true$ represents the ground truth labels
- $p_s$ is the standard (T=1) probability output from the student
- $\alpha$ controls the balance between mimicking the teacher and learning from ground truth
Choosing Hyperparameters
Two critical hyperparameters affect the distillation process:
- Temperature ($T$): Higher values (typically 2-20) produce softer probability distributions that better reveal the dark knowledge. The optimal temperature depends on the dataset and model architectures.
- Distillation Weight ($\alpha$): This controls how much the student should focus on mimicking the teacher versus learning directly from labels. Values between 0.5 and 0.7 are common starting points.
These parameters often require tuning for optimal results. In practice, you might start with T=3 and α=0.5, then adjust based on validation performance.
Beyond Basic Distillation
While the described approach focuses on distilling knowledge from the final layer outputs, advanced variations of Knowledge Distillation also transfer:
- Feature knowledge: Matching intermediate representations between teacher and student
- Attention maps: Transferring where the teacher model focuses its attention
- Relational knowledge: Preserving relationships between different examples
In the next section, we'll see how to implement this Knowledge Distillation process in practice using the FasterAI library.
4. Implement Knowledge Distillation in practice
Introducing FasterAI
We have created an open-source library, building on PyTorch and fastai to help deep learning practitioners to apply such compression technique. It is free to use and contains many state-of-the-art compression technique, that will help you to create smaller, faster and greener AI models.
To perform Knowledge Distillation in FasterAI only requires an additionnal line of code when compared to fastai code.
We first need to train our teacher on the desired use-case (for more detail about how to train the model with fastai, please refer to the documentation). In our example, we will take a pre-trained ResNet34 as the teacher, and an untrained ResNet18 as the student. They will be trained on the PETS dataset.
from fastai.vision.all import *
from fasterai.distill.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))
teacher = vision_learner(dls, resnet34, metrics=accuracy)
teacher.unfreeze()
teacher.fit_one_cycle(5, 1e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.712487 | 0.780329 | 0.826116 | 00:04 |
1 | 0.426159 | 0.454067 | 0.895129 | 00:04 |
2 | 0.131154 | 0.193276 | 0.926252 | 00:04 |
3 | 0.039612 | 0.191167 | 0.937754 | 00:04 |
4 | 0.024194 | 0.194976 | 0.937077 | 00:04 |
Once the teacher has been trained, it is now ready to help a smaller model in its learning. In FasterAI, to apply Knowledge Distillation, we use the Callback system, allowing to twist the training loop of a model. In our case, we would like to replace the classic loss function by the KD loss function that we defined earlier.
Before doing so, let's first train a student model without applying the Knowledge Distillation, so we will have a baseline that we can compare to:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
student.fit_one_cycle(10, 1e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.602736 | 0.784080 | 0.682003 | 00:04 |
1 | 0.582019 | 0.629800 | 0.644790 | 00:04 |
2 | 0.547411 | 0.521493 | 0.725981 | 00:04 |
3 | 0.490268 | 0.669058 | 0.740189 | 00:04 |
4 | 0.448316 | 0.446682 | 0.778078 | 00:03 |
5 | 0.403792 | 0.668784 | 0.759811 | 00:03 |
6 | 0.350714 | 0.409201 | 0.815291 | 00:04 |
7 | 0.279282 | 0.392315 | 0.815968 | 00:04 |
8 | 0.197490 | 0.415861 | 0.837618 | 00:03 |
9 | 0.157046 | 0.403317 | 0.834235 | 00:04 |
We find that our student Resnet18 achieves $83$% of accuracy. Let's see how this changes when the teacher helps it in the training:
student = Learner(dls, resnet18(num_classes=2), metrics=accuracy)
kd_cb = KnowledgeDistillationCallback(teacher=teacher, weight=0.5)
student.fit_one_cycle(10, 1e-3, cbs=kd_cb)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.874970 | 2.434021 | 0.709066 | 00:04 |
1 | 2.619885 | 2.321189 | 0.737483 | 00:04 |
2 | 2.381633 | 2.690866 | 0.730041 | 00:04 |
3 | 2.101448 | 1.772370 | 0.771313 | 00:04 |
4 | 1.824600 | 1.707633 | 0.793640 | 00:04 |
5 | 1.588555 | 1.433752 | 0.814614 | 00:04 |
6 | 1.273060 | 1.264489 | 0.843708 | 00:04 |
7 | 0.979666 | 1.169676 | 0.849120 | 00:04 |
8 | 0.768508 | 1.047257 | 0.862652 | 00:04 |
9 | 0.630613 | 1.043255 | 0.861976 | 00:04 |
We can observe that now, the student model achieves $86$% of accuracy ! This shows that Knowledge Distillation indeed helped the student in its learning, making its performance closer to the teacher.
By default, the callback will apply the distillation loss as detailed in this blog post, which provides a good default for classification scenarios. There are however other alternative losses that may perform better in other cases. Please take a look at the other losses available in FasterAI, and compare them to find which one is the more adapted to your use-case !
Conclusion
In the next blog post, we will present another compression technique called Sparsification, which consist in finding unimportant connexions in our neural network, and replace them by zeroes, effectively removing parameters from our neural network without impacting its performance.
Join us on Discord to stay tuned! 🚀