LLM Pruning and Distillation in Practice: The Minitron Approach

LLM Pruning and Distillation in Practice: The Minitron Approach
Photo by fabio / Unsplash


Original Paper: https://arxiv.org/pdf/2408.11796

By: Sharath Turuvekere SreenivasSaurav MuralidharanRaviraj JoshiMarcin ChochowskiMostofa PatwaryMohammad ShoeybiBryan CatanzaroJan KautzPavlo Molchanov

Abstract

We present a comprehensive report on compressing the Llama 3.1 8B and Mistral NeMo 12B models to 4B and 8B parameters, respectively, using pruning and distillation.

We explore two distinct pruning strategies: (1) depth pruning and (2) joint hidden/attention/MLP (width) pruning, and evaluate the results on common benchmarks from the LM Evaluation Harness.

The models are then aligned with NeMo Aligner and tested in instruct-tuned versions. This approach produces a compelling 4B model from Llama 3.1 8B and a state-of-the-art Mistral-NeMo-Minitron-8B (MN-Minitron-8B for brevity) model from Mistral NeMo 12B.

We found that with no access to the original data, it is beneficial to slightly fine-tune teacher models on the distillation dataset.

Summary Notes

image

Figure 1:High-level overview of our proposed pruning and distillation approach. The total number of tokens used for each step is indicated in parentheses.

image

Figure 2:Pruning and distillation process outlined in the original paper [1]. We follow the same approach in this work.

Introduction

This paper presents a comprehensive approach to compressing large language models (LLMs) like Llama 3.1 8B and Mistral NeMo 12B into more efficient models with 4B and 8B parameters, respectively.

The methodology combines pruning and distillation techniques to reduce model size while maintaining high performance on common benchmarks.

The authors introduce the Minitron strategy, which successfully compresses these models and produces state-of-the-art variants that outperform similarly-sized models.

Key Concepts

LLM Pruning and Distillation

  • The paper explores two pruning strategies: depth pruning (removing layers) and width pruning (reducing hidden dimensions, attention heads, etc.).
  • The distillation process involves fine-tuning the larger "teacher" model on a new dataset before pruning, a step referred to as "teacher correction," to ensure that the smaller "student" model retains as much of the original performance as possible.

Minitron Compression Strategy

  • The Minitron approach involves starting with a large pre-trained model, pruning it to create a smaller model, and then using knowledge distillation to transfer knowledge from the larger model to the smaller one.
  • The strategy was applied to Llama 3.1 8B and Mistral NeMo 12B models, resulting in Llama-3.1-Minitron-4B and MN-Minitron-8B models, which are significantly more efficient while retaining high accuracy.

Methodology

Pruning

  • Pruning was performed based on importance scores for each layer, neuron, and attention head. The most effective pruning strategy was found to be width pruning, which involved retaining attention heads and pruning other dimensions.
  • Depth pruning was also explored, where layers were removed based on their impact on the model's performance, but it was less effective than width pruning in most cases.

Distillation

  • Distillation involved using the corrected teacher model to guide the student model's training. This process involved minimizing the Kullback-Leibler (KL) divergence between the teacher and student models' outputs.
  • The paper demonstrates that teacher correction is crucial when the original training data is unavailable, as it significantly improves the student model's performance.

Experiments and Results

Model Performance

  • The Minitron-4B and MN-Minitron-8B models were evaluated on several benchmarks, including MMLU, HumanEval, and Winogrande.
  • The results showed that the MN-Minitron-8B model outperformed similarly-sized models, including the original Mistral NeMo 12B, in tasks such as GSM8k and HumanEval, where it achieved better accuracy.

Runtime Performance

  • The compressed models were optimized using NVIDIA TensorRT-LLM for faster inference.
  • The Llama-3.1-Minitron-4B models achieved significant speedups, with the depth-pruned variant being the fastest, offering a 2.7× improvement in throughput over the original Llama 3.1 8B model.

Insights

Importance of Teacher Correction

  • Teacher correction was found to be essential for achieving optimal distillation results, reducing validation loss by over 6% and ensuring the smaller model performs well on new datasets.
  • The corrected teacher model can be used in parallel with distillation to improve efficiency.

Effectiveness of Pruning Strategies

  • Width pruning consistently outperformed depth pruning, particularly in maintaining reasoning abilities and achieving higher accuracy on benchmarks like MMLU and GSM8k.
  • Depth pruning, however, was more effective in boosting inference speed.

Conclusion

The Minitron approach demonstrates the effectiveness of combining pruning and distillation to create smaller, more efficient LLMs without significant loss of performance.

The Llama-3.1-Minitron-4B and MN-Minitron-8B models represent significant advancements in model compression, offering state-of-the-art accuracy with a much-reduced computational footprint.

This approach could serve as a blueprint for future efforts to compress and optimize large-scale language models.

Read more