A PyTorch implementation of an image classification system based on the DINOv3 (self-DIstillation with NO labels) vision transformer. This project provides a complete training pipeline with distributed data parallel (DDP) support, advanced data augmentation, and multiple loss functions including supervised contrastive learning.
- DINOv3 Backbone: Leverages pre-trained DINOv3 ViT-B/16 for powerful feature extraction
- Distributed Training: Full DDP support for multi-GPU training
- Advanced Loss Functions:
- Combined Cross-Entropy + Supervised Contrastive Loss (SDC)
- EM-based Supervised Contrastive Loss with learnable class centers
- Smart Data Sampling: Weighted sampling for handling class imbalance
- Medical Image Augmentation: Specialized augmentation pipeline for medical imaging
- Comprehensive Metrics: Top-K accuracy, F1 score, recall, and more
- Gradient Accumulation: Memory-efficient training for large models
argument.py # Data augmentation configurations
Classification_Metrics.py # Evaluation metrics
classifier_dataset.py # Dataset class with weighted sampling
data_sampler.py # Distributed weighted sampler
dense_features_PCA.py # Feature extraction and PCA visualization
LinearClassifier.py # Linear classifier implementation
train_linear.py # Main training script
dinov3/ # DINOv3 model implementation
image/ # Dataset directory
train/ # Training images
val/ # Validation images
test/ # Test images
pre_weight/ # Pre-trained model weights
python >= 3.8
torch >= 2.0.0
torchvision >= 0.15.0
torchmetrics
Pillow
tqdm
numpy
matplotlib
scikit-learnInstall dependencies:
pip install torch torchvision torchmetrics pillow tqdm numpy matplotlib scikit-learnOrganize your images in the following structure:
image/
train/
class1/
class2/
class3/
val/
class1/
class2/
class3/
test/
class1/
class2/
class3/
Download the DINOv3 ViT-B/16 pre-trained weights and place them in the pre_weight/ directory.
Edit the get_default_config() function in train_linear.py to set your hyperparameters.
python train_linear.pyThe script automatically detects available GPUs and uses DDP:
python train_linear.pyExtract and visualize features using PCA:
python dense_features_PCA.pyThe classifier consists of:
- Frozen DINOv3 Backbone: Pre-trained ViT-B/16 (embedding dim: 768)
- Feature Aggregation: Concatenates features from the last N transformer blocks
- Linear Classifier: Single linear layer for classification
Feature dimension calculation:
- Without avgpool:
n_last_blocks 768(e.g., 4 768 = 3072) - With avgpool:
(n_last_blocks + 1) 768(e.g., 5 768 = 3840)
Loss = α CrossEntropy + β SupervisedContrastive- ** Cross-Entropy**: Standard classification loss
- ** Supervised Contrastive**: Encourages same-class features to be closer, different-class features to be farther
An expectation-maximization variant with learnable class centers:
- E-step: Calculate responsibility (soft assignment) of samples to classes
- M-step: Update class centers based on responsibilities
- Supports multiple similarity metrics: dot product, cosine, euclidean
Specialized augmentation for medical images:
- Random rotation (30)
- Random horizontal/vertical flip
- Color jitter (brightness, contrast, saturation)
- Random affine transformations
- Gaussian blur
- Normalization with ImageNet statistics
Automatic weighted sampling based on class distribution
For limited GPU memory
Enable automatic mixed precision for faster training
Supports cosine annealing and step decay
The training script automatically computes:
- Top-1 Accuracy: Percentage of correct top predictions
- Top-3 Accuracy: Percentage when true class is in top 3 predictions
- F1 Score: Harmonic mean of precision and recall (micro-average)
- Recall: True positive rate (micro-average)
Models are automatically saved:
last.pth: Latest model checkpointbest.pth: Best model based on validation accuracyepoch_N.pth: Periodic snapshots every N epochs
- Batch Size: Start with 96 and adjust based on GPU memory
- Learning Rate: 0.01 works well for linear classifiers with SGD
- Feature Layers: Using 4 last blocks (
n_last_blocks=4) is a good balance - Gradient Clipping: Set to 1.0 to prevent gradient explosion
- Validation Interval: Validate every epoch for small datasets
- Reduce
batch_size - Enable gradient accumulation
- Reduce
img_size - Use fewer feature blocks (
n_last_blocks)
- Enable DDP for multi-GPU training
- Increase
num_workersfor data loading - Enable
use_ampfor mixed precision
- Adjust learning rate
- Try different optimizers (SGD vs AdamW)
- Tune SDC loss weight
- Check data augmentation strength
This project is released under the MIT License.
- DINOv3 model from Meta AI Research
- PyTorch team for the excellent deep learning framework
- The open-source community for various tools and libraries
For questions and feedback, please open an issue on GitHub.
If you find this project helpful, please consider giving it a star!
Made with by the community