Elastic Attention enables models to achieve both strong performance and efficient inference by dynamically allocating computation modes (Full Attention or Sparse Attention) to each attention head through our designed Attention Router, adapting sparsity ratios based on input characteristics.
Elastic Attention features:
-
High Training Efficiency: Within only 12 hours of training on
$8 \times$ A800 GPUs for 8B-scale models. - Long-sequence performance: Matches backbone models and surpasses baseline methods (experiments on Meta-Llama-3.1-8B-Instruct and Qwen3-series models).
- Inference efficiency: Achieves higher sparsity and faster inference speed on partial long-context tasks.
We recommend the following experimental environment, which can reproduce the results in the paper:
| Component | Specification | Notes |
|---|---|---|
| OS | Ubuntu 22.04.4 LTS | Tested on ID: ubuntu |
| Python | 3.11+ | Recommended |
| PyTorch | 2.6.0 | Ecosystem compatible |
| CUDA | 12.4+ | Required |
| GPU | NVIDIA A100/H100 (80GB) | High VRAM required |
Clone the repository and set up the basic PyTorch ecosystem.
# 1: Create a new Python environment
conda create -n elastic_attn python=3.10
conda activate elastic_attn
# Install PyTorch ecosystem (CUDA 12.4)
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124This project relies on Block-Sparse-Attention and other libraries.
β οΈ IMPORTANT:
Compilation of CUDA kernels may take up to 5-10 minutes. Please ensurenvccis in your PATH.
# 2.1 Install Block-Sparse-Attention (Custom CUDA Ops)
git clone https://github.com/mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention
# Ensure CUDA_HOME matches your local path (adjust if necessary)
export CUDA_HOME=/usr/local/cuda-12.4/
python setup.py install
cd ..
# 2.2 Install other python dependencies
pip install -r requirements.txt
pip install modelscope # Required for data download# Clone the repository
git clone https://github.com/LCM-Lab/Elastic-Attention.git
cd Elastic-Attention
pip install -e .We use ModelScope to host the datasets. The training data for different models is provided as follows:
You can use the following Python snippets to download the datasets programmatically:
from modelscope.msdatasets import MsDataset
# Download Qwen Mix SFT (64K)
dataset_qwen = MsDataset.load('LCM_group/qwen_mix_sft_64K6')
# Download LLaMA Mix SFT (64K)
dataset_llama = MsDataset.load('LCM_group/llama_mix_sft_64K6')Tip: For debugging or small-scale experiments, we provide cached dataset at:
elasticattn/public_data/data_cache/demo_data_qwen_packed_maxseq65536.parquet
Pre-trained models and checkpoints are available on ModelScope.
| Model Series | Models | Model Collection |
|---|---|---|
| Elastic-Attention Collection | Qwen3-4B / Qwen3-8B / Llama3.1-8B-Instruct |
To start training with the provided demo data, utilize the included startup script.
# Grant execution permissions
chmod +x elasticattn/run_scripts/training.sh
cd elasticattn
# Run the training script
bash run_scripts/training.sh
Configuration: Batch size, learning rate, and other hyperparameters can be modified inside
elasticattn/run_scripts/training.sh.
Here is a minimal example of how to use Elastic Attention for text generation.
π Click to expand the Inference Code
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
def load_sparse_model(model_path):
"""
Dynamically loads the correct sparse architecture based on config.
"""
config_path = f"{model_path}/config.json"
with open(config_path, "r") as f:
config_data = json.load(f)
arch = config_data.get("architectures", [])
if not arch:
raise ValueError("No architecture found in config.json")
arch_name = arch[0]
print(f"π Detected architecture: {arch_name}")
# Register custom architectures
if "PawLlama" in arch_name:
from elasticattn.training.eval.modeling_flash_llama_moe import (
PawLlamaForCausalLM, PawLlamaConfig
)
AutoModelForCausalLM.register(PawLlamaConfig, PawLlamaForCausalLM)
model_cls = PawLlamaForCausalLM
elif "PawQwen" in arch_name:
from elasticattn.training.eval.modeling_flash_qwen_moe import (
PawQwen3ForCausalLM, PawQwen3Config
)
AutoModelForCausalLM.register(PawQwen3Config, PawQwen3ForCausalLM)
model_cls = PawQwen3ForCausalLM
else:
raise ValueError(f"Unsupported architecture: {arch_name}")
# Load model
model = model_cls.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
return model
# --- Execution ---
model_path = "****" # <--- Replace with your checkpoint path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print("Loading Elastic Attention Model...")
model = load_sparse_model(model_path)
model.eval()
# Generate
input_text = "Explain quantum mechanics in one sentence."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
print("Generating...")
outputs = model.generate(**inputs, max_new_tokens=100)
print("\nOutput:\n" + tokenizer.decode(outputs[0], skip_special_tokens=True))We recommend using LOOM-Eval for comprehensive evaluation of long-context capabilities.
# 1. Clone and Install
git clone [https://github.com/LCM-Lab/LOOM-Eval.git](https://github.com/LCM-Lab/LOOM-Eval.git)
cd LOOM-Eval
pip install -e .
# 2. Run Evaluation
loomeval.run \
--model_path /path/to/model \
--cfg_path /benchmarks/General/RULER/configs/RULER.yaml \
--server transformers \
--acceleration elasticattn \
--device 0 1 2 3 4 5 6 7 \
--gp_num 1 \
--output_dir /path/to/results
We acknowledge and reference the following open-source implementations:
| Method | Repository |
|---|---|
| NSA (Native Sparse Attention) | XunhaoLai/native-sparse-attention-triton |
| MoBA | MoonshotAI/MoBA |
| InfLLM-V2 | OpenBMB/infllmv2_cuda_impl |
| XAttention | mit-han-lab/x-attention |
If you have any questions, please connect us with: zecheng.tang@foxmail.com or q_qtang@163.com.
If you find this project useful in your research, please consider citing:
@misc{tang2026elasticattentiontesttimeadaptive,
title={Elastic Attention: Test-time Adaptive Sparsity Ratios for Efficient Transformers},
author={Zecheng Tang and Quantong Qiu and Yi Yang and Zhiyi Hong and Haiya Xiang and Kebin Liu and Qingqing Dang and Juntao Li and Min Zhang},
year={2026},
eprint={2601.17367},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2601.17367},
}