This is the official implementation of the paper Kolmogorov-Arnold Attention: Is Learnable Attention Better for Vision Transformers? by S. Maity, K. Hitsman, X. Li, A. Dutta.
- 8 Jul 2025 : Initial code is now available. If you face any problem, please read the FAQ & Issues section and raise an issue if necessary.
π Requirements
π Getting Started
π FAQ & Issues
π Credits
π Citation
| Model | CIFAR-10 | CIFAR-100 | ImageNet-1K |
|---|---|---|---|
| ViT-Base | 83.45 | 58.07 | 72.90 |
| + G1B | 81.81 | 55.92 | 68.03 |
| + G1U | 80.75 | 57.36 | 68.83 |
| ViT-Small | 81.08 | 53.47 | 70.50 |
| + G3B | 79.78 | 54.11 | 67.77 |
| + G3U | 79.52 | 53.86 | 67.76 |
| ViT-Tiny | 72.76 | 43.53 | 59.15 |
| + G3B | 76.69 | 46.29 | 59.11 |
| + G3U | 75.56 | 46.75 | 57.97 |
- Python 3.11
- torch 2.4.0, torchvision 0.19.0, torchaudio 2.4.0
- timm 1.0.9
- einops 0.8.0
- fvcore 0.1.5.post20221221
- scipy 1.9.3
- numpy 1.26.3
- pillow 10.2
- wandb 0.18.7
β Dataset
β Environment
β Training
β Evaluation
β Finetuning
-
This implementation includes provisioning for six datasets: CIFAR-10, CIFAR-100, ImageNet-1K, SVHN, Flowers 102, and STL-10. Use the below arguments for using each of these datasets.
--data-set CIFAR10for CIFAR-0--data-set CIFAR100for CIFAR-100--data-set IMNETfor ImageNet-1K--data-set SVHNfor SVHN--data-set Flowers102for Flowers 102--data-set STL10for STL-10
-
All the datasets except ImageNet-1K will be automatically downloaded. The default directory is
./data/and can be changed using--data-pathargument. -
For ImagNet-1K the code expects the data in ImageFolder format as shown below. Pass the ImageNet root directory using
--data-pathargument.
βββ path/to/imagenet # ImageNet-1K root directory
βββ train # train set root
| βββ class 1
| | βββ image 1.jpeg
| | βββ image 2.jpeg
| | βββ ...
| |
| βββ class 2
| | βββ image 1.jpeg
| | βββ image 2.jpeg
| | βββ ...
| |
| βββ class 3
| | βββ ...
| |
| βββ class 1000
| βββ image 1.jpeg
| βββ image 2.jpeg
| βββ ...
|
βββ validation # train set root
βββ class 1
| βββ image 1.jpeg
| βββ image 2.jpeg
| βββ ...
|
βββ class 2
| βββ image 1.jpeg
| βββ image 2.jpeg
| βββ ...
|
βββ class 3
| βββ ...
|
βββ class 1000
βββ image 1.jpeg
βββ image 2.jpeg
βββ ...
You can install the reqired packages on a Python 3.11 installation as given in requirements or you can install using the below command.
pip install -r requirements.txt
This implementation is based on the DeiT codebase and retains most of its original functionalities, despite most of them never being used here. Below is a brief overview of the available arguments.
--modelargument specifies the name of the model to train. We have six available models as following.vit_tiny_patch16_224vit_small_patch16_224vit_base_patch16_224vit_tiny_patch16_224_karatvit_small_patch16_224_karatvit_base_patch16_224_karat
- We keep the standard arguments as set by the DeiT repository. Please follow the same arguments setup for this implementation. This includes the common hyperparameters.
- Model realted hyperparameters:
--dropand--drop-path. - Common training hyperparametrs:
--batch-size(default 64),--epochs(default 100). - Choice of optimizer using
--opt(deafult AdamW) and related hyperparameters:--opt-eps,--opt-betas,--clip-grad,--momentum,--weight-decay. - Choice of learning rate scheduler
--sched(default cosine), base learnning rate--lr(default 5Γ10-4), and other scheduler related hyperparameters:--warmup-lr,--min-lr,--decay-epochs,--warmup-epochs,--decay-rate. - Arguments related to data augmentation:
--color-jitter,--smoothingetc.
- Model realted hyperparameters:
- Please use the
--output_dirargument to set the directory path where training logs and checkpoint are to be saved. The deafult path is an empty string which signifies no saving. - Please use
--deviceto set the training device. The default device iscuda. --run_nameand--run_idarguments are used by Weights and Biases logging.--run_nameis used to provide a job name in Weights and Biases, and if not passed, it disables Weights and Biases logging.--run_idis only used to provide the existing job identifier while resuming a job.- The hyperparameters related to KArAt are as the following.
--base_fnfor setting the base activation function. The default is a custom functionZeroModulewhich ensures no base function. Any standard function fromtorch.nncan be used. eg.--base_fn nn.SiLUor--base_fn nn.ReLU--basis_typetype of the basis function for KArAt. The deafult isFourierBasisand the other choices areSplineBasis,WaveletBasis, andRationalBasis.--gridfor the gridsize G in Fourier and Spline basis.--orderfor the order of the Spline.--grid_rangefor the grid range of the Spline.--wavelet_typefor determining the type of Wavelet when using Wavelet basis. The choices aremexican_hat,morlet,dog,meyer, andshannon.--rational_orderfor specifying the numerator and denominator of the Rational basis.--depthfor the number of total layers in a KArAt module. The default is 2, one KAN layer and one Linear Projection layer.--hidden_dimis the low rank inner dimension r. The default is 12.--mixed_headargument is used to enable training with KArAt and softmax attention heads simulataneously.--num_attentionargument determines the number of KArAt heads and the rest of the heads are softmax. If not mentioned, this defaults to 1, 3, and 6 for ViT Tiny, Small and Base respectively. This is not used at all in the main work.--universalto enable Universal mode and--blockwisefor the Blockwise mode of KArAt. The default is Blockwise mode.modular_attnis used to enable different types of layers in a KArAt module instead of only KAN layers. This is used in the main work.--modular_modeis used for determining the order of KAN layer and Linear Projector. The available choices arew1_phi2andphi1_w2. The default is set tophi1_w2which is presented in the main work.--project_l1argument is used to enable projection on L1 ball after KArAt activation to ensure the values are bounded between 0 and 1 and sum up to 1. This is not used in the main work.--basis_not_trainableand--base_not_trainablearguments are used to freeze the learnable parameters in the Basis function and Base Activation function respectively.--attn_lris used to set a different learning rate for the KArAt modules. The default value is none, which falls back the base learning rate.
To train ViT-Tiny + Fourier KArAt G3B with Blockwise mode, use the following command.
python main.py --model vit_tiny_patch16_224_karat --batch-size 128 --data-path path/to/imagenet --data-set IMNET --output_dir ./output/vit_tiny_patch16_224_FourierKArAt_G3B --epochs 100 --num_workers 8 --grid 3 --hidden_dim 12 --basis_type FourierBasis --base_fn ZeroModule --base_not_trainable --modular_attn --modular_mode phi1_w2 --blockwise
To train ViT-Tiny + Fourier KArAt G3U with Universal mode, use the following command.
python main.py --model vit_tiny_patch16_224_karat --batch-size 128 --data-path path/to/imagenet --data-set IMNET --output_dir ./output/vit_tiny_patch16_224_FourierKArAt_G3U --epochs 100 --num_workers 8 --grid 3 --hidden_dim 12 --basis_type FourierBasis --base_fn ZeroModule --base_not_trainable --modular_attn --modular_mode phi1_w2 --universal
To train on n GPUs please use torchrun --nproc_per_node=n main.py instead of python main.py, and keep the rest of the argument parsing as it is. Please note that the effective batch size in Multi-GPU scenario is batch size * n.
For evaluation, use --resume /path/to/checkpoint.pth argument to provide weights and --eval flag along with the rest of the arguments.
For finetuning a pre-trained model use --finetune /path/to/pretrained_checkpoint.pth for providing pre-trained checkpoint along with the rest of the arguments.
- Our experimental settings used 32 GB of memory and H100 GPUs with 8 workers in the dataloader. Training may be slower if this specifications are not met.
If you are facing a problem not listed here, please raise an issue.
The codebase is directly adapted from DEiT and id heavily inspired from the following wonderful open-source repositories.
If you use our code for your research, please cite our paper. Many thanks!
@article{maity2025karat,
title={Kolmogorov-Arnold Attention: Is Learnable Attention Better For Vision Transformers?},
author={Maity, Subhajit and Hitsman, Killian and Li, Xin and Dutta, Aritra},
journal={arXiv preprint arXiv:2503.10632},
year={2025}
}
