This is the PyTorch implementation of the paper TEP-ones.
The repository is organized as follows:
- In the
classificationfolder scripts are provided to train and prune your own backbones on Imagenet-1k. - In the
transfer_learningfolder scripts are provided to fine-tune the pruned models on several downstream task and compute the transferability scores.
Requirements are saved in environment.yml.
To compute the correlation values you need to have both the transferability scores and the final test accuracy on each downstream task.
All of the pre-computed transferability scores are saved in transfer_learning/transferability_scores divided per model and organized per category, while the test accuracies are saved in transfer_learning/test_accs_best_hyp.json.
To obtain the correlation results launch the following scripts:
cd transfer_learning
python correlation.py --transferability-scores transferability_scores/ --model resnet50
python correlation.py --transferability-scores transferability_scores/ --model swin
While for the ablation studies:
cd transfer_learning
python correlation.py --transferability-scores transferability_scores/ --model resnet50 --ablation
python correlation.py --transferability-scores transferability_scores/ --model swin --ablation
To compute your own transferability scores, first download and unzip the following directory with the pruned models: drive
Save the activations of each model for each dataset by running the following script.
Remember to replace the following args:
- --base-model-path
/path/to/saved/pruned_models/(path where you saved the downloaded pruned models) - --root
/path/to/datasets(is the root, each dataset will be automatically downloaded inside of it)
cd transfer_learning
sh ../script/save_acts.sh
Compute the transferability score of each model for each dataset by running the following script.
Remember to replace the following args:
- --base-model-path
/path/to/saved/pruned_models/(path where you saved the downloaded pruned models) - --root
/path/to/datasets(is the root, each dataset will be automatically downloaded inside of it)
cd transfer_learning
sh ../script/transferabilty_metrics/metrics_resnet50.sh
sh ../script/transferabilty_metrics/swin.sh
While for the ablation studies:
cd transfer_learning
sh ../script/transferabilty_metrics/ablations_resnet50.sh
sh ../script/transferabilty_metrics/ablations_swin.sh
To have a transferability metric based on the results after one epoch of training:
cd transfer_learning
sh ../script/transferabilty_metrics/ablations_train_one_epoch_resnet50.sh
sh ../script/transferabilty_metrics/ablations_train_one_epoch_swin.sh
For estimating the computation time:
cd transfer_learning
sh ../script/times/metrics_resnet50.sh
sh ../script/times/swin.sh
sh ../script/times/ablations_resnet50.sh
sh ../script/times/ablations_swin.sh
-
$MODEL: resnet50 / swin
-
$DATASET: Flowers102/ Cifar10/ Cifar100/ Food101/ DTD/ Aircraft/ Pet
-
$HYP-CONF: from 0 to 19 included
cd transfer_learning
python train.py --batch-size 64 --checkpoint /path/to/save/checkpoints/pruned_$MODEL/$MODEL_iter_0_seed_1/checkpoint.pth --dataset $DATASET --freeze-bn 1 --loss cross_entropy --lr-scheduler easy_cosine --lr-wd-set $HYP-CONF --model $MODEL --opt sgd --root /path/to/datasets --search-hyp 1 --seed 42
A sweep file to perform this hyperparamter search is provided in script/sweep/hyp_search.yaml
-
$LR: founded lr
-
$WD: founded weight-decay
-
$SEED: 0,1,2
-
$MODEL: resnet50 / swin
-
$DATASET: Flowers102/ Cifar10/ Cifar100/ Food101/ DTD/ Aircraft/ Pet
-
$ITER: iteration of pruned model to fine tune
cd transfer_learning
python train.py --batch-size 64 --checkpoint /path/to/save/checkpoints/pruned_$MODEL/$MODEL_iter_$ITER_seed_1/checkpoint.pth --dataset $DATASET --freeze-bn 1 --loss cross_entropy --lr $LR --weight-decay $WD --lr-scheduler easy_cosine --model $MODEL --opt sgd --root --root /path/to/datasets --save-dir path/to/save/finetuned/model/ --search-hyp 0 --seed $SEED
A sweep file to perform these fine-tunings with the founded lr and weight-decay for each Dataset are provided in script/sweep/fine_tune_resnet50 for ResNet-50 and in script/sweep/fine_tune_swin for SwinV2-T.
cd transfer_learning
python save_test_acc.py --saved-checkpoints path/to/save/finetuned/model/
- K-means algorithm comes from this repository.
- The transferabilty baselines comes from the NCTI and SFDA repositories.
- All of the training scripts are developed based on torchvision official repository.