From effad9203c1bb18e550c03bc1dcd4e29e8838162 Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Tue, 23 Sep 2025 11:54:02 +0200 Subject: [PATCH 1/8] Add KVCompose Signed-off-by: Dmitry Akulov --- evaluation/evaluate_registry.py | 3 + kvpress/__init__.py | 2 + kvpress/presses/kvcompose_press.py | 425 +++++++++++++++++++++++++++++ 3 files changed, 430 insertions(+) create mode 100644 kvpress/presses/kvcompose_press.py diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index 0dc4bd03..a6d35528 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -22,6 +22,7 @@ FinchPress, KeyDiffPress, KnormPress, + KVComposePress, KVzipPress, ObservedAttentionPress, PyramidKVPress, @@ -72,6 +73,8 @@ "expected_attention": ExpectedAttentionPress(), "finch": FinchPress(), "keydiff": KeyDiffPress(), + "kvcompose": KVComposePress(), + "kvcompose_unstructured": KVComposePress(structured=False), "kvzip": KVzipPress(), "knorm": KnormPress(), "observed_attention": ObservedAttentionPress(), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index cccd1f02..ff37ce24 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -18,6 +18,7 @@ from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.keydiff_press import KeyDiffPress from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.kvcompose_press import KVComposePress from kvpress.presses.kvzip_press import KVzipPress from kvpress.presses.lagkv_press import LagKVPress from kvpress.presses.observed_attention_press import ObservedAttentionPress @@ -65,4 +66,5 @@ "KeyDiffPress", "KVzipPress", "ExpectedAttentionStatsPress", + "KVComposePress", ] diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py new file mode 100644 index 00000000..a0ee79c1 --- /dev/null +++ b/kvpress/presses/kvcompose_press.py @@ -0,0 +1,425 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import Generator, Union + + +import numpy as np +import torch +import types +from torch import nn +from transformers.models.llama import LlamaForCausalLM +from transformers.modeling_utils import PreTrainedModel +from transformers.cache_utils import DynamicCache +from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + +from kvpress.presses.base_press import BasePress + +logger = logging.getLogger(__name__) + + +class Aggregator(ABC): + n: int + data: torch.Tensor + neutral: float + device: str + + def __init__(self, n, device): + self.n = n + self.device = device + self.data = torch.full((self.n, ), self.neutral, device=self.device) + + def reset(self): + self.__init__(self.n, self.device) + + def partial_fit(self, nd_data: Union[torch.Tensor, np.ndarray]): + if isinstance(nd_data, np.ndarray): + nd_data = torch.from_numpy(nd_data) + if nd_data.ndim == 1: + nd_data = nd_data.unsqueeze(0) + return self._partial_fit(nd_data) + + @abstractmethod + def _partial_fit(self, nd_data: torch.Tensor): + pass + + def transform(self): + return self.data + + def fit(self, *args): + self.__init__(self.n, self.device) + self.partial_fit(*args) + + def fit_transform(self, *args): + self.fit(*args) + return self.transform() + + +class MaxAggregator(Aggregator): + def __init__(self, n, device): + self.neutral = -torch.inf + super().__init__(n, device) + + def _partial_fit(self, nd_data: torch.Tensor): + new_max_data = nd_data.amax(dim=tuple(range(len(nd_data.shape)-1))) + self.data = torch.maximum(self.data, new_max_data) + + +class MeanAggregator(Aggregator): + sum_data: torch.Tensor + count_data: torch.Tensor + + def __init__(self, n, device): + self.neutral = 0. + super().__init__(n, device) + self.sum_data = torch.full((n, ), self.neutral, device=self.device) + self.count_data = torch.full((n, ), self.neutral, device=self.device) + + def _partial_fit(self, nd_data: torch.Tensor): + new_sum_data = nd_data.sum(dim=tuple(range(len(nd_data.shape)-1))) + new_count_data = torch.ones_like(nd_data, device=self.device).sum(dim=tuple(range(len(nd_data.shape)-1))) + self.sum_data += new_sum_data + self.count_data += new_count_data + self.data = self.sum_data / self.count_data + + +aggregator_by_name = { + "mean": MeanAggregator, + "max": MaxAggregator, +} + + +def get_aggregator(aggregator_name: str) -> type[Aggregator]: + aggregator_name = aggregator_name.lower() + if aggregator_name not in aggregator_by_name: + raise ValueError(f"Unknown aggregator_name: {aggregator_name}. " + f"Available: {list(aggregator_by_name)}") + return aggregator_by_name[aggregator_name] + + +def copy_cache(cache: DynamicCache): + new_cache = DynamicCache() + for (layer, (k, v)) in enumerate(cache): + new_cache.update(layer_idx=layer, key_states=k.clone(), value_states=v.clone()) + return new_cache + + +@dataclass +class KVComposePress(BasePress): + """ + KVComposePress implements KVCompose: a structured KV cache compression + method that remains compatible with standard inference pipelines. + + Setting `structured=False` enables the unstructured variant where each head + retains tokens independently (no composite alignment). This generally yields + better theoretical performance but is incompatible with standard KV cache layouts + unless the attention mechanism is modified. + + Based on KVCompose (https://arxiv.org/abs/2509.05165). + + Parameters + ---------- + structured : bool, default=True + Whether to use the structured or unstructured method. + compression_ratio : float, default=0.0 + Global fraction of KV tokens to remove. + agg_task : str, default="max" + Strategy to form per-context-token importance score per layer/head from + attention (e.g. 'max', 'mean'). + agg_group : str, default="mean" + Aggregation within each head across groups (for grouped query attention). + agg_head : str, default="mean" + Aggregation across heads to form composite importance score (used for + structured alignment). + add_v_norm : bool, default=False + Whether to multiply token score by the norm of its value vector. + add_mean_across_heads : bool, default=True + Whether to augment token scores with the mean score across all heads + to improve stability. + keep_token_lower_bound : int, default=0 + Minimum number of tokens to keep in each layer. + """ + + def __init__(self, + structured: bool = True, + compression_ratio: float = 0, + agg_task: str = "max", + agg_group: str = "mean", + agg_head: str = "mean", + add_v_norm: bool = False, + add_mean_across_heads: bool = True, + keep_token_lower_bound: int = 0, + ): + self.structured = structured + self.compression_ratio = compression_ratio + self.agg_task = agg_task + self.agg_group = agg_group + self.agg_head = agg_head + + self.add_v_norm = add_v_norm + self.add_mean_across_heads = add_mean_across_heads + self.keep_token_lower_bound = keep_token_lower_bound + + super().__init__() + + def __post_init__(self): + assert 0 <= self.compression_ratio < 1, "Compression ratio must be between 0 and 1" + + def _init_statistics(self): + self.task_aggregators = [ + [get_aggregator(self.agg_task)(self.context_len, self.device) for _ in range(self.num_att_heads)] + for _ in range(self.num_layers) + ] + + def _register_model(self, model: PreTrainedModel): + self.model = model + self.num_layers: int = getattr(model.config, "num_hidden_layers") + self.num_att_heads: int = getattr(model.config, "num_attention_heads") + self.num_kv_heads: int = getattr(model.config, "num_key_value_heads") + self.num_kv_groups: int = self.num_att_heads // self.num_kv_heads + self.device = next(model.parameters()).device + + def register_context_ids(self, context_ids: torch.Tensor): + self.context_ids = context_ids + self.context_len = self.context_ids.shape[-1] + self.prompt_ids = [] + self._init_statistics() + + def register_prompt_ids(self, prompt_ids: list[torch.Tensor]): + self.prompt_ids = prompt_ids + + def _register_cache(self, cache: DynamicCache): + self.cache = cache + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Fitting self.task_aggregators with the attention scores from the forward pass. + Attentions are cleaned up from the output to save memory. + """ + layer = int(module.layer_idx) + if not self.structured and len(self.modules) == layer: + # Record modules for later masking (used only in unstructured mode). + self.modules.append(module) + + layer_attentions = output[1] + assert layer_attentions is not None + + for att_head in range(self.num_att_heads): + if layer_attentions.shape[3] == layer_attentions.shape[2]: + # Skip self-to-self attention (prefill step), only record context-to-query attentions + continue + layer_att_head_attention = layer_attentions[:, att_head, :, :self.context_len] + self.task_aggregators[layer][att_head].partial_fit(layer_att_head_attention) + + # Clean up attention to save memory. + output = list(output) + del output[1] + output.append(None) + + return output + + def get_scores(self): + """ + Obtaining token scores by doing aggregation over groups for every kv head. + """ + self.scores = torch.zeros((self.num_layers, self.num_kv_heads, self.context_len), device=self.device) + for layer in range(self.num_layers): + for kv_head in range(self.num_kv_heads): + group_aggregator = get_aggregator(self.agg_group)(self.context_len, self.device) + for att_head in range(kv_head * self.num_kv_groups, (kv_head + 1) * self.num_kv_groups): + group_aggregator.partial_fit(self.task_aggregators[layer][att_head].transform()) + self.scores[layer, kv_head] = group_aggregator.transform() + + def enhance_scores(self): + """ + Enhance token scores by incorporating value vector norms and mean scores. + """ + for layer in range(self.num_layers): + if self.add_v_norm: + for kv_head in range(self.num_kv_heads): + v = self.cache.value_cache[layer][0, kv_head].detach() + self.scores[layer, kv_head] = self.scores[layer, kv_head] * v.norm(dim=1) + if self.add_mean_across_heads: + self.scores[layer] += self.scores[layer].mean(dim=0, keepdim=True) + + def get_composite_scores(self): + self.composite_scores_per_head = self.scores.sort(dim=-1, descending=True)[0] + + self.composite_scores_per_layer = torch.full((self.num_layers, self.context_len), 0., device=self.device) + for layer in range(self.num_layers): + layer_aggregator = get_aggregator(self.agg_head)(self.context_len, self.device) + for kv_head in range(self.num_kv_heads): + layer_aggregator.partial_fit(self.scores[layer, kv_head].sort(descending=True)[0]) + self.composite_scores_per_layer[layer] = layer_aggregator.transform() + + def get_important_per_layer(self): + """ + Calculates how many tokens to keep per layer (and per head for unstructured). + """ + self.get_composite_scores() + + self.composite_scores_per_layer[..., :self.keep_token_lower_bound] += 1e9 + self.composite_scores_per_layer[0] = \ + self.composite_scores_per_layer.max(dim=0).values # Ensures first layer is the largest. + threshold_layer = self.composite_scores_per_layer.quantile(self.compression_ratio) + self.important_per_layer = (self.composite_scores_per_layer >= threshold_layer).sum(dim=-1).cpu().numpy() + + self.composite_scores_per_head[..., :self.keep_token_lower_bound] += 1e9 + threshold_head = self.composite_scores_per_head.quantile(self.compression_ratio) + self.important_per_head = (self.composite_scores_per_head >= threshold_head).sum(dim=-1).cpu().numpy() + + def prepare_important_masks(self): + """ + Building masks of tokens to keep per kv head. + """ + self.get_scores() + self.enhance_scores() + self.get_important_per_layer() + + self.important_mask_per_kv_head = [ + [ + torch.zeros(size=(self.context_len, ), device=self.device, dtype=torch.bool) + for _ in range(self.num_kv_heads) + ] + for _ in range(self.num_layers) + ] + + for layer in range(self.num_layers): + for kv_head in range(self.num_kv_heads): + count_of_important = ( + self.important_per_layer[layer] + if self.structured + else self.important_per_head[layer, kv_head] + ) + important_indices = torch.argsort(self.scores[layer, kv_head], descending=True)[:count_of_important] + self.important_mask_per_kv_head[layer][kv_head][important_indices] = True + + def compress_structured(self) -> None: + """ + Preparing compressed version of the cache. + For KVPress, we modify the cache in-place (stored in self.cache). + For general use, it's accessible via self.compressed_cache. + """ + self.compressed_cache = DynamicCache() + + for layer in range(self.num_layers): + kv_over_layer = [[], []] + for kv_head in range(self.num_kv_heads): + important_mask = self.important_mask_per_kv_head[layer][kv_head] + + keys = self.cache.layers[layer].keys[0, kv_head][:self.context_len].detach().clone() + values = self.cache.layers[layer].values[0, kv_head][:self.context_len].detach().clone() + keys = keys[important_mask] + values = values[important_mask] + kv_over_layer[0].append(keys) + kv_over_layer[1].append(values) + new_key_states = torch.stack(kv_over_layer[0], dim=0).unsqueeze(0) + new_value_states = torch.stack(kv_over_layer[1], dim=0).unsqueeze(0) + self.compressed_cache.update(layer_idx=layer, key_states=new_key_states, value_states=new_value_states) + + self.cache.layers[layer].keys = new_key_states + self.cache.layers[layer].values = new_value_states + + def compress_unstructured(self) -> None: + """ + Storing evicted indices in module.masked_key_indices. + Relies on attention_patch.py implementation that simulates real eviction. + Supports only batch size 1. + """ + if self.context_ids.shape[0] != 1: + raise NotImplementedError("Unstructured compression supports only batch size 1.") + for layer in range(self.num_layers): + masked_over_layer = [[], [], []] + + for kv_head in range(self.num_kv_heads): + non_important_mask = ~self.important_mask_per_kv_head[layer][kv_head] + num_non_important_tokens = int(non_important_mask.sum().item()) + batch_indices = torch.full((num_non_important_tokens, ), 0, device=self.device) + head_indices = torch.full((num_non_important_tokens, ), kv_head, device=self.device) + seq_indices = non_important_mask.nonzero(as_tuple=True)[0] + masked_over_layer[0].append(batch_indices) + masked_over_layer[1].append(head_indices) + masked_over_layer[2].append(seq_indices) + module = self.modules[layer] + masked_over_layer = tuple(map(lambda x: torch.cat(x, dim=0), masked_over_layer)) + module.masked_key_indices = masked_over_layer + + def compress_cache(self) -> None: + if self.structured: + self.compress_structured() + else: + self.compress_unstructured() + + @contextmanager + def __call__(self, model: PreTrainedModel) -> Generator: + """ + Context manager to apply a compression method to a model. + Apply this context manager during the pre-filling phase to compress the context. + + Parameters + ---------- + model : PreTrainedModel + Model to apply the compression method to + """ + + if not isinstance(model, (LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM)): + logger.warning(f"Model {type(model)} not tested") + + self._register_model(model) + + def new_forward(self, + input_ids, + past_key_values, + *args, + press: KVComposePress, + **kwargs, + ): + press.register_context_ids(input_ids) + + original_attn_implementation = self.model.config._attn_implementation + self.model.config._attn_implementation = "eager" + self.original_forward_KVComposePress( + input_ids=input_ids, + past_key_values=past_key_values, + *args, + **kwargs, + ) + + press._register_cache(past_key_values) + for prompt_ids in (press.prompt_ids or [press.context_ids]): + cache = copy_cache(past_key_values) + self.original_forward_KVComposePress( + input_ids=prompt_ids.to(self.model.device), + past_key_values=cache, + *args, + **kwargs, + ) + + self.model.config._attn_implementation = original_attn_implementation + + hooks = [] + self.modules = [] + try: + for layer in model.model.layers: + layer.self_attn.rotary_emb = model.model.rotary_emb + hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) + + setattr(model, "original_forward_KVComposePress", model.model.forward) + new_forward_with_press = partial(new_forward, press=self) + model.model.forward = types.MethodType(new_forward_with_press, model) + + yield + finally: + model.model.forward = getattr(model, "original_forward_KVComposePress") + delattr(model, "original_forward_KVComposePress") + for forward_hook in hooks: + forward_hook.remove() + self.prepare_important_masks() + self.compress_cache() From b0b3715bd75a567cc94f1438761f9c7d2f2a7bed Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Tue, 23 Sep 2025 12:26:00 +0200 Subject: [PATCH 2/8] Add KVCompose to tests/default_presses.py and README description Signed-off-by: Dmitry Akulov --- README.md | 1 + tests/default_presses.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/README.md b/README.md index d47c63a2..009463f7 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ Some presses rely on a different logic: - `DuoAttentionPress` ([source](kvpress/presses/duo_attention_press.py), [paper](https://arxiv.org/abs/2410.10819)): split heads into retrieval heads (no compression) and streaming heads (StreamingLLM approach) - `FinchPress` ([source](kvpress/presses/finch_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): similar to SnapKV with a dynamic window size and key value re-rotation - `KVzipPress` ([source](kvpress/presses/kvzip_press.py), [paper](https://arxiv.org/abs/2505.23416)): identifies redundant KV pairs through context reconstruction. Achieves near-lossless compression at the cost of multiple forward passes. +- `KVComposePress` ([source](kvpress/presses/kvcompose_press.py), [paper](https://arxiv.org/abs/2509.05165)): attention-guided eviction, aligning per-head selections into composite tokens to preserve cache structure. Finally we provide wrapper presses that can be combined with other presses: - `AdaKVPress` ([source](kvpress/presses/adakv_press.py), [paper](https://arxiv.org/abs/2407.11550)): prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions diff --git a/tests/default_presses.py b/tests/default_presses.py index 5c954766..31dc276f 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -9,6 +9,7 @@ ExpectedAttentionStatsPress, KeyDiffPress, KnormPress, + KVComposePress, KVzipPress, LagKVPress, PyramidKVPress, @@ -74,4 +75,13 @@ def load_attention_pattern(model): "cls": KVzipPress, "kwargs": [{"compression_ratio": 0.5, "layerwise": False}, {"compression_ratio": 0.8, "layerwise": True}], }, + { + "cls": KVComposePress, + "kwargs": [ + {"compression_ratio": 0.5}, + {"compression_ratio": 0.8}, + {"structured": False, "compression_ratio": 0.5}, + {"structured": False, "compression_ratio": 0.8}, + ], + }, ] From 21a9220774ecbcd4687469a2198428b5c8659b84 Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Mon, 29 Sep 2025 18:05:00 +0200 Subject: [PATCH 3/8] KVCompose fix style Signed-off-by: Dmitry Akulov --- kvpress/presses/kvcompose_press.py | 37 ++++++++++++++---------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index a0ee79c1..b7b90e22 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -34,10 +34,13 @@ class Aggregator(ABC): def __init__(self, n, device): self.n = n self.device = device - self.data = torch.full((self.n, ), self.neutral, device=self.device) + self._init_data() + + def _init_data(self): + self.data = torch.full((self.n,), self.neutral, device=self.device) def reset(self): - self.__init__(self.n, self.device) + self._init_data() def partial_fit(self, nd_data: Union[torch.Tensor, np.ndarray]): if isinstance(nd_data, np.ndarray): @@ -54,7 +57,7 @@ def transform(self): return self.data def fit(self, *args): - self.__init__(self.n, self.device) + self._init_data() self.partial_fit(*args) def fit_transform(self, *args): @@ -189,7 +192,7 @@ def _register_model(self, model: PreTrainedModel): def register_context_ids(self, context_ids: torch.Tensor): self.context_ids = context_ids self.context_len = self.context_ids.shape[-1] - self.prompt_ids = [] + self.prompt_ids: list[torch.Tensor] = [] self._init_statistics() def register_prompt_ids(self, prompt_ids: list[torch.Tensor]): @@ -204,9 +207,6 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic Attentions are cleaned up from the output to save memory. """ layer = int(module.layer_idx) - if not self.structured and len(self.modules) == layer: - # Record modules for later masking (used only in unstructured mode). - self.modules.append(module) layer_attentions = output[1] assert layer_attentions is not None @@ -244,7 +244,7 @@ def enhance_scores(self): for layer in range(self.num_layers): if self.add_v_norm: for kv_head in range(self.num_kv_heads): - v = self.cache.value_cache[layer][0, kv_head].detach() + v = self.cache.layers[layer].values[0, kv_head].detach() self.scores[layer, kv_head] = self.scores[layer, kv_head] * v.norm(dim=1) if self.add_mean_across_heads: self.scores[layer] += self.scores[layer].mean(dim=0, keepdim=True) @@ -310,7 +310,7 @@ def compress_structured(self) -> None: self.compressed_cache = DynamicCache() for layer in range(self.num_layers): - kv_over_layer = [[], []] + kv_over_layer: list[list[torch.Tensor]] = [[], []] for kv_head in range(self.num_kv_heads): important_mask = self.important_mask_per_kv_head[layer][kv_head] @@ -327,7 +327,7 @@ def compress_structured(self) -> None: self.cache.layers[layer].keys = new_key_states self.cache.layers[layer].values = new_value_states - def compress_unstructured(self) -> None: + def compress_unstructured(self, model: PreTrainedModel) -> None: """ Storing evicted indices in module.masked_key_indices. Relies on attention_patch.py implementation that simulates real eviction. @@ -335,11 +335,11 @@ def compress_unstructured(self) -> None: """ if self.context_ids.shape[0] != 1: raise NotImplementedError("Unstructured compression supports only batch size 1.") - for layer in range(self.num_layers): - masked_over_layer = [[], [], []] + for layer_idx, layer in enumerate(model.model.layers): + masked_over_layer: list[list[torch.Tensor]] = [[], [], []] for kv_head in range(self.num_kv_heads): - non_important_mask = ~self.important_mask_per_kv_head[layer][kv_head] + non_important_mask = ~self.important_mask_per_kv_head[layer_idx][kv_head] num_non_important_tokens = int(non_important_mask.sum().item()) batch_indices = torch.full((num_non_important_tokens, ), 0, device=self.device) head_indices = torch.full((num_non_important_tokens, ), kv_head, device=self.device) @@ -347,15 +347,13 @@ def compress_unstructured(self) -> None: masked_over_layer[0].append(batch_indices) masked_over_layer[1].append(head_indices) masked_over_layer[2].append(seq_indices) - module = self.modules[layer] - masked_over_layer = tuple(map(lambda x: torch.cat(x, dim=0), masked_over_layer)) - module.masked_key_indices = masked_over_layer + layer.self_attn.masked_key_indices = tuple(map(lambda x: torch.cat(x, dim=0), masked_over_layer)) - def compress_cache(self) -> None: + def compress_cache(self, model: PreTrainedModel) -> None: if self.structured: self.compress_structured() else: - self.compress_unstructured() + self.compress_unstructured(model) @contextmanager def __call__(self, model: PreTrainedModel) -> Generator: @@ -405,7 +403,6 @@ def new_forward(self, self.model.config._attn_implementation = original_attn_implementation hooks = [] - self.modules = [] try: for layer in model.model.layers: layer.self_attn.rotary_emb = model.model.rotary_emb @@ -422,4 +419,4 @@ def new_forward(self, for forward_hook in hooks: forward_hook.remove() self.prepare_important_masks() - self.compress_cache() + self.compress_cache(model) From a2824beeaa3c1e474a6a558d34d3693fae90ab59 Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Tue, 30 Sep 2025 14:22:25 +0200 Subject: [PATCH 4/8] KVCompose fix test_presses_run Signed-off-by: Dmitry Akulov --- kvpress/presses/kvcompose_press.py | 3 ++- tests/presses/test_presses.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index b7b90e22..562aad5f 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -383,7 +383,7 @@ def new_forward(self, original_attn_implementation = self.model.config._attn_implementation self.model.config._attn_implementation = "eager" - self.original_forward_KVComposePress( + outputs = self.original_forward_KVComposePress( input_ids=input_ids, past_key_values=past_key_values, *args, @@ -401,6 +401,7 @@ def new_forward(self, ) self.model.config._attn_implementation = original_attn_implementation + return outputs hooks = [] try: diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index d70d4328..63e5dbf1 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -16,6 +16,7 @@ CriticalKVPress, KeyRerotationPress, KnormPress, + KVComposePress, KVzipPress, ObservedAttentionPress, ScorerPress, @@ -70,6 +71,8 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 if hasattr(press, "__post_init_from_model__"): press.__post_init_from_model__(unit_test_model) if issubclass(wrapper_press, ComposedPress): + if isinstance(press, KVComposePress): # KVComposePress is currently not compatible with ComposedPress + return if isinstance(press, KVzipPress): # KVzipPress is currently not compatible with ComposedPress return press = ComposedPress(presses=[press]) From 7ff0457525fa8745be7d0c4c69da341d3f0dad5f Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Sat, 15 Nov 2025 03:02:47 +0100 Subject: [PATCH 5/8] Get rid of doubling the cache Signed-off-by: Dmitry Akulov --- kvpress/presses/kvcompose_press.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index 562aad5f..cf6e8ac2 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -305,24 +305,20 @@ def compress_structured(self) -> None: """ Preparing compressed version of the cache. For KVPress, we modify the cache in-place (stored in self.cache). - For general use, it's accessible via self.compressed_cache. """ - self.compressed_cache = DynamicCache() - for layer in range(self.num_layers): kv_over_layer: list[list[torch.Tensor]] = [[], []] for kv_head in range(self.num_kv_heads): important_mask = self.important_mask_per_kv_head[layer][kv_head] - keys = self.cache.layers[layer].keys[0, kv_head][:self.context_len].detach().clone() - values = self.cache.layers[layer].values[0, kv_head][:self.context_len].detach().clone() + keys = self.cache.layers[layer].keys[0, kv_head][:self.context_len] + values = self.cache.layers[layer].values[0, kv_head][:self.context_len] keys = keys[important_mask] values = values[important_mask] kv_over_layer[0].append(keys) kv_over_layer[1].append(values) new_key_states = torch.stack(kv_over_layer[0], dim=0).unsqueeze(0) new_value_states = torch.stack(kv_over_layer[1], dim=0).unsqueeze(0) - self.compressed_cache.update(layer_idx=layer, key_states=new_key_states, value_states=new_value_states) self.cache.layers[layer].keys = new_key_states self.cache.layers[layer].values = new_value_states @@ -392,7 +388,7 @@ def new_forward(self, press._register_cache(past_key_values) for prompt_ids in (press.prompt_ids or [press.context_ids]): - cache = copy_cache(past_key_values) + cache = past_key_values self.original_forward_KVComposePress( input_ids=prompt_ids.to(self.model.device), past_key_values=cache, From 21d3d2b48a859d7beed77036405bcb952dfa5f77 Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Sat, 15 Nov 2025 05:56:23 +0100 Subject: [PATCH 6/8] Style and documentation fix Signed-off-by: Dmitry Akulov --- kvpress/presses/kvcompose_press.py | 89 +++++++++++++++--------------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index cf6e8ac2..a5410b36 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -99,14 +99,6 @@ def _partial_fit(self, nd_data: torch.Tensor): } -def get_aggregator(aggregator_name: str) -> type[Aggregator]: - aggregator_name = aggregator_name.lower() - if aggregator_name not in aggregator_by_name: - raise ValueError(f"Unknown aggregator_name: {aggregator_name}. " - f"Available: {list(aggregator_by_name)}") - return aggregator_by_name[aggregator_name] - - def copy_cache(cache: DynamicCache): new_cache = DynamicCache() for (layer, (k, v)) in enumerate(cache): @@ -125,6 +117,10 @@ class KVComposePress(BasePress): better theoretical performance but is incompatible with standard KV cache layouts unless the attention mechanism is modified. + Requirements: + - Requires attention weights (attn) to be present for the forward hook. + - Attention weights are deleted after use to save memory. + Based on KVCompose (https://arxiv.org/abs/2509.05165). Parameters @@ -150,34 +146,24 @@ class KVComposePress(BasePress): Minimum number of tokens to keep in each layer. """ - def __init__(self, - structured: bool = True, - compression_ratio: float = 0, - agg_task: str = "max", - agg_group: str = "mean", - agg_head: str = "mean", - add_v_norm: bool = False, - add_mean_across_heads: bool = True, - keep_token_lower_bound: int = 0, - ): - self.structured = structured - self.compression_ratio = compression_ratio - self.agg_task = agg_task - self.agg_group = agg_group - self.agg_head = agg_head - - self.add_v_norm = add_v_norm - self.add_mean_across_heads = add_mean_across_heads - self.keep_token_lower_bound = keep_token_lower_bound - - super().__init__() + structured: bool = True + compression_ratio: float = 0 + agg_task: str = "max" + agg_group: str = "mean" + agg_head: str = "mean" + add_v_norm: bool = False + add_mean_across_heads: bool = True + keep_token_lower_bound: int = 0 def __post_init__(self): assert 0 <= self.compression_ratio < 1, "Compression ratio must be between 0 and 1" def _init_statistics(self): + """ + Initializing the task aggregators for each layer and head. + """ self.task_aggregators = [ - [get_aggregator(self.agg_task)(self.context_len, self.device) for _ in range(self.num_att_heads)] + [aggregator_by_name[self.agg_task](self.context_len, self.device) for _ in range(self.num_att_heads)] for _ in range(self.num_layers) ] @@ -225,14 +211,15 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic return output - def get_scores(self): + def compute_scores(self): """ Obtaining token scores by doing aggregation over groups for every kv head. + Stored in tensor self.scores of shape (num_layers, num_kv_heads, context_len). """ self.scores = torch.zeros((self.num_layers, self.num_kv_heads, self.context_len), device=self.device) for layer in range(self.num_layers): for kv_head in range(self.num_kv_heads): - group_aggregator = get_aggregator(self.agg_group)(self.context_len, self.device) + group_aggregator = aggregator_by_name[self.agg_group](self.context_len, self.device) for att_head in range(kv_head * self.num_kv_groups, (kv_head + 1) * self.num_kv_groups): group_aggregator.partial_fit(self.task_aggregators[layer][att_head].transform()) self.scores[layer, kv_head] = group_aggregator.transform() @@ -240,6 +227,7 @@ def get_scores(self): def enhance_scores(self): """ Enhance token scores by incorporating value vector norms and mean scores. + Modifies tensor self.scores in-place. """ for layer in range(self.num_layers): if self.add_v_norm: @@ -249,39 +237,50 @@ def enhance_scores(self): if self.add_mean_across_heads: self.scores[layer] += self.scores[layer].mean(dim=0, keepdim=True) - def get_composite_scores(self): + def compute_composite_scores(self): + """ + Calculating composite scores per head and layer. + Stored in tensors: + - self.composite_scores_per_head of shape (num_layers, num_kv_heads, context_len): unstructured compression. + - self.composite_scores_per_layer of shape (num_layers, context_len): structured compression. + """ self.composite_scores_per_head = self.scores.sort(dim=-1, descending=True)[0] + self.composite_scores_per_head[..., :self.keep_token_lower_bound] += 1e9 self.composite_scores_per_layer = torch.full((self.num_layers, self.context_len), 0., device=self.device) for layer in range(self.num_layers): - layer_aggregator = get_aggregator(self.agg_head)(self.context_len, self.device) + layer_aggregator = aggregator_by_name[self.agg_head](self.context_len, self.device) for kv_head in range(self.num_kv_heads): layer_aggregator.partial_fit(self.scores[layer, kv_head].sort(descending=True)[0]) self.composite_scores_per_layer[layer] = layer_aggregator.transform() + self.composite_scores_per_layer[..., :self.keep_token_lower_bound] += 1e9 + self.composite_scores_per_layer[0] = \ + self.composite_scores_per_layer.max(dim=0).values # Ensures first layer is the largest. - def get_important_per_layer(self): + def compute_important_per_layer(self): """ Calculates how many tokens to keep per layer (and per head for unstructured). + Stored in tensors: + - self.important_per_head of shape (num_layers, num_kv_heads): unstructured compression. + - self.important_per_layer of shape (num_layers): structured compression. """ - self.get_composite_scores() + self.compute_composite_scores() - self.composite_scores_per_layer[..., :self.keep_token_lower_bound] += 1e9 - self.composite_scores_per_layer[0] = \ - self.composite_scores_per_layer.max(dim=0).values # Ensures first layer is the largest. - threshold_layer = self.composite_scores_per_layer.quantile(self.compression_ratio) - self.important_per_layer = (self.composite_scores_per_layer >= threshold_layer).sum(dim=-1).cpu().numpy() - - self.composite_scores_per_head[..., :self.keep_token_lower_bound] += 1e9 threshold_head = self.composite_scores_per_head.quantile(self.compression_ratio) self.important_per_head = (self.composite_scores_per_head >= threshold_head).sum(dim=-1).cpu().numpy() + threshold_layer = self.composite_scores_per_layer.quantile(self.compression_ratio) + self.important_per_layer = (self.composite_scores_per_layer >= threshold_layer).sum(dim=-1).cpu().numpy() + def prepare_important_masks(self): """ Building masks of tokens to keep per kv head. + Stored in tensor: + - self.important_mask_per_kv_head of shape (num_layers, num_kv_heads, context_len). """ - self.get_scores() + self.compute_scores() self.enhance_scores() - self.get_important_per_layer() + self.compute_important_per_layer() self.important_mask_per_kv_head = [ [ From a2b81bdd7e57ca379f7ab9c44b9f803506c00119 Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Sat, 15 Nov 2025 06:50:27 +0100 Subject: [PATCH 7/8] Adapt test_head_compression for KVCompose Signed-off-by: Dmitry Akulov --- tests/presses/test_head_compression.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/presses/test_head_compression.py b/tests/presses/test_head_compression.py index 83be7dd4..931ccc1f 100644 --- a/tests/presses/test_head_compression.py +++ b/tests/presses/test_head_compression.py @@ -4,7 +4,7 @@ import torch from transformers import DynamicCache -from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress +from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, KVComposePress from tests.fixtures import unit_test_model # noqa: F401 @@ -40,12 +40,19 @@ def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ra assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences -# Only for KVzipPress, since it's the only non-wrapper press with head compression (apart from Duo) -@pytest.mark.parametrize("press", [KVzipPress]) +# Only for KVzipPress and unstructured KVComposePress, since they are +# the only non-wrapper presses with head compression (apart from Duo) +@pytest.mark.parametrize( + "press_cls, kwargs", + [ + (KVzipPress, {"layerwise": True}), + (KVzipPress, {"layerwise": False}), + (KVComposePress, {"structured": False}), + ], +) @pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8]) -@pytest.mark.parametrize("layerwise", [True, False]) -def test_head_compression(unit_test_model, press, compression_ratio, layerwise): # noqa: F811 - press = KVzipPress(compression_ratio=compression_ratio, layerwise=layerwise) +def test_head_compression(unit_test_model, press_cls, kwargs, compression_ratio): + press = press_cls(compression_ratio=compression_ratio, **kwargs) with press(unit_test_model): input_ids = torch.randint(0, 1024, (1, 128)) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values From 1cc445e83a5a413870b1ee095893f94c7f8265db Mon Sep 17 00:00:00 2001 From: Dmitry Akulov Date: Fri, 16 Jan 2026 16:32:16 +0100 Subject: [PATCH 8/8] KVCompose memory warning Signed-off-by: Dmitry Akulov --- README.md | 3 +++ kvpress/presses/kvcompose_press.py | 10 +++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 009463f7..15334765 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,9 @@ Some presses rely on a different logic: - `KVzipPress` ([source](kvpress/presses/kvzip_press.py), [paper](https://arxiv.org/abs/2505.23416)): identifies redundant KV pairs through context reconstruction. Achieves near-lossless compression at the cost of multiple forward passes. - `KVComposePress` ([source](kvpress/presses/kvcompose_press.py), [paper](https://arxiv.org/abs/2509.05165)): attention-guided eviction, aligning per-head selections into composite tokens to preserve cache structure. +> [!NOTE] +> `KVComposePress` performs an extra pass over the full context, temporarily creating a KV cache of ~2x the context length and creating memory overhead during prefill. + Finally we provide wrapper presses that can be combined with other presses: - `AdaKVPress` ([source](kvpress/presses/adakv_press.py), [paper](https://arxiv.org/abs/2407.11550)): prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions - `PerLayerCompressionPress` ([source](kvpress/presses/per_layer_compression_press.py)): compress each layer with a different compression ratio (experimental) diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py index a5410b36..b2203761 100644 --- a/kvpress/presses/kvcompose_press.py +++ b/kvpress/presses/kvcompose_press.py @@ -99,13 +99,6 @@ def _partial_fit(self, nd_data: torch.Tensor): } -def copy_cache(cache: DynamicCache): - new_cache = DynamicCache() - for (layer, (k, v)) in enumerate(cache): - new_cache.update(layer_idx=layer, key_states=k.clone(), value_states=v.clone()) - return new_cache - - @dataclass class KVComposePress(BasePress): """ @@ -362,6 +355,9 @@ def __call__(self, model: PreTrainedModel) -> Generator: Model to apply the compression method to """ + logger.warning( + "KVComposePress temporarily creates a KV cache of ~2x the context length during prefill; " + ) if not isinstance(model, (LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM)): logger.warning(f"Model {type(model)} not tested")