Deep learning plant disease detection system using PyTorch and transfer learning. Classifies 38 crop diseases across 14 plant species with 95% accuracy on the PlantVillage dataset.
PlantGuard-AI is a computer vision system that helps farmers and agricultural professionals identify plant diseases from leaf images. Using PyTorch and ResNet18 transfer learning, the model achieves 95.02% test accuracy on 8,147 images across 38 disease classes.
Key Features:
- Transfer Learning with ResNet18 (ImageNet pre-trained weights)
- Custom Data Augmentation pipeline for robust predictions
- Production-Ready training loop with checkpointing and learning rate scheduling
- Comprehensive Evaluation with confusion matrices and per-class metrics
- Automated Visualizations for model performance analysis
- Test Accuracy: 95.02%
- Precision: 95.75%
- Recall: 95.02%
- F1-Score: 95.01%
- Training Time: ~2 hours (10 epochs on CPU)
- Dataset Size: 54,305 images across 38 classes
# Clone the repository
git clone https://github.com/SNMiguel/PlantGuard-AI.git
cd PlantGuard-AI
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txt- Download PlantVillage dataset from Kaggle
- Extract to
data/raw/(should contain 38 disease class folders)
# Quick training (5 epochs)
python main.py --data_dir data/raw --epochs 5 --batch_size 32
# Full training (10 epochs, recommended)
python main.py --data_dir data/raw --epochs 10 --batch_size 32 --verbose
# Advanced training with frozen backbone
python main.py --data_dir data/raw --epochs 15 --freeze_backbone# Predict on a single image
python predict.py --image_path path/to/leaf_image.jpg --model_path models/saved/resnet18_best.pthPlantGuard-AI/
โโโ data/
โ โโโ raw/ # PlantVillage dataset (38 disease folders)
โโโ models/
โ โโโ resnet_model.py # ResNet18 implementation with transfer learning
โ โโโ saved/ # Trained model checkpoints (.pth files)
โโโ utils/
โ โโโ data_loader.py # Dataset loading and splitting
โ โโโ transforms.py # Data augmentation pipeline
โ โโโ metrics.py # Evaluation metrics
โ โโโ visualizer.py # Plotting and visualization
โโโ results/
โ โโโ visualizations/ # Training curves, confusion matrix, predictions
โโโ train.py # Training script
โโโ main.py # Main pipeline orchestrator
โโโ predict.py # Single image prediction
โโโ requirements.txt # Dependencies
โโโ README.md # Project documentation
The model classifies 38 disease classes across 14 crop species:
Crops: Apple, Blueberry, Cherry, Corn, Grape, Orange, Peach, Pepper, Potato, Raspberry, Soybean, Squash, Strawberry, Tomato
Example Diseases:
- Apple: Scab, Black rot, Cedar apple rust
- Tomato: Bacterial spot, Early blight, Late blight, Leaf mold, Septoria leaf spot, Spider mites, Target spot, Yellow leaf curl virus, Mosaic virus
- Potato: Early blight, Late blight
- Grape: Black rot, Esca, Leaf blight
- And 23 more disease + healthy classes...
- Base Model: ResNet18 (pre-trained on ImageNet)
- Transfer Learning: Fine-tuned final layer for 38-class classification
- Input Size: 224x224 RGB images
- Total Parameters: 11.2M (19K trainable when backbone frozen)
- Random horizontal/vertical flips
- Random rotation (ยฑ20ยฐ)
- Color jitter (brightness, contrast, saturation, hue)
- Random affine transformations
- ImageNet normalization
- Optimizer: Adam (lr=0.001)
- Loss Function: CrossEntropyLoss
- Scheduler: ReduceLROnPlateau
- Batch Size: 32
- Train/Val/Test Split: 70/15/15
(Check results/visualizations/ for generated plots)
# In main.py, change model parameter
python main.py --model resnet34 # Options: resnet18, resnet34, resnet50python main.py --epochs 20 --batch_size 64 --freeze_backbone- Add new disease folders to
data/raw/ - Model will automatically detect and train on new classes
- Transfer learning significantly reduces training time and data requirements
- Data augmentation is critical for model generalization on agricultural images
- ResNet18 provides excellent balance between accuracy and computational efficiency
- Proper train/val/test splitting prevents overfitting on time-series-like datasets
- Checkpointing enables resuming training and deploying best models
- Deep learning with PyTorch
- Transfer learning and fine-tuning
- Computer vision for image classification
- Custom data augmentation pipelines
- Production ML workflows (training loops, checkpointing, evaluation)
- Large-scale dataset handling (54K+ images)
- Model performance analysis and visualization
Contributions are welcome! Feel free to:
- Report bugs or issues
- Suggest new features or improvements
- Submit pull requests
- Improve documentation
This project is licensed under the MIT License.
Miguel Shema Ngabonziza
- LinkedIn: linkedin.com/in/migztech
- GitHub: github.com/SNMiguel
- Portfolio: migztech.vercel.app
- PlantVillage dataset provided by Penn State University
- ResNet architecture from Deep Residual Learning for Image Recognition
- Built as part of technical skill development for AI/ML engineering roles
- Inspired by the need for accessible agricultural technology in developing regions
- Deploy as REST API with FastAPI
- Build Gradio web interface for live predictions
- Add explainability with Grad-CAM visualizations
- Implement mobile app for field use
- Expand to additional crop species
- Add disease severity classification
- Multi-language support for global accessibility
โญ If you found this project helpful, please consider giving it a star!


