diff --git a/chapters/en/unit3/vision-transformers/knowledge-distillation.mdx b/chapters/en/unit3/vision-transformers/knowledge-distillation.mdx index 00ec7e23b..cb1331798 100644 --- a/chapters/en/unit3/vision-transformers/knowledge-distillation.mdx +++ b/chapters/en/unit3/vision-transformers/knowledge-distillation.mdx @@ -1,10 +1,8 @@ # Knowledge Distillation with Vision Transformers -We are going to learn about Knowledge Distillation, the method behind [distilGPT](https://huggingface.co/distilgpt2) and [distilbert](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english), two of *the most downloaded models on the Hugging Face Hub!* +We are going to learn about Knowledge Distillation, the method behind [distilGPT](https://huggingface.co/distilgpt2) and [distilbert](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english), two of *the most downloaded models on the Hugging Face Hub!* -Presumably, we've all had teachers who "teach" by simply providing us the correct answers and then testing us on questions we haven't seen before, analogous to supervised learning -of machine learning models where we provide a labeled dataset to train on. Instead of having a model train on labels, however, -we can pursue [Knowledge Distillation](https://arxiv.org/abs/1503.02531) as an alternative to arriving at a much smaller model that can perform comparably to the larger model and much faster to boot. +Presumably, we've all had teachers who "teach" by simply providing us the correct answers and then testing us on questions we haven't seen before, analogous to supervised learning of machine learning models where we provide a labeled dataset to train on. Instead of having a model train on labels, however, we can pursue [Knowledge Distillation](https://arxiv.org/abs/1503.02531) as an alternative to arriving at a much smaller model that can perform comparably to the larger model and much faster to boot. ## Intuition Behind Knowledge Distillation @@ -14,25 +12,19 @@ Imagine you were given this multiple-choice question: If you had someone just tell you, "The answer is Draco Malfoy," that doesn't teach you a whole lot about each of the characters' relative relationships with Harry Potter. -On the other hand, if someone tells you, "I am very confident it is not Ron Weasley, I am somewhat confident it is not Neville Longbottom, and -I am very confident that it *is* Draco Malfoy", this gives you some information about these characters' relationships to Harry Potter! -This is precisely the kind of information that gets passed down to our student model under the Knowledge Distillation paradigm. +On the other hand, if someone tells you, "I am very confident it is not Ron Weasley, I am somewhat confident it is not Neville Longbottom, and I am very confident that it *is* Draco Malfoy," this gives you some information about these characters' relationships to Harry Potter! This is precisely the kind of information that gets passed down to our student model under the Knowledge Distillation paradigm. ## Distilling the Knowledge in a Neural Network -In the paper [*Distilling the Knowledge in a Neural Network*](https://arxiv.org/abs/1503.02531), Hinton et al. introduced the training methodology known as knowledge distillation, -taking inspiration from *insects*, of all things. Just as insects transition from larval to adult forms that are optimized for different tasks, large-scale machine learning models can -initially be cumbersome, like larvae, for extracting structure from data but can distill their knowledge into smaller, more efficient models for deployment. +In the paper [*Distilling the Knowledge in a Neural Network*](https://arxiv.org/abs/1503.02531), Hinton et al. introduced the training methodology known as knowledge distillation, taking inspiration from *insects*, of all things. Just as insects transition from larval to adult forms that are optimized for different tasks, large-scale machine learning models can initially be cumbersome, like larvae, for extracting structure from data but can distill their knowledge into smaller, more efficient models for deployment. -The essence of Knowledge Distillation is using the predicted logits from a teacher network to pass information to a smaller, more efficient student model. We do this -by re-writing the loss function to contain a *distillation loss*, which encourages the student model's distribution over the output space to approximate the teacher's. +The essence of Knowledge Distillation is using the predicted logits from a teacher network to pass information to a smaller, more efficient student model. We do this by re-writing the loss function to contain a *distillation loss*, which encourages the student model's distribution over the output space to approximate the teacher's. The distillation loss is formulated as: ![Distillation Loss](https://huggingface.co/datasets/hf-vision/course-assets/resolve/main/KL-Loss.png) -The KL loss refers to the [Kullback-Leibler Divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence) between the teacher and the student's output distributions. -The overall loss for the student model is then formulated as the sum of this distillation loss with the standard cross-entropy loss over the ground-truth labels. +The KL loss refers to the [Kullback-Leibler Divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence) between the teacher and the student's output distributions. The overall loss for the student model is then formulated as the sum of this distillation loss with the standard cross-entropy loss over the ground-truth labels. To see this loss function implemented in Python and a fully worked out example in Python, let's check out the [notebook for this section](https://github.com/johko/computer-vision-course/blob/main/notebooks/Unit%203%20-%20Vision%20Transformers/KnowledgeDistillation.ipynb). @@ -40,3 +32,26 @@ To see this loss function implemented in Python and a fully worked out example i Open In Colab +# Leveraging Knowledge Distillation for Edge Devices + +Knowledge distillation has become increasingly crucial as AI models are deployed on edge devices. Deploying a large-scale model, such as one with a 1 GB size and a latency of 1 second, is impractical for real-time applications due to high computational and memory requirements. These limitations are primarily attributed to the model's size. As a result, the field has embraced knowledge distillation, a technique that reduces model parameters by over 90% with minimal performance degradation. + +## The Consequences(good & bad) of Knowledge Distillation + +### 1. Entropy Gain +In the context of information theory, entropy is analogous to its counterpart in physics, where it measures the "chaos" or disorder within a system. In our scenario, it quantifies the amount of information a distribution contains. Consider the following example: + +- Which is harder to remember: `[0, 1, 0, 0]` or `[0.2, 0.5, 0.2, 0.1]`? + +The first vector, `[0, 1, 0, 0]`, is easier to remember and compress, as it contains less information. This can be represented as "1" in the second position. On the other hand, `[0.2, 0.5, 0.2, 0.1]` contains more information. +Building on that, let’s say, for example, we trained an 80M parameter network on ImageNet and then distilled it (as discussed earlier) into a 5M parameter student model. We would find that the entropy contained in the output of the teacher model is much lower than that of the student model. This means that the output of the student model, even though correct, is more chaotic than the teacher’s outputs. This comes down to a simple fact: the teacher’s additional parameters help it discern between classes more easily as it extracts more features. This perspective on knowledge distillation is very interesting and is actively being researched to reduce the student’s entropy, either by using it as a loss function or by applying similar metrics inspired by physics (such as energy). + + +### 2. Coherent Gradient Updates +Models learn iteratively by minimizing a loss function and updating their parameters through gradient descent. Consider a set of parameters `P = {w1, w2, w3, ..., wn}`, whose role in the teacher model is to activate when detecting a sample of class A. If an ambiguous sample resembles class A but belongs to class B, the model's gradient update will be aggressive after the misclassification, leading to instability. In contrast, the distillation process, with the teacher model's soft targets, promotes more stable and coherent gradient updates during training, resulting in a smoother learning process for the student model. + +### 3. Ability to Train on Unlabeled Data +The presence of a teacher model allows the student model to train on unlabeled data. The teacher model can generate pseudo-labels for these unlabeled samples, which the student model can then use for training. This approach significantly increases the amount of usable training data. + +### 4. A Shift in Perspective +Deep learning models are typically trained with the assumption that providing enough data will allow them to approximate a function `F` that accurately represents the underlying phenomenon. However, in many cases, data scarcity makes this assumption unrealistic. The traditional approach involves building larger models and fine-tuning them iteratively to achieve optimal results. In contrast, knowledge distillation shifts this perspective: given that we already have a well-trained teacher model `F`, the goal becomes approximating `F` using a smaller model `f`.