Skip to content

ajitashwath/model-distillation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 

Repository files navigation

🧠 Model Distillation: Using CIFAR-10 dataset

🔍 What is Model Distillation?

Think of model distillation like file compression, but for AI models. Just like you can zip a big folder to save space, we shrink a large neural network (called the “Teacher”) into a much smaller one (called the “Student”) while trying to keep it smart.

This is especially useful when you want to run models on phones, edge devices, or any low-resource environments.


📋 Abstract

This repository presents a comprehensive implementation and analysis of knowledge distillation for image classification on CIFAR-10. We demonstrate how a lightweight student network can learn from a larger teacher network, achieving significant computational efficiency gains while maintaining competitive performance. Our implementation includes detailed performance metrics, model analysis, and visualization tools for understanding the distillation process.

🎯 Research Objectives

  1. Investigate the effectiveness of knowledge distillation for model compression
  2. Analyze the trade-offs between model size, inference speed, and accuracy
  3. Provide a reproducible framework for knowledge distillation experiments
  4. Demonstrate practical applications of model compression techniques

🏗️ Architecture Overview

Teacher Model (ResNet18)

  • Architecture: Deep convolutional network with residual connections
  • Parameters: 4,664,970 (~4.7M)
  • Size: 17.81 MB
  • Design Philosophy: Maximizes representational capacity for knowledge extraction

Student Model (Lightweight CNN)

  • Architecture: Compact 4-block convolutional network
  • Parameters: 106,826 (~107K)
  • Size: 0.41 MB
  • Compression Ratio: 43.7× parameter reduction, 43.4× size reduction

📊 Experimental Results

Performance Metrics

Model Accuracy (%) Size (MB) Parameters Inference Time (ms) Efficiency Score*
Teacher 83.13 17.81 4,664,970 111.29 0.75
Student Baseline 72.70 0.41 106,826 27.76 2.62
Student Distilled 70.86 0.41 106,826 25.82 2.74

*Efficiency Score = Accuracy / (Size × Inference Time)

Key Findings

🔍 Surprising Result: Distillation Performance Gap

  • Expected: Distilled student outperforms baseline student
  • Observed: Distilled student (70.86%) performs 1.84% lower than baseline (72.70%)
  • Hypothesis: Potential optimization challenges or hyperparameter sensitivity

Computational Efficiency Gains

  • Size Reduction: 97.7% smaller model (17.81MB → 0.41MB)
  • Speed Improvement: 4.3× faster inference (111ms → 26ms)
  • Parameter Efficiency: 43.7× fewer parameters with only 14.8% accuracy drop

📈 Scaling Analysis

  • Teacher-to-student knowledge gap: 12.27% accuracy difference
  • Size-to-performance ratio: Excellent for resource-constrained environments
  • Inference efficiency: Suitable for real-time applications

🧪 Methodology

Knowledge Distillation Framework

Our implementation follows the seminal work by Hinton et al. (2015) with the following components:

Loss Function

L_total = α × L_soft + (1-α) × L_hard

Where:

  • L_soft: KL divergence between teacher and student soft predictions
  • L_hard: Cross-entropy loss with ground truth labels
  • α = 0.7: Weighting factor (soft knowledge emphasis)
  • T = 4.0: Temperature parameter for probability smoothing

Training Configuration

  • Optimizer: Adam with weight decay (1e-4)
  • Learning Rate: 0.001 with Step LR scheduler
  • Batch Size: 128
  • Epochs: 15 for all models
  • Data Augmentation: Random crop, horizontal flip, normalization

Evaluation Metrics

  1. Accuracy: Top-1 classification accuracy on CIFAR-10 test set
  2. Model Size: Memory footprint including parameters and buffers
  3. Inference Time: Average forward pass time over 100 batches
  4. Parameter Count: Total trainable parameters
  5. Compression Ratio: Teacher size / Student size

🔬 Analysis and Discussion

Distillation Challenges Observed

1. Negative Transfer Phenomenon

The distilled student's underperformance compared to the baseline suggests:

  • Capacity Mismatch: Student network may be too small to effectively capture teacher knowledge
  • Temperature Sensitivity: T = 4.0 might be suboptimal for this architecture pair
  • Loss Balance: α = 0.7 may overemphasize soft targets

2. Potential Improvements

  • Hyperparameter Tuning: Systematic search for optimal T and α
  • Progressive Distillation: Multi-step knowledge transfer
  • Feature-level Distillation: Intermediate layer knowledge transfer
  • Attention Transfer: Focus on important spatial regions

Practical Implications

Success Metrics

  • Deployment Viability: 97.7% size reduction enables mobile deployment
  • Real-time Processing: 4.3× speed improvement for latency-critical applications
  • Resource Efficiency: Significant reduction in computational requirements

⚠️ Limitations

  • Accuracy Trade-off: 12.27% accuracy drop may be significant for some applications
  • Distillation Effectiveness: Need for methodology refinement
  • Architecture Dependency: Results may vary with different model architectures

📈 Recommendations for Future Work

Immediate Improvements

  1. Hyperparameter Optimization: Grid search for T ∈ [1, 2, 3, 4, 5] and α ∈ [0.1, 0.3, 0.5, 0.7, 0.9]
  2. Student Architecture: Experiment with slightly larger student networks
  3. Training Duration: Extended training with learning rate scheduling
  4. Ensemble Distillation: Multiple teacher models for robust knowledge transfer

Advanced Techniques

  1. Attention-based Distillation: Transfer spatial attention maps
  2. Progressive Knowledge Transfer: Curriculum learning approach
  3. Multi-level Feature Distillation: Intermediate layer supervision
  4. Adversarial Distillation: GAN-based knowledge transfer

🚀 Usage Instructions

Prerequisites

pip install torch torchvision matplotlib pandas numpy

📊 Visualization and Analysis

The implementation includes comprehensive visualization:

  • Accuracy Comparison: Bar charts showing model performance
  • Size Analysis: Model compression visualization
  • Training Curves: Learning progression over epochs
  • Efficiency Metrics: Speed vs. accuracy trade-offs

🔗 References

  1. Hinton et al. (2015) - Distilling the Knowledge in a Neural Network
  2. Distillation in PyTorch (Tutorial)

🤝 Contributing

Contributions are welcome! Please feel free to submit pull requests or open issues for:

  • Bug fixes
  • Performance improvements
  • New distillation techniques
  • Additional evaluation metrics
  • Documentation improvements

For questions, suggestions, or collaborations, please open an issue or reach out through the repository's discussion section.


This research demonstrates the practical viability of knowledge distillation for model compression while highlighting areas for methodological improvement. The significant computational gains achieved make this approach valuable for deployment in resource-constrained environments, despite the observed accuracy trade-offs.

About

A simple implementation of model distillation using PyTorch on the CIFAR-10 dataset.

Topics

Resources

Stars

Watchers

Forks