Meta-Policy Knowledge Distillation for Training Compact Student Models
MPDistil is a teacher-student collaborative knowledge distillation framework that enables compact student models to outperform larger teacher models through meta-learning and curriculum learning.
Based on the ICLR 2024 paper: A Good Learner can Teach Better: Teacher-Student Collaborative Knowledge Distillation
- π Superior Performance: 6-layer BERT student outperforms 12-layer BERT teacher on 5/6 SuperGLUE tasks
- π― 4-Phase Training: Teacher fine-tuning β PKD β Meta-teacher β Curriculum learning
- π Simple API: Easy-to-use
.train()method with full control over all phases - π Flexible Metrics: Built-in support for accuracy, F1, MCC, correlation via HuggingFace
evaluate - π§ Customizable: Works with any HuggingFace model and custom datasets
- π» Colab-Ready: GPU-optimized for Google Colab environments
- π¦ Easy Installation: Single pip command to get started
MPDistil consists of 4 training phases:
- Teacher Fine-tuning: Fine-tune teacher model on the target task
- Student PKD: Knowledge distillation with Patient Knowledge Distillation
- Meta-Teacher Learning: Collaborative or competitive loss for meta-learning
- Curriculum Learning: Reinforcement learning-based task selection
pip install git+https://github.com/parmanu-lcs2/mpdistil.gitgit clone https://github.com/parmanu-lcs2/mpdistil.git
cd mpdistil
pip install -e .pip install mpdistilfrom mpdistil import MPDistil, load_superglue_dataset
# Load data
loaders, num_labels = load_superglue_dataset('CB')
# Initialize MPDistil
model = MPDistil(
task_name='CB',
num_labels=num_labels,
metric='f1', # Options: 'accuracy', 'f1', 'mcc', 'correlation', 'auto'
teacher_model='bert-base-uncased',
student_model='bert-base-uncased',
student_layers=6
)
# Train with all 4 phases
history = model.train(
train_loader=loaders['train'],
val_loader=loaders['val'],
teacher_epochs=10, # Phase 1
student_epochs=10, # Phase 2
meta_epochs=1 # Phase 3 (NEW!)
)
# Save trained student
model.save_student('./my_student_model')
# Make predictions
predictions = model.predict(loaders['test'])from mpdistil import MPDistil, create_simple_dataloader
# Prepare your data
texts = [("This is text A", "This is text B"), ...]
labels = [0, 1, 0, ...]
# Create DataLoader
train_loader = create_simple_dataloader(
texts=texts,
labels=labels,
tokenizer_name='bert-base-uncased',
max_length=128,
batch_size=8
)
# Train model
model = MPDistil(task_name='MyTask', num_labels=2, metric='accuracy')
history = model.train(train_loader, val_loader)# Load multiple tasks for curriculum learning
cb_loaders, _ = load_superglue_dataset('CB')
rte_loaders, _ = load_superglue_dataset('RTE')
boolq_loaders, _ = load_superglue_dataset('BoolQ')
# Train with curriculum learning
history = model.train(
train_loader=cb_loaders['train'],
val_loader=cb_loaders['val'],
meta_loaders={
'RTE': rte_loaders['val'],
'BoolQ': boolq_loaders['val']
},
teacher_epochs=10, # Phase 1
student_epochs=10, # Phase 2
meta_epochs=3, # Phase 3 - can train for multiple epochs!
num_episodes=200 # Phase 4 - curriculum learning episodes
)MPDistil(
task_name: str, # Name of the main task
num_labels: int, # Number of output classes
metric: str = 'accuracy', # Metric: 'accuracy', 'f1', 'mcc', 'correlation', 'auto'
teacher_model: str = 'bert-base-uncased', # HuggingFace model name
student_model: str = 'bert-base-uncased', # HuggingFace model name
student_layers: int = 6, # Number of layers for student
device: str = 'auto', # 'auto', 'cuda', or 'cpu'
output_dir: str = './mpdistil_outputs'
)| Method | Description |
|---|---|
train(train_loader, val_loader, **kwargs) |
Train the model (all 4 phases) |
predict(test_loader) |
Generate predictions |
save_student(path) |
Save student model in HuggingFace format |
load_student(path) |
Load a saved student model |
save_predictions(predictions, path, label_mapping) |
Save predictions to TSV |
Configure training hyperparameters:
from mpdistil import TrainingConfig
config = TrainingConfig(
# Phase 1: Teacher
teacher_epochs=10,
teacher_lr=2e-5,
# Phase 2: Student PKD
student_epochs=10,
student_lr=3e-5,
alpha=0.5, # Soft loss weight
beta=100.0, # PKD loss weight
temperature=5.0, # Distillation temperature
# Phase 3: Meta-Teacher
meta_epochs=3, # NEW! Meta-teacher can train for multiple epochs
meta_lr=1e-3,
use_competitive_loss=False, # Use collaborative loss
# Phase 4: Curriculum
num_episodes=200,
reward_type='binary', # or 'real'
# General
batch_size=8,
seed=42,
report_to=None # Options: 'wandb', 'tensorboard', None
)
history = model.train(train_loader, val_loader, config=config)| Parameter | Type | Default | Description |
|---|---|---|---|
teacher_epochs |
int | 10 | Phase 1: Teacher training epochs |
student_epochs |
int | 10 | Phase 2: Student training epochs |
meta_epochs |
int | 1 | Phase 3: Meta-teacher training epochs (NEW!) |
num_episodes |
int | 200 | Phase 4: Curriculum learning episodes |
teacher_lr |
float | 2e-5 | Teacher learning rate |
student_lr |
float | 3e-5 | Student learning rate |
meta_lr |
float | 1e-3 | Meta-learning rate |
alpha |
float | 0.5 | Soft vs hard loss weight |
beta |
float | 100.0 | PKD loss weight |
temperature |
float | 5.0 | Distillation temperature |
use_competitive_loss |
bool | False | Competitive vs collaborative |
reward_type |
str | 'binary' | 'binary' or 'real' |
batch_size |
int | 8 | Batch size |
seed |
int | 42 | Random seed |
Performance on SuperGLUE tasks (BERT-base teacher β BERT-6L student):
| Model | BoolQ | CB | COPA | RTE | WiC | WSC |
|---|---|---|---|---|---|---|
| Teacher | 75.3 | 83.9 | 63.0 | 67.1 | 57.1 | 64.4 |
| Student (Undistilled) | 71.6 | 75.0 | 53.0 | 64.6 | 56.0 | 63.5 |
| MPDistil (Ours) | 73.4 | 83.9 | 70.0 | 67.5 | 59.6 | 65.4 |
β¨ Student outperforms teacher on 5/6 tasks!
See the examples/ directory for Jupyter notebooks:
- GLUE.ipynb: Usage with SuperGLUE
MPDistil supports multiple evaluation metrics via HuggingFace evaluate library:
| Metric | Use Case | Returns |
|---|---|---|
'accuracy' |
Standard classification (default) | {'acc': float} |
'f1' |
Imbalanced datasets, multi-class | {'acc': float, 'f1': float, 'acc_and_f1': float} |
'mcc' |
Binary classification, imbalanced | {'mcc': float} |
'correlation' |
Regression tasks (STS-B) | {'pearson': float, 'spearmanr': float} |
'auto' |
Auto-detect based on task | Task-specific metric |
# Accuracy (default)
model = MPDistil(task_name='BoolQ', num_labels=2, metric='accuracy')
# F1 score (recommended for CB, MultiRC)
model = MPDistil(task_name='CB', num_labels=3, metric='f1')
# Matthews Correlation (for binary with imbalance)
model = MPDistil(task_name='CoLA', num_labels=2, metric='mcc')
# Correlation (for regression)
model = MPDistil(task_name='STS-B', num_labels=1, metric='correlation', is_regression=True)
# Auto-detect
model = MPDistil(task_name='CB', num_labels=3, metric='auto') # Uses F1 for CBFine-tune a large teacher model (e.g., BERT-base) on your target task.
Train a smaller student model using:
- Soft targets from teacher (KL divergence)
- Hard labels (cross-entropy)
- Patient KD (intermediate feature matching)
Loss: Ξ± * soft_loss + (1-Ξ±) * hard_loss + Ξ² * pkd_loss
Create a meta-teacher that learns from both teacher and student representations:
Collaborative loss (default):
L = 0.5 * CE(T'(h_teacher), y) + 0.5 * CE(T'(h_student), y)
Competitive loss:
L = -mean(P_teacher) + mean(P_student) + CE_loss
Use reinforcement learning to select which auxiliary tasks help the student learn:
- Action model selects next task
- Reward based on student improvement over teacher
- REINFORCE algorithm updates policy
model = MPDistil(
task_name='MyTask',
num_labels=3,
teacher_model='roberta-large',
student_model='distilbert-base-uncased',
student_layers=6
)history = model.train(
train_loader=train_loader,
val_loader=val_loader,
report_to='wandb', # Options: 'wandb', 'tensorboard', None
wandb_project='my-project',
wandb_run_name='experiment-1'
)# Access student model
student = model.student
# Access teacher model
teacher = model.teacher
# Use with HuggingFace
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('./my_student_model')
model = AutoModel.from_pretrained('./my_student_model')If you use MPDistil in your research, please cite:
@inproceedings{sengupta2024mpdistil,
title={A Good Learner can Teach Better: Teacher-Student Collaborative Knowledge Distillation},
author={Sengupta, Ayan and Dixit, Shantanu and Akhtar, Md Shad and Chakraborty, Tanmoy},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=Ixi4j6LtdX}
}This project is licensed under the MIT License - see the LICENSE file for details.
Contributions are welcome! Please feel free to submit a Pull Request.
- Original paper: ICLR 2024
- Built with PyTorch and HuggingFace Transformers
For questions or issues, please open an issue on GitHub or contact the authors.
Made with β€οΈ for the research community
