2.2 Distillation & Quantization

🎯 Learning Objectives

  • Understand knowledge distillation techniques for creating efficient SLMs
  • Master quantization methods (8-bit, 4-bit) for model compression
  • Explore QLoRA and other parameter-efficient fine-tuning approaches
  • Analyze performance trade-offs in model compression

🧠 Knowledge Distillation

Knowledge Distillation is a technique where a smaller "student" model learns to mimic the behavior of a larger "teacher" model, capturing its knowledge in a more compact form.

🎓 Teacher Model

Large, powerful model
(e.g., GPT-4, 175B params)

📚 Knowledge Transfer

Soft targets, attention maps
intermediate representations

🎯 Student Model

Compact, efficient model
(e.g., 1-7B params)

🔬 Distillation Process

Step 1: Teacher Model Preparation

Select a high-performing large model as the teacher. This model should excel at the target tasks.

# Load teacher model (large, powerful) teacher_model = AutoModelForCausalLM.from_pretrained("gpt-3.5-turbo") teacher_model.eval() # Generate soft targets with temperature scaling def get_teacher_predictions(inputs, temperature=3.0): with torch.no_grad(): logits = teacher_model(inputs).logits # Soften the distribution for better knowledge transfer soft_targets = F.softmax(logits / temperature, dim=-1) return soft_targets

Step 2: Student Architecture Design

Create a smaller model with fewer layers, smaller hidden dimensions, or different architecture optimizations.

# Student model configuration (much smaller) student_config = { "vocab_size": 50257, "n_positions": 2048, "n_embd": 768, # Reduced from 4096 "n_layer": 12, # Reduced from 36 "n_head": 12, # Reduced from 32 } student_model = GPT2LMHeadModel(student_config) print(f"Teacher params: 175B, Student params: {student_model.num_parameters()//1e6:.1f}M")

Step 3: Distillation Training

Train the student model to match both the teacher's outputs and the ground truth labels.

# Distillation loss function def distillation_loss(student_logits, teacher_logits, labels, alpha=0.7, temperature=3.0): # Soft target loss (student learns from teacher) soft_loss = F.kl_div( F.log_softmax(student_logits / temperature, dim=-1), F.softmax(teacher_logits / temperature, dim=-1), reduction='batchmean' ) * (temperature ** 2) # Hard target loss (student learns from ground truth) hard_loss = F.cross_entropy(student_logits, labels) # Weighted combination return alpha * soft_loss + (1 - alpha) * hard_loss

📊 Distillation Results

Model Parameters GLUE Score Inference Speed Memory (GB)
Teacher (BERT-Large) 340M 84.3 1x 1.3
Student (DistilBERT) 66M 81.2 2x 0.3
Performance Retention 19% 96% 200% 23%

⚖️ Quantization Techniques

Quantization reduces model size and inference time by using lower-precision number representations (e.g., INT8 instead of FP32).

🔢 Precision Levels

FP32 (Full Precision)

Size: 4 bytes/param
Range: ±3.4 × 10³⁸
Use: Training, highest accuracy

FP16 (Half Precision)

Size: 2 bytes/param
Range: ±6.5 × 10⁴
Use: Modern GPUs, good balance

INT8 (8-bit Integer)

Size: 1 byte/param
Range: -128 to 127
Use: Edge devices, 4x compression

INT4 (4-bit Integer)

Size: 0.5 bytes/param
Range: -8 to 7
Use: Extreme compression, 8x smaller

🛠️ Quantization Methods

📐 Post-Training Quantization (PTQ)

Quantize an already-trained model without additional training.

import torch # Dynamic quantization (runtime quantization) quantized_model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear, torch.nn.LSTM}, dtype=torch.qint8 ) # Static quantization (calibration-based) model.qconfig = torch.quantization.get_default_qconfig('fbgemm') torch.quantization.prepare(model, inplace=True) # Calibrate with sample data for data in calibration_loader: model(data) torch.quantization.convert(model, inplace=True)

Pros: Fast, no retraining needed
Cons: Potential accuracy loss

🎯 Quantization-Aware Training (QAT)

Train the model with quantization simulation to maintain accuracy.

# QAT setup model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') torch.quantization.prepare_qat(model, inplace=True) # Train with fake quantization for epoch in range(num_epochs): for batch in train_loader: optimizer.zero_grad() outputs = model(batch) loss = criterion(outputs, targets) loss.backward() optimizer.step() # Convert to actual quantized model torch.quantization.convert(model, inplace=True)

Pros: Better accuracy retention
Cons: Requires retraining time

📈 Quantization Performance Impact

Memory Usage:
FP32 → INT4
FP32 (100%) → FP16 (50%) → INT8 (25%) → INT4 (12.5%)

🔧 QLoRA: Quantized Low-Rank Adaptation

QLoRA combines quantization with Low-Rank Adaptation (LoRA) to enable efficient fine-tuning of large models on consumer hardware.

🏗️ QLoRA Architecture

Core Components

  1. 4-bit Quantization: Base model stored in 4-bit precision
  2. LoRA Adapters: Small trainable matrices added to key layers
  3. Gradient Checkpointing: Memory optimization during training
  4. Paged Optimizers: Handle memory spikes efficiently
# QLoRA implementation example from transformers import AutoModelForCausalLM, BitsAndBytesConfig from peft import LoraConfig, get_peft_model # 4-bit quantization config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) # Load quantized base model model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", quantization_config=bnb_config, device_map="auto" ) # LoRA configuration lora_config = LoraConfig( r=16, # Low-rank dimension lora_alpha=32, # LoRA scaling parameter target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], lora_dropout=0.1, ) # Add LoRA adapters model = get_peft_model(model, lora_config) print(f"Trainable parameters: {model.num_parameters(only_trainable=True):,}")

💡 QLoRA Benefits

Method Memory (GB) Training Time Performance Retention Hardware Requirements
Full Fine-tuning 80-120 1x 100% Multiple A100s
LoRA 40-60 0.8x 95-98% Single A100
QLoRA 12-16 0.9x 90-95% RTX 3090/4090

🚀 Advanced Compression Techniques

✂️ Structured Pruning

Remove entire neurons, attention heads, or layers based on importance metrics.

# Magnitude-based pruning import torch.nn.utils.prune as prune # Prune 30% of least important weights in each linear layer for module in model.modules(): if isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=0.3) # Make pruning permanent for module in model.modules(): if isinstance(module, torch.nn.Linear): prune.remove(module, 'weight')

🎭 Progressive Distillation

Gradually reduce model size through multiple distillation stages.

175B 13B 7B 1.3B

🔄 Dynamic Quantization

Adapt quantization precision based on layer importance and input characteristics.

  • Mixed Precision: Different layers use different bit widths
  • Adaptive Quantization: Adjust precision based on activation ranges
  • Channel-wise Quantization: Per-channel scaling for better accuracy

🎯 Best Practices for Model Compression

  • Start with Distillation: Often provides the best accuracy-size trade-off
  • Combine Techniques: Use distillation + quantization + pruning together
  • Calibration Data Quality: Use representative data for quantization calibration
  • Task-Specific Optimization: Tailor compression to your specific use case
  • Iterative Approach: Gradually increase compression while monitoring performance
  • Hardware Consideration: Choose techniques compatible with target deployment hardware

Key Insight: The goal isn't just smaller models, but maintaining the right balance between size, speed, and accuracy for your specific application requirements.