Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,19 @@ jobs:
runs-on: linux-amd64-gpu-l4-latest-1
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: 3.10.11

- name: Setup CUDA
uses: Jimver/cuda-toolkit@v0.2.16
with:
cuda: '12.5.0'

- name: Set CUDA_HOME
run: echo "CUDA_HOME=/usr/local/cuda" >> $GITHUB_ENV

- name: Verify environment
run: |
nvidia-smi
python3 --version
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true

- name: Install dependencies
run: uv sync --all-groups

run: |
uv sync --all-groups
uv pip install torch==2.10
- run: make test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
8 changes: 1 addition & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,14 @@ reports:

.PHONY: test
test: reports
$(UV) pip install optimum-quanto
$(UV) pip install flash-attn
$(UV) pip install flash-attn --no-build-isolation --find-links https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/expanded_assets/v0.7.12
PYTHONPATH=. \
$(UV) run pytest \
--cov-report xml:reports/coverage.xml \
--cov=kvpress/ \
--junitxml=./reports/junit.xml \
-v \
tests/ | tee reports/pytest_output.log
@if grep -q "SKIPPED" reports/pytest_output.log; then \
echo "Error: Tests were skipped. All tests must run."; \
grep "SKIPPED" reports/pytest_output.log; \
exit 1; \
fi
@if grep -q "FAILED" reports/pytest_output.log; then \
echo "Error: Some tests failed."; \
grep "FAILED" reports/pytest_output.log; \
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Finally we provide wrapper presses that can be combined with other presses:
- `BlockPress` ([source](kvpress/presses/block_press.py), [paper](https://arxiv.org/abs/2504.15364)): segments input sequence into non-overlapping blocks and compresses iteratively.
- `DecodingPress` ([source](kvpress/presses/decoding_press.py)): allows for compression during decoding, see decoding section in this README.
- `PrefillDecodingPress` ([source](kvpress/presses/prefill_decoding_press.py)): allows to compress both during prefilling and during decoding.
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True).
- `DMSPress` ([source](kvpress/presses/dms_press.py), [paper](https://arxiv.org/abs/2506.05345)): evict keys and values with scores below a given threshold of any `ScorerPress` instead of relying on top-k scores. Support both prefilling and decoding (if decoding=True), but only supports dense-prefill and not sparse-prefill.

For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)

Expand Down
2 changes: 1 addition & 1 deletion kvpress/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def generate_answer(
generated_ids.append(new_id)
if new_id.item() in should_stop_token_ids:
break
answer = self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True)
answer = str(self.tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True))
return answer

def postprocess(self, model_outputs, single_question):
Expand Down
1 change: 1 addition & 0 deletions kvpress/presses/dms_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class DMSPress(BasePress):
"""
Based on Dynamic Memory Sparsification (DMS, https://arxiv.org/abs/2506.05345) inference.
Wraps a ScorerPress and evicts keys/values with scores below a given threshold.
This press implements a dense-prefill version of DMS, not the sparse-prefill version.

Unlike most presses that use a fixed compression_ratio, DMSPress uses a score threshold
to determine which KV pairs to evict. This allows for adaptive compression where the actual
Expand Down
8 changes: 4 additions & 4 deletions kvpress/presses/kvzap_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class KVzapModel(PreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.all_tied_weights_keys = {}
if config.hidden_dim is None:
# Linear model
self.layers = nn.ModuleList(
Expand Down Expand Up @@ -72,8 +73,7 @@ def score(
attentions: torch.Tensor,
kwargs: dict,
) -> torch.Tensor:
module = self.kvzap_model.layers[module.layer_idx]
module = module.to(hidden_states.device, dtype=hidden_states.dtype).eval()
with torch.no_grad():
scores = module(hidden_states).transpose(1, 2)
kvzap_module = self.kvzap_model.layers[module.layer_idx]
kvzap_module = kvzap_module.to(hidden_states.device, dtype=hidden_states.dtype).eval()
scores = kvzap_module(hidden_states).transpose(1, 2)
return scores
2 changes: 1 addition & 1 deletion kvzap/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _forward_hook(self, module, input, kwargs, output):
scale = scale.repeat_interleave(module.o_proj.block_size[0], dim=0)
scale = scale.repeat_interleave(module.o_proj.block_size[1], dim=1)
Wo = Wo.to(V.dtype) * scale
Wo = Wo.view(module.config.num_attention_heads, module.head_dim, module.config.hidden_size)
Wo = Wo.view(module.config.num_attention_heads, V.shape[-1], module.config.hidden_size)
WoV_norm = torch.einsum("h i j, b h t i -> b h t j", Wo.to(dtype=V.dtype), V).norm(dim=-1)
scores = torch.einsum("b h t i, b h i -> b h t i", scores, WoV_norm)

Expand Down
6 changes: 4 additions & 2 deletions kvzap/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ def train_linear(X: torch.Tensor, y: torch.Tensor) -> KVzapModel:
KVzapConfig(input_dim=X.shape[2], hidden_dim=None, output_dim=y.shape[2], n_modules=X.shape[1])
)
for layer_idx, (W, b) in enumerate(params):
linear_model.layers[layer_idx].weight.data = torch.tensor(W, dtype=X.dtype) # type: ignore[index]
linear_model.layers[layer_idx].bias.data = torch.tensor(b, dtype=X.dtype) # type: ignore[index]
W = torch.tensor(np.atleast_2d(W), dtype=X.dtype)
b = torch.tensor(np.atleast_1d(b), dtype=X.dtype)
linear_model.layers[layer_idx].weight.data = W # type: ignore[index]
linear_model.layers[layer_idx].bias.data = b # type: ignore[index]
return linear_model


Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kvpress"
version = "0.4.2"
version = "0.4.3"
description = "Efficiently compress the KV cache of any pretrained transformer"
authors = [
{ name = "Simon Jegou" },
Expand All @@ -15,7 +15,7 @@ dependencies = [
"torch>=2.3.1,<3",
# transformers<4.54 is not supported due to refactoring of the transformers library.
# transformers 4.54-4.55.2 are not compatible with kvpress due to flash attention bugs in transformers
"transformers>=4.56",
"transformers>=4.56,<5.0.0",
"sentencepiece>=0.2.0,<0.3",
"protobuf>=5.27.2,<6",
"datasets>=2.21.0,<3",
Expand Down