diff --git a/predict_imagenet_label.py b/predict_imagenet_label.py index 576f57d..79bf532 100644 --- a/predict_imagenet_label.py +++ b/predict_imagenet_label.py @@ -10,7 +10,8 @@ import time import torch from timm.data import ImageDataset, create_loader, resolve_data_config -from timm.models import apply_test_time_pool, create_model +from timm.models import create_model +from timm.models.layers import apply_test_time_pool from timm.utils import AverageMeter, setup_default_logging torch.backends.cudnn.benchmark = True