diff --git a/README.md b/README.md index d47c63a2..15334765 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,10 @@ Some presses rely on a different logic: - `DuoAttentionPress` ([source](kvpress/presses/duo_attention_press.py), [paper](https://arxiv.org/abs/2410.10819)): split heads into retrieval heads (no compression) and streaming heads (StreamingLLM approach) - `FinchPress` ([source](kvpress/presses/finch_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): similar to SnapKV with a dynamic window size and key value re-rotation - `KVzipPress` ([source](kvpress/presses/kvzip_press.py), [paper](https://arxiv.org/abs/2505.23416)): identifies redundant KV pairs through context reconstruction. Achieves near-lossless compression at the cost of multiple forward passes. +- `KVComposePress` ([source](kvpress/presses/kvcompose_press.py), [paper](https://arxiv.org/abs/2509.05165)): attention-guided eviction, aligning per-head selections into composite tokens to preserve cache structure. + +> [!NOTE] +> `KVComposePress` performs an extra pass over the full context, temporarily creating a KV cache of ~2x the context length and creating memory overhead during prefill. Finally we provide wrapper presses that can be combined with other presses: - `AdaKVPress` ([source](kvpress/presses/adakv_press.py), [paper](https://arxiv.org/abs/2407.11550)): prune bottom scores of any `ScorerPress` but across all heads, achieving head-wise compressions diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index 0dc4bd03..a6d35528 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -22,6 +22,7 @@ FinchPress, KeyDiffPress, KnormPress, + KVComposePress, KVzipPress, ObservedAttentionPress, PyramidKVPress, @@ -72,6 +73,8 @@ "expected_attention": ExpectedAttentionPress(), "finch": FinchPress(), "keydiff": KeyDiffPress(), + "kvcompose": KVComposePress(), + "kvcompose_unstructured": KVComposePress(structured=False), "kvzip": KVzipPress(), "knorm": KnormPress(), "observed_attention": ObservedAttentionPress(), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index cccd1f02..ff37ce24 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -18,6 +18,7 @@ from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.keydiff_press import KeyDiffPress from kvpress.presses.knorm_press import KnormPress +from kvpress.presses.kvcompose_press import KVComposePress from kvpress.presses.kvzip_press import KVzipPress from kvpress.presses.lagkv_press import LagKVPress from kvpress.presses.observed_attention_press import ObservedAttentionPress @@ -65,4 +66,5 @@ "KeyDiffPress", "KVzipPress", "ExpectedAttentionStatsPress", + "KVComposePress", ] diff --git a/kvpress/presses/kvcompose_press.py b/kvpress/presses/kvcompose_press.py new file mode 100644 index 00000000..b2203761 --- /dev/null +++ b/kvpress/presses/kvcompose_press.py @@ -0,0 +1,414 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import Generator, Union + + +import numpy as np +import torch +import types +from torch import nn +from transformers.models.llama import LlamaForCausalLM +from transformers.modeling_utils import PreTrainedModel +from transformers.cache_utils import DynamicCache +from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM + +from kvpress.presses.base_press import BasePress + +logger = logging.getLogger(__name__) + + +class Aggregator(ABC): + n: int + data: torch.Tensor + neutral: float + device: str + + def __init__(self, n, device): + self.n = n + self.device = device + self._init_data() + + def _init_data(self): + self.data = torch.full((self.n,), self.neutral, device=self.device) + + def reset(self): + self._init_data() + + def partial_fit(self, nd_data: Union[torch.Tensor, np.ndarray]): + if isinstance(nd_data, np.ndarray): + nd_data = torch.from_numpy(nd_data) + if nd_data.ndim == 1: + nd_data = nd_data.unsqueeze(0) + return self._partial_fit(nd_data) + + @abstractmethod + def _partial_fit(self, nd_data: torch.Tensor): + pass + + def transform(self): + return self.data + + def fit(self, *args): + self._init_data() + self.partial_fit(*args) + + def fit_transform(self, *args): + self.fit(*args) + return self.transform() + + +class MaxAggregator(Aggregator): + def __init__(self, n, device): + self.neutral = -torch.inf + super().__init__(n, device) + + def _partial_fit(self, nd_data: torch.Tensor): + new_max_data = nd_data.amax(dim=tuple(range(len(nd_data.shape)-1))) + self.data = torch.maximum(self.data, new_max_data) + + +class MeanAggregator(Aggregator): + sum_data: torch.Tensor + count_data: torch.Tensor + + def __init__(self, n, device): + self.neutral = 0. + super().__init__(n, device) + self.sum_data = torch.full((n, ), self.neutral, device=self.device) + self.count_data = torch.full((n, ), self.neutral, device=self.device) + + def _partial_fit(self, nd_data: torch.Tensor): + new_sum_data = nd_data.sum(dim=tuple(range(len(nd_data.shape)-1))) + new_count_data = torch.ones_like(nd_data, device=self.device).sum(dim=tuple(range(len(nd_data.shape)-1))) + self.sum_data += new_sum_data + self.count_data += new_count_data + self.data = self.sum_data / self.count_data + + +aggregator_by_name = { + "mean": MeanAggregator, + "max": MaxAggregator, +} + + +@dataclass +class KVComposePress(BasePress): + """ + KVComposePress implements KVCompose: a structured KV cache compression + method that remains compatible with standard inference pipelines. + + Setting `structured=False` enables the unstructured variant where each head + retains tokens independently (no composite alignment). This generally yields + better theoretical performance but is incompatible with standard KV cache layouts + unless the attention mechanism is modified. + + Requirements: + - Requires attention weights (attn) to be present for the forward hook. + - Attention weights are deleted after use to save memory. + + Based on KVCompose (https://arxiv.org/abs/2509.05165). + + Parameters + ---------- + structured : bool, default=True + Whether to use the structured or unstructured method. + compression_ratio : float, default=0.0 + Global fraction of KV tokens to remove. + agg_task : str, default="max" + Strategy to form per-context-token importance score per layer/head from + attention (e.g. 'max', 'mean'). + agg_group : str, default="mean" + Aggregation within each head across groups (for grouped query attention). + agg_head : str, default="mean" + Aggregation across heads to form composite importance score (used for + structured alignment). + add_v_norm : bool, default=False + Whether to multiply token score by the norm of its value vector. + add_mean_across_heads : bool, default=True + Whether to augment token scores with the mean score across all heads + to improve stability. + keep_token_lower_bound : int, default=0 + Minimum number of tokens to keep in each layer. + """ + + structured: bool = True + compression_ratio: float = 0 + agg_task: str = "max" + agg_group: str = "mean" + agg_head: str = "mean" + add_v_norm: bool = False + add_mean_across_heads: bool = True + keep_token_lower_bound: int = 0 + + def __post_init__(self): + assert 0 <= self.compression_ratio < 1, "Compression ratio must be between 0 and 1" + + def _init_statistics(self): + """ + Initializing the task aggregators for each layer and head. + """ + self.task_aggregators = [ + [aggregator_by_name[self.agg_task](self.context_len, self.device) for _ in range(self.num_att_heads)] + for _ in range(self.num_layers) + ] + + def _register_model(self, model: PreTrainedModel): + self.model = model + self.num_layers: int = getattr(model.config, "num_hidden_layers") + self.num_att_heads: int = getattr(model.config, "num_attention_heads") + self.num_kv_heads: int = getattr(model.config, "num_key_value_heads") + self.num_kv_groups: int = self.num_att_heads // self.num_kv_heads + self.device = next(model.parameters()).device + + def register_context_ids(self, context_ids: torch.Tensor): + self.context_ids = context_ids + self.context_len = self.context_ids.shape[-1] + self.prompt_ids: list[torch.Tensor] = [] + self._init_statistics() + + def register_prompt_ids(self, prompt_ids: list[torch.Tensor]): + self.prompt_ids = prompt_ids + + def _register_cache(self, cache: DynamicCache): + self.cache = cache + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Fitting self.task_aggregators with the attention scores from the forward pass. + Attentions are cleaned up from the output to save memory. + """ + layer = int(module.layer_idx) + + layer_attentions = output[1] + assert layer_attentions is not None + + for att_head in range(self.num_att_heads): + if layer_attentions.shape[3] == layer_attentions.shape[2]: + # Skip self-to-self attention (prefill step), only record context-to-query attentions + continue + layer_att_head_attention = layer_attentions[:, att_head, :, :self.context_len] + self.task_aggregators[layer][att_head].partial_fit(layer_att_head_attention) + + # Clean up attention to save memory. + output = list(output) + del output[1] + output.append(None) + + return output + + def compute_scores(self): + """ + Obtaining token scores by doing aggregation over groups for every kv head. + Stored in tensor self.scores of shape (num_layers, num_kv_heads, context_len). + """ + self.scores = torch.zeros((self.num_layers, self.num_kv_heads, self.context_len), device=self.device) + for layer in range(self.num_layers): + for kv_head in range(self.num_kv_heads): + group_aggregator = aggregator_by_name[self.agg_group](self.context_len, self.device) + for att_head in range(kv_head * self.num_kv_groups, (kv_head + 1) * self.num_kv_groups): + group_aggregator.partial_fit(self.task_aggregators[layer][att_head].transform()) + self.scores[layer, kv_head] = group_aggregator.transform() + + def enhance_scores(self): + """ + Enhance token scores by incorporating value vector norms and mean scores. + Modifies tensor self.scores in-place. + """ + for layer in range(self.num_layers): + if self.add_v_norm: + for kv_head in range(self.num_kv_heads): + v = self.cache.layers[layer].values[0, kv_head].detach() + self.scores[layer, kv_head] = self.scores[layer, kv_head] * v.norm(dim=1) + if self.add_mean_across_heads: + self.scores[layer] += self.scores[layer].mean(dim=0, keepdim=True) + + def compute_composite_scores(self): + """ + Calculating composite scores per head and layer. + Stored in tensors: + - self.composite_scores_per_head of shape (num_layers, num_kv_heads, context_len): unstructured compression. + - self.composite_scores_per_layer of shape (num_layers, context_len): structured compression. + """ + self.composite_scores_per_head = self.scores.sort(dim=-1, descending=True)[0] + self.composite_scores_per_head[..., :self.keep_token_lower_bound] += 1e9 + + self.composite_scores_per_layer = torch.full((self.num_layers, self.context_len), 0., device=self.device) + for layer in range(self.num_layers): + layer_aggregator = aggregator_by_name[self.agg_head](self.context_len, self.device) + for kv_head in range(self.num_kv_heads): + layer_aggregator.partial_fit(self.scores[layer, kv_head].sort(descending=True)[0]) + self.composite_scores_per_layer[layer] = layer_aggregator.transform() + self.composite_scores_per_layer[..., :self.keep_token_lower_bound] += 1e9 + self.composite_scores_per_layer[0] = \ + self.composite_scores_per_layer.max(dim=0).values # Ensures first layer is the largest. + + def compute_important_per_layer(self): + """ + Calculates how many tokens to keep per layer (and per head for unstructured). + Stored in tensors: + - self.important_per_head of shape (num_layers, num_kv_heads): unstructured compression. + - self.important_per_layer of shape (num_layers): structured compression. + """ + self.compute_composite_scores() + + threshold_head = self.composite_scores_per_head.quantile(self.compression_ratio) + self.important_per_head = (self.composite_scores_per_head >= threshold_head).sum(dim=-1).cpu().numpy() + + threshold_layer = self.composite_scores_per_layer.quantile(self.compression_ratio) + self.important_per_layer = (self.composite_scores_per_layer >= threshold_layer).sum(dim=-1).cpu().numpy() + + def prepare_important_masks(self): + """ + Building masks of tokens to keep per kv head. + Stored in tensor: + - self.important_mask_per_kv_head of shape (num_layers, num_kv_heads, context_len). + """ + self.compute_scores() + self.enhance_scores() + self.compute_important_per_layer() + + self.important_mask_per_kv_head = [ + [ + torch.zeros(size=(self.context_len, ), device=self.device, dtype=torch.bool) + for _ in range(self.num_kv_heads) + ] + for _ in range(self.num_layers) + ] + + for layer in range(self.num_layers): + for kv_head in range(self.num_kv_heads): + count_of_important = ( + self.important_per_layer[layer] + if self.structured + else self.important_per_head[layer, kv_head] + ) + important_indices = torch.argsort(self.scores[layer, kv_head], descending=True)[:count_of_important] + self.important_mask_per_kv_head[layer][kv_head][important_indices] = True + + def compress_structured(self) -> None: + """ + Preparing compressed version of the cache. + For KVPress, we modify the cache in-place (stored in self.cache). + """ + for layer in range(self.num_layers): + kv_over_layer: list[list[torch.Tensor]] = [[], []] + for kv_head in range(self.num_kv_heads): + important_mask = self.important_mask_per_kv_head[layer][kv_head] + + keys = self.cache.layers[layer].keys[0, kv_head][:self.context_len] + values = self.cache.layers[layer].values[0, kv_head][:self.context_len] + keys = keys[important_mask] + values = values[important_mask] + kv_over_layer[0].append(keys) + kv_over_layer[1].append(values) + new_key_states = torch.stack(kv_over_layer[0], dim=0).unsqueeze(0) + new_value_states = torch.stack(kv_over_layer[1], dim=0).unsqueeze(0) + + self.cache.layers[layer].keys = new_key_states + self.cache.layers[layer].values = new_value_states + + def compress_unstructured(self, model: PreTrainedModel) -> None: + """ + Storing evicted indices in module.masked_key_indices. + Relies on attention_patch.py implementation that simulates real eviction. + Supports only batch size 1. + """ + if self.context_ids.shape[0] != 1: + raise NotImplementedError("Unstructured compression supports only batch size 1.") + for layer_idx, layer in enumerate(model.model.layers): + masked_over_layer: list[list[torch.Tensor]] = [[], [], []] + + for kv_head in range(self.num_kv_heads): + non_important_mask = ~self.important_mask_per_kv_head[layer_idx][kv_head] + num_non_important_tokens = int(non_important_mask.sum().item()) + batch_indices = torch.full((num_non_important_tokens, ), 0, device=self.device) + head_indices = torch.full((num_non_important_tokens, ), kv_head, device=self.device) + seq_indices = non_important_mask.nonzero(as_tuple=True)[0] + masked_over_layer[0].append(batch_indices) + masked_over_layer[1].append(head_indices) + masked_over_layer[2].append(seq_indices) + layer.self_attn.masked_key_indices = tuple(map(lambda x: torch.cat(x, dim=0), masked_over_layer)) + + def compress_cache(self, model: PreTrainedModel) -> None: + if self.structured: + self.compress_structured() + else: + self.compress_unstructured(model) + + @contextmanager + def __call__(self, model: PreTrainedModel) -> Generator: + """ + Context manager to apply a compression method to a model. + Apply this context manager during the pre-filling phase to compress the context. + + Parameters + ---------- + model : PreTrainedModel + Model to apply the compression method to + """ + + logger.warning( + "KVComposePress temporarily creates a KV cache of ~2x the context length during prefill; " + ) + if not isinstance(model, (LlamaForCausalLM, Qwen2ForCausalLM, Qwen3ForCausalLM)): + logger.warning(f"Model {type(model)} not tested") + + self._register_model(model) + + def new_forward(self, + input_ids, + past_key_values, + *args, + press: KVComposePress, + **kwargs, + ): + press.register_context_ids(input_ids) + + original_attn_implementation = self.model.config._attn_implementation + self.model.config._attn_implementation = "eager" + outputs = self.original_forward_KVComposePress( + input_ids=input_ids, + past_key_values=past_key_values, + *args, + **kwargs, + ) + + press._register_cache(past_key_values) + for prompt_ids in (press.prompt_ids or [press.context_ids]): + cache = past_key_values + self.original_forward_KVComposePress( + input_ids=prompt_ids.to(self.model.device), + past_key_values=cache, + *args, + **kwargs, + ) + + self.model.config._attn_implementation = original_attn_implementation + return outputs + + hooks = [] + try: + for layer in model.model.layers: + layer.self_attn.rotary_emb = model.model.rotary_emb + hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) + + setattr(model, "original_forward_KVComposePress", model.model.forward) + new_forward_with_press = partial(new_forward, press=self) + model.model.forward = types.MethodType(new_forward_with_press, model) + + yield + finally: + model.model.forward = getattr(model, "original_forward_KVComposePress") + delattr(model, "original_forward_KVComposePress") + for forward_hook in hooks: + forward_hook.remove() + self.prepare_important_masks() + self.compress_cache(model) diff --git a/tests/default_presses.py b/tests/default_presses.py index 5c954766..31dc276f 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -9,6 +9,7 @@ ExpectedAttentionStatsPress, KeyDiffPress, KnormPress, + KVComposePress, KVzipPress, LagKVPress, PyramidKVPress, @@ -74,4 +75,13 @@ def load_attention_pattern(model): "cls": KVzipPress, "kwargs": [{"compression_ratio": 0.5, "layerwise": False}, {"compression_ratio": 0.8, "layerwise": True}], }, + { + "cls": KVComposePress, + "kwargs": [ + {"compression_ratio": 0.5}, + {"compression_ratio": 0.8}, + {"structured": False, "compression_ratio": 0.5}, + {"structured": False, "compression_ratio": 0.8}, + ], + }, ] diff --git a/tests/presses/test_head_compression.py b/tests/presses/test_head_compression.py index 83be7dd4..931ccc1f 100644 --- a/tests/presses/test_head_compression.py +++ b/tests/presses/test_head_compression.py @@ -4,7 +4,7 @@ import torch from transformers import DynamicCache -from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress +from kvpress import AdaKVPress, CriticalAdaKVPress, KnormPress, KVzipPress, KVComposePress from tests.fixtures import unit_test_model # noqa: F401 @@ -40,12 +40,19 @@ def test_wrapper_head_compression(unit_test_model, wrapper_press, compression_ra assert abs(cumulative_compression_ratio - press.compression_ratio) < 1e-2 # tolerate small differences -# Only for KVzipPress, since it's the only non-wrapper press with head compression (apart from Duo) -@pytest.mark.parametrize("press", [KVzipPress]) +# Only for KVzipPress and unstructured KVComposePress, since they are +# the only non-wrapper presses with head compression (apart from Duo) +@pytest.mark.parametrize( + "press_cls, kwargs", + [ + (KVzipPress, {"layerwise": True}), + (KVzipPress, {"layerwise": False}), + (KVComposePress, {"structured": False}), + ], +) @pytest.mark.parametrize("compression_ratio", [0.2, 0.4, 0.6, 0.8]) -@pytest.mark.parametrize("layerwise", [True, False]) -def test_head_compression(unit_test_model, press, compression_ratio, layerwise): # noqa: F811 - press = KVzipPress(compression_ratio=compression_ratio, layerwise=layerwise) +def test_head_compression(unit_test_model, press_cls, kwargs, compression_ratio): + press = press_cls(compression_ratio=compression_ratio, **kwargs) with press(unit_test_model): input_ids = torch.randint(0, 1024, (1, 128)) unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index d70d4328..63e5dbf1 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -16,6 +16,7 @@ CriticalKVPress, KeyRerotationPress, KnormPress, + KVComposePress, KVzipPress, ObservedAttentionPress, ScorerPress, @@ -70,6 +71,8 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 if hasattr(press, "__post_init_from_model__"): press.__post_init_from_model__(unit_test_model) if issubclass(wrapper_press, ComposedPress): + if isinstance(press, KVComposePress): # KVComposePress is currently not compatible with ComposedPress + return if isinstance(press, KVzipPress): # KVzipPress is currently not compatible with ComposedPress return press = ComposedPress(presses=[press])