Skip to content

Elastic Attention: Test-time Adaptive Sparsity Ratios for Efficient Transformers

License

Notifications You must be signed in to change notification settings

LCM-Lab/Elastic-Attention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

18 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸš€ Elastic Attention: Test-time Adaptive Sparsity Ratios for Efficient Transformers

arXiv Hugging Face Collection ModelScope License


πŸ“– Quick Scan

Elastic Attention enables models to achieve both strong performance and efficient inference by dynamically allocating computation modes (Full Attention or Sparse Attention) to each attention head through our designed Attention Router, adapting sparsity ratios based on input characteristics.

Method Overview

Elastic Attention features:

  • High Training Efficiency: Within only 12 hours of training on $8 \times$ A800 GPUs for 8B-scale models.
  • Long-sequence performance: Matches backbone models and surpasses baseline methods (experiments on Meta-Llama-3.1-8B-Instruct and Qwen3-series models).
  • Inference efficiency: Achieves higher sparsity and faster inference speed on partial long-context tasks.

πŸ’» System Environment

We recommend the following experimental environment, which can reproduce the results in the paper:

Component Specification Notes
OS Ubuntu 22.04.4 LTS Tested on ID: ubuntu
Python 3.11+ Recommended
PyTorch 2.6.0 Ecosystem compatible
CUDA 12.4+ Required
GPU NVIDIA A100/H100 (80GB) High VRAM required

βš™οΈ Installation

1. Setup Python Environment

Clone the repository and set up the basic PyTorch ecosystem.

# 1: Create a new Python environment
conda create -n elastic_attn python=3.10
conda activate elastic_attn

# Install PyTorch ecosystem (CUDA 12.4)
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

2. Install Dependencies

This project relies on Block-Sparse-Attention and other libraries.

⚠️ IMPORTANT:
Compilation of CUDA kernels may take up to 5-10 minutes. Please ensure nvcc is in your PATH.

# 2.1 Install Block-Sparse-Attention (Custom CUDA Ops)
git clone https://github.com/mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention

# Ensure CUDA_HOME matches your local path (adjust if necessary)
export CUDA_HOME=/usr/local/cuda-12.4/
python setup.py install
cd ..

# 2.2 Install other python dependencies
pip install -r requirements.txt
pip install modelscope  # Required for data download

3. Install Elastic Attention

# Clone the repository
git clone https://github.com/LCM-Lab/Elastic-Attention.git
cd Elastic-Attention
pip install -e .

πŸ“š Data Preparation

We use ModelScope to host the datasets. The training data for different models is provided as follows:

Download Datasets in Code

You can use the following Python snippets to download the datasets programmatically:

from modelscope.msdatasets import MsDataset

# Download Qwen Mix SFT (64K)
dataset_qwen = MsDataset.load('LCM_group/qwen_mix_sft_64K6')

# Download LLaMA Mix SFT (64K)
dataset_llama = MsDataset.load('LCM_group/llama_mix_sft_64K6')

Tip: For debugging or small-scale experiments, we provide cached dataset at: elasticattn/public_data/data_cache/demo_data_qwen_packed_maxseq65536.parquet

🏰 Model Zoo

Pre-trained models and checkpoints are available on ModelScope.

Model Series Models Model Collection
Elastic-Attention Collection Qwen3-4B / Qwen3-8B / Llama3.1-8B-Instruct Hugging Face / ModelScope

πŸƒ Training

To start training with the provided demo data, utilize the included startup script.

# Grant execution permissions
chmod +x elasticattn/run_scripts/training.sh

cd elasticattn
# Run the training script
bash run_scripts/training.sh

Configuration: Batch size, learning rate, and other hyperparameters can be modified inside elasticattn/run_scripts/training.sh.

⚑ Quick Start (Inference)

Here is a minimal example of how to use Elastic Attention for text generation.

πŸ‘‡ Click to expand the Inference Code
import torch
import json
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_sparse_model(model_path):
    """
    Dynamically loads the correct sparse architecture based on config.
    """
    config_path = f"{model_path}/config.json"
    with open(config_path, "r") as f:
        config_data = json.load(f)

    arch = config_data.get("architectures", [])
    if not arch:
        raise ValueError("No architecture found in config.json")

    arch_name = arch[0]
    print(f"πŸš€ Detected architecture: {arch_name}")

    # Register custom architectures
    if "PawLlama" in arch_name:
        from elasticattn.training.eval.modeling_flash_llama_moe import (
            PawLlamaForCausalLM, PawLlamaConfig
        )
        AutoModelForCausalLM.register(PawLlamaConfig, PawLlamaForCausalLM)
        model_cls = PawLlamaForCausalLM
        
    elif "PawQwen" in arch_name:
        from elasticattn.training.eval.modeling_flash_qwen_moe import (
            PawQwen3ForCausalLM, PawQwen3Config
        )
        AutoModelForCausalLM.register(PawQwen3Config, PawQwen3ForCausalLM)
        model_cls = PawQwen3ForCausalLM
    else:
        raise ValueError(f"Unsupported architecture: {arch_name}")

    # Load model
    model = model_cls.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )
    return model

# --- Execution ---
model_path = "****" # <--- Replace with your checkpoint path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

print("Loading Elastic Attention Model...")
model = load_sparse_model(model_path)
model.eval()

# Generate
input_text = "Explain quantum mechanics in one sentence."
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

print("Generating...")
outputs = model.generate(**inputs, max_new_tokens=100)
print("\nOutput:\n" + tokenizer.decode(outputs[0], skip_special_tokens=True))

βš–οΈ Evaluation

We recommend using LOOM-Eval for comprehensive evaluation of long-context capabilities.

# 1. Clone and Install
git clone [https://github.com/LCM-Lab/LOOM-Eval.git](https://github.com/LCM-Lab/LOOM-Eval.git)
cd LOOM-Eval
pip install -e .

# 2. Run Evaluation
loomeval.run \ 
  --model_path /path/to/model \
  --cfg_path /benchmarks/General/RULER/configs/RULER.yaml \
  --server transformers \
  --acceleration elasticattn \
  --device 0 1 2 3 4 5 6 7 \
  --gp_num 1 \
  --output_dir /path/to/results

πŸ”— Related Implementations

We acknowledge and reference the following open-source implementations:

Method Repository
NSA (Native Sparse Attention) XunhaoLai/native-sparse-attention-triton
MoBA MoonshotAI/MoBA
InfLLM-V2 OpenBMB/infllmv2_cuda_impl
XAttention mit-han-lab/x-attention

πŸ“¬ Contact

If you have any questions, please connect us with: zecheng.tang@foxmail.com or q_qtang@163.com.

πŸ“ Citation

If you find this project useful in your research, please consider citing:

@misc{tang2026elasticattentiontesttimeadaptive,
      title={Elastic Attention: Test-time Adaptive Sparsity Ratios for Efficient Transformers}, 
      author={Zecheng Tang and Quantong Qiu and Yi Yang and Zhiyi Hong and Haiya Xiang and Kebin Liu and Qingqing Dang and Juntao Li and Min Zhang},
      year={2026},
      eprint={2601.17367},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2601.17367}, 
}
Built with ❀️ by the LCM Group

About

Elastic Attention: Test-time Adaptive Sparsity Ratios for Efficient Transformers

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •