This repository is an unofficial PyTorch implementation of the paper [1] : "Transformers Meet Small Datasets". TSD-Net is a transformer-CNN hybrid architecture with multiple novel layers including the Convolutional Parameter Sharing Attention(CPSA) Block and the Local Feed-Forward Network(LFFN) Block. The core idea behind these novelties is to increase the effectiveness of CNNs which excel at capturing local patterns through attention mechanisms that excel at capturing global patterns.
For more details, please refer the paper.
Figure 1: Overall architecture of TSD-Net.
This repo contains code for training the TSD-T(TSD-Tiny) and TSD-B(TSD-Big) variants on the CIFAR-10 Dataset.The training logic can be customized to train on other datasets and the training parameters and the model parameters can be easily modified using the config/config.yaml file. There is also support for model checkpointing. In order to enable it, the users need only create a checkpoints folder and two subfolders 'latest' and 'best' inside it to store the latest and the best versions of the model, respectively during training.
There is also support to auto-resume training from a checkpoint. You can set 'AUTO_RESUME' to True in the config.yaml file. By default, this auto-resumes training from the model inside checkpoints/latest but this can be tweaked by modifying utils.resume_checkpoint function and tools.train functions.
git clone https://github.com/soulsharp/ADTransformerPlease ensure that PyTorch supports CUDA on your setup and install the requirments in your virtual env.
pip install -r requirements.txtTo train the classification model(TSD-T by default):
In case you wanna train a different model, you can modify the logic inside dataset.py. By default the CIFAR-10 Dataset is trained.
All training-related parameters can be modified in the config.yaml file. Important: Do not modify the tsd_params section to avoid breaking the training pipeline.
Make sure that in the config file, train_params["checkpoint_dir] points to the base of the checkpoint directory
Run the following command:
python -m tools.trainPull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
[1] @ARTICLE{9944625,
author={Shao, Ran and Bi, Xiao-Jun},
journal={IEEE Access},
title={Transformers Meet Small Datasets},
year={2022},
volume={10},
number={},
pages={118454-118464},
keywords={Transformers;Convolutional neural networks;Computational modeling;Computer architecture;Training data;Feature extraction;Data models;Visual analytics;Convolutional neural networks;small datasets;transformer;vision transformer},
doi={10.1109/ACCESS.2022.3221138}}