From 9a3df1af34dfa7e3ebe2721ce928850d81c34e08 Mon Sep 17 00:00:00 2001 From: Jang-Hyun Date: Tue, 3 Feb 2026 07:57:00 +0000 Subject: [PATCH 1/4] fast kvzip Signed-off-by: Jang-Hyun --- evaluation/evaluate_registry.py | 8 +- kvpress/__init__.py | 4 +- kvpress/presses/fastkvzip_press.py | 277 +++++++++++++++++++++++++++++ tests/default_presses.py | 2 + tests/presses/test_presses.py | 7 +- 5 files changed, 292 insertions(+), 6 deletions(-) create mode 100644 kvpress/presses/fastkvzip_press.py diff --git a/evaluation/evaluate_registry.py b/evaluation/evaluate_registry.py index 2274c43a..e81e7fda 100644 --- a/evaluation/evaluate_registry.py +++ b/evaluation/evaluate_registry.py @@ -20,25 +20,26 @@ ComposedPress, CriticalAdaKVPress, CriticalKVPress, + CURPress, DecodingPress, + DMSPress, DuoAttentionPress, ExpectedAttentionPress, + FastKVzipPress, FinchPress, KeyDiffPress, KnormPress, KVzapPress, KVzipPress, + LagKVPress, ObservedAttentionPress, PyramidKVPress, QFilterPress, RandomPress, SnapKVPress, StreamingLLMPress, - DMSPress, ThinKPress, TOVAPress, - CURPress, - LagKVPress, ) # These dictionaries define the available datasets, scorers, and KVPress methods for evaluation. @@ -82,6 +83,7 @@ "duo_attention": DuoAttentionPress(), "duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True), "expected_attention": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)), + "fastkvzip": FastKVzipPress(), "finch": FinchPress(), "keydiff": KeyDiffPress(), "kvzip": KVzipPress(), diff --git a/kvpress/__init__.py b/kvpress/__init__.py index 7f124d4a..0f08519e 100644 --- a/kvpress/__init__.py +++ b/kvpress/__init__.py @@ -14,9 +14,11 @@ from kvpress.presses.criticalkv_press import CriticalAdaKVPress, CriticalKVPress from kvpress.presses.cur_press import CURPress from kvpress.presses.decoding_press import DecodingPress +from kvpress.presses.dms_press import DMSPress from kvpress.presses.duo_attention_press import DuoAttentionPress from kvpress.presses.expected_attention_press import ExpectedAttentionPress from kvpress.presses.expected_attention_with_stats import ExpectedAttentionStatsPress +from kvpress.presses.fastkvzip_press import FastKVzipPress from kvpress.presses.finch_press import FinchPress from kvpress.presses.key_rerotation_press import KeyRerotationPress from kvpress.presses.keydiff_press import KeyDiffPress @@ -37,7 +39,6 @@ from kvpress.presses.snapkv_press import SnapKVPress from kvpress.presses.streaming_llm_press import StreamingLLMPress from kvpress.presses.think_press import ThinKPress -from kvpress.presses.dms_press import DMSPress from kvpress.presses.tova_press import TOVAPress # Patch the attention functions to support head-wise compression @@ -81,4 +82,5 @@ "NonCausalAttnPress", "KVzapPress", "DMSPress", + "FastKVzipPress", ] diff --git a/kvpress/presses/fastkvzip_press.py b/kvpress/presses/fastkvzip_press.py new file mode 100644 index 00000000..919914fc --- /dev/null +++ b/kvpress/presses/fastkvzip_press.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import math +import os +import re +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Generator + +import torch +from huggingface_hub import hf_hub_download +from torch import nn +from transformers import AutoConfig, PreTrainedModel +from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm + +from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress + +logger = logging.getLogger(__name__) + + +class Weight(nn.Module): + def __init__( + self, + index: int, + input_dim: int, + output_dim: int, + nhead: int, + ngroup: int, + dtype, + sink=1, + ): + super().__init__() + self.index = index + self.output_dim = output_dim + self.nhead = nhead + self.ngroup = ngroup + self.sink = sink + + self.q_proj = nn.Linear(input_dim, nhead * ngroup * output_dim, bias=True, dtype=dtype) + self.k_proj = nn.Linear(input_dim, nhead * output_dim, bias=False, dtype=dtype) + self.q_norm = Qwen3RMSNorm(output_dim) + self.k_norm = Qwen3RMSNorm(output_dim) + self.k_base = nn.Parameter(torch.zeros([nhead, 1, sink, output_dim])) + self.b = nn.Parameter(torch.zeros([nhead, 1, ngroup], dtype=dtype)) + + self.d = math.sqrt(self.output_dim) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.squeeze(0) # bsz = 1 + nseq = hidden_states.shape[0] # sequence x dim + hidden_shape = (nseq, self.nhead, -1, self.output_dim) + + queries = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)) + keys = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)) + queries = queries.transpose(0, 1).transpose(-1, -2) + keys = keys.transpose(0, 1) + + # head x seq x 1 x group + logit = torch.matmul(keys, queries) / self.d + self.b.unsqueeze(2) + # head x 1 x sink x group + logit_base = torch.matmul(self.k_base, queries) / self.d + score = 1 / (1 + torch.exp(logit_base - logit).sum(2, keepdim=True)) + + score = score.mean(-1) # n_head, seq, 1 + return score.squeeze(-1).unsqueeze(0) # bsz x n_head x seq + + def extra_repr(self): + # Customize the print output + repr_str = f"index={self.index}, output_dim={self.output_dim}, nhead={self.nhead}, ngroup={self.ngroup}\n" + if self.sink != 0: + repr_str += f"k_base shape: {self.k_base.shape}\n" + repr_str += f"b shape: {self.b.shape}\n" + return repr_str + + +def init_fastkvzip(model_config, device="cuda"): + dtype = model_config.dtype + input_dim = model_config.hidden_size + sink, output_dim = 16, 16 + ngroup = model_config.num_attention_heads // model_config.num_key_value_heads + nhead = model_config.num_key_value_heads + + modules = [] + for idx in range(model_config.num_hidden_layers): + module = Weight(idx, input_dim, output_dim, nhead, ngroup, dtype, sink=sink).to(device) + modules.append(module) + print(f"load random gate ({module})") + return modules + + +def load_fastkvzip(model_name="Qwen/Qwen3-8B", file_name="fastkvzip", device="cuda"): + if not model_name: + raise AssertionError("Model_name is empty. Please check load_gate.") + state_dict, gate_id = get_gate_weight(model_name, file_name) + + dtype = state_dict[0]["q_proj.weight"].dtype + head_group_outdim, input_dim = state_dict[0]["q_proj.weight"].shape + head_outdim, _ = state_dict[0]["k_proj.weight"].shape + output_dim = state_dict[0]["q_norm.weight"].shape[-1] + nhead = head_outdim // output_dim + ngroup = head_group_outdim // head_outdim + + m = re.search(r"sink(\d+)", gate_id) + sink = int(m.group(1)) if m else 0 + + modules = [] + for idx, weight in enumerate(state_dict): + module = Weight(idx, input_dim, output_dim, nhead, ngroup, dtype, sink=sink).to(device) + module.load_state_dict(weight) + modules.append(module) + + print(f"load gate {gate_id} ({module})") + return modules + + +def get_gate_id(model_name, file_name="fastkvzip"): + if file_name == "fastkvzip": + config = AutoConfig.from_pretrained(model_name) + if hasattr(config, "text_config"): + config = config.text_config + ngroup = config.num_attention_heads // config.num_key_value_heads + file_name = f"q{ngroup}_dim16_sink16" + + model_name = model_name.split("/")[-1].lower() + gate_id = os.path.join(model_name, file_name + ".pt") + return gate_id + + +def get_gate_weight(model_name, file_name): + gate_id = get_gate_id(model_name, file_name) + file_path = hf_hub_download(repo_id="Jang-Hyun/Fast-KVzip", filename=gate_id, repo_type="model") + + # Load the PyTorch tensor/dictionary + weights = torch.load(file_path, weights_only=False)["module"] + return weights, gate_id + + +@dataclass +class FastKVzipPress(BasePress): + """ + Fast KVzip estimates KV importance scores using gates trained on KVzip scores. + + Based on Fast KVzip (https://arxiv.org/abs/2601.17668). + Authors: Jang-Hyun Kim, Dongyoon Han, Sangdoo Yun + Affiliation: NAVER AI Lab + + Parameters + ---------- + compression_ratio : float, default=0.0 + Fraction of key-value pairs to remove during compression. + layerwise : bool, default=False + Whether to enable uniform compression ratios across layers. + When False, while the overall KV cache compression ratio is maintained, + each layer has a different compression ratio. + n_sink : int, default=4 + Number of initial tokens to preserve as attention sinks. + window_size : int, default=4096 + Number of tokens in the local window retained during chunked prefilling. + window_size : float, default=0.02 + Fraction of the context length used to calculate the local window size retained during short-context prefilling. + """ + + compression_ratio: float = 0.0 + layerwise: bool = False + + n_sink: int = 4 + window_size: int = 4096 # for chunked prefilling with long contexts + window_ratio: float = 0.02 + + gates: list[nn.Module] | None = field(init=False, default=None) + score_val: list[torch.Tensor] | torch.Tensor | None = field(init=False, default=None) + + def __post_init_from_model__(self, model): + """ + Automatically load gates for the model. + """ + if self.gates is None: + try: + self.gates = load_fastkvzip(model_name=model.config.name_or_path, device=model.device) + except Exception: + print("The gates for the given model are not released!") + self.gates = init_fastkvzip(model.config, device=model.device) + + @contextmanager + def __call__(self, model: PreTrainedModel) -> Generator: + self.__post_init_from_model__(model) + + if not isinstance(model, SUPPORTED_MODELS): + logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}") + + hooks = [] + try: + self.score_val = [None for _ in range(len(model.model.layers))] # reset every prefilling + for layer in model.model.layers: + layer.self_attn.rotary_emb = model.model.rotary_emb + hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) + yield + + self.compress_post(model) # Perform compression + + finally: + for hook in hooks: + hook.remove() + + def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): + """ + Override the forward_hook of BasePress. + During the forward_hook, KVzip only calculates importance scores, + aggregates scores across all layers, and then performs compression. + """ + + hidden_states = kwargs["hidden_states"] + q_len = hidden_states.shape[1] + + # Don't compress after pre-filling + if kwargs["cache_position"][-1] > q_len: + return output + + self._score_fast(module, hidden_states) + return output + + def _score_fast(self, module: nn.Module, hidden_states: torch.Tensor): + """ + Calculate the KV importance scores. + """ + + scores = self.gates[int(module.layer_idx)](hidden_states) + scores[:, :, : self.n_sink] = 1.0 + + ctx_len = scores.size(-1) + if ctx_len < 32000: + window_size = int(ctx_len * self.window_ratio) + else: + window_size = self.window_size + scores[:, :, -window_size:] = 1.0 + + self.score_val[int(module.layer_idx)] = scores + + def compress_post(self, model: PreTrainedModel): + """ + Obtain the indices of KV pairs to be evicted. + Adopted from adakv_press.compress (fake compression). KVzip does not rely on safeguards. + """ + self.score_val = torch.stack(self.score_val, dim=0) + + if self.compression_ratio > 0: + n_layer, bsz, num_key_value_heads, ctx_len = self.score_val.shape + + # calculate the pruned KV pairs across layers + if self.layerwise: + nl = int(bsz * num_key_value_heads * ctx_len * self.compression_ratio) + n_pruned_layers = nl * torch.ones(n_layer, device=self.score_val.device, dtype=torch.int) + else: + n_pruned_indices = int(self.score_val.numel() * self.compression_ratio) + pruned_indices = torch.topk(-self.score_val.reshape(-1), n_pruned_indices).indices + n_tokens_per_layer = bsz * num_key_value_heads * ctx_len + n_pruned_layers = torch.bincount(pruned_indices // n_tokens_per_layer, minlength=n_layer).int() + + for layer in model.model.layers: + module = layer.self_attn + layer_idx = int(module.layer_idx) + + assert module.config._attn_implementation != "eager", "eager mode not supported" + + scores = self.score_val[layer_idx] + + # Compute bottom-k across heads + n_pruned = n_pruned_layers[layer_idx].cpu() + indices = torch.topk(-scores.reshape(bsz, -1), n_pruned, dim=1).indices.flatten().cpu() + + # Save indices to mask during the attention mechanism. Please refer to attention_patch.py for details + batch_indices = torch.arange(bsz, device=n_pruned.device).repeat_interleave(n_pruned) + head_indices = indices // ctx_len + seq_indices = indices % ctx_len + module.masked_key_indices = (batch_indices, head_indices, seq_indices) diff --git a/tests/default_presses.py b/tests/default_presses.py index 078a55f3..cd17c6e1 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -9,6 +9,7 @@ DuoAttentionPress, ExpectedAttentionPress, ExpectedAttentionStatsPress, + FastKVzipPress, KeyDiffPress, KnormPress, KVzapPress, @@ -93,6 +94,7 @@ def post_init_from_model(self, model): "cls": KVzipPress, "kwargs": [{"compression_ratio": 0.5, "layerwise": False}, {"compression_ratio": 0.8, "layerwise": True}], }, + {"cls": FastKVzipPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": CURPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": TestKVzapPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, { diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index 6b19fd94..f937b6ab 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -14,13 +14,14 @@ ComposedPress, CriticalAdaKVPress, CriticalKVPress, + DMSPress, + FastKVzipPress, KeyRerotationPress, KnormPress, KVzipPress, ObservedAttentionPress, ScorerPress, SnapKVPress, - DMSPress, ThinKPress, ) from tests.default_presses import default_presses @@ -80,7 +81,9 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 if hasattr(press, "post_init_from_model"): press.post_init_from_model(unit_test_model) if issubclass(wrapper_press, ComposedPress): - if isinstance(press, KVzipPress): # KVzipPress is currently not compatible with ComposedPress + if isinstance(press, KVzipPress) or isinstance( + press, FastKVzipPress + ): # KVzipPress and FastKVzipPress are currently not compatible with ComposedPress return press = ComposedPress(presses=[press]) elif not isinstance(press, ScorerPress): # remaining wrapper presses only support ScorerPress From be4889e9a35e1460edb2fcc02e8c03eebee90fa2 Mon Sep 17 00:00:00 2001 From: Jang-Hyun Date: Wed, 4 Feb 2026 08:13:43 +0000 Subject: [PATCH 2/4] description Signed-off-by: Jang-Hyun --- README.md | 1 + kvpress/presses/fastkvzip_press.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/README.md b/README.md index 5348a793..6dce0964 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres - `CompactorPress` ([source](kvpress/presses/compactor_press.py), [paper](https://arxiv.org/abs/2507.08143)): blends `NonCausalAttnPress` and `LeverageScorePress` based on the compression_ratio. - `CURPress` ([source](kvpress/presses/cur_press.py), [paper](https://arxiv.org/abs/2509.15038)): prune keys and values based on the CUR decomposition using approximate leverage scores. - `KVzapPress` ([source](kvpress/presses/kvzap/kvzap_press.py), [paper](https://arxiv.org/abs/2601.07891), [training](kvzap)): approximate KVzip+ using a fast surrogate model. To be used in conjunction with the `DMSPress`. +- `FastKVzipPress` ([source](kvpress/presses/fastkvzip_press.py), [paper](https://arxiv.org/abs/2601.17668)): approximate KVzip through a lightweight gating mechanism. Some presses rely on a different logic: - `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/abs/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries diff --git a/kvpress/presses/fastkvzip_press.py b/kvpress/presses/fastkvzip_press.py index 919914fc..ee560513 100644 --- a/kvpress/presses/fastkvzip_press.py +++ b/kvpress/presses/fastkvzip_press.py @@ -142,6 +142,10 @@ class FastKVzipPress(BasePress): """ Fast KVzip estimates KV importance scores using gates trained on KVzip scores. + In this code, we implement Fast KVzip with minimal changes to this repository. + For a fully optimized implementation with actual compression and chunked-prefill, + please refer to the original repository (https://github.com/Janghyun1230/FastKVzip). + Based on Fast KVzip (https://arxiv.org/abs/2601.17668). Authors: Jang-Hyun Kim, Dongyoon Han, Sangdoo Yun Affiliation: NAVER AI Lab From 364728dcd93127aa23a969bf2ed0e1b43aef3d80 Mon Sep 17 00:00:00 2001 From: Jang-Hyun Date: Fri, 6 Feb 2026 06:30:42 +0000 Subject: [PATCH 3/4] description Signed-off-by: Jang-Hyun --- kvpress/presses/fastkvzip_press.py | 41 +++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/kvpress/presses/fastkvzip_press.py b/kvpress/presses/fastkvzip_press.py index ee560513..7f911c70 100644 --- a/kvpress/presses/fastkvzip_press.py +++ b/kvpress/presses/fastkvzip_press.py @@ -12,7 +12,7 @@ import torch from huggingface_hub import hf_hub_download from torch import nn -from transformers import AutoConfig, PreTrainedModel +from transformers import AutoConfig, PreTrainedModel, Gemma3ForConditionalGeneration from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm from kvpress.presses.base_press import SUPPORTED_MODELS, BasePress @@ -21,6 +21,9 @@ class Weight(nn.Module): + """ + Fast KVzip gate architecture (https://arxiv.org/abs/2601.17668). + """ def __init__( self, index: int, @@ -28,8 +31,8 @@ def __init__( output_dim: int, nhead: int, ngroup: int, - dtype, - sink=1, + dtype: torch.dtype, + sink: int = 1, ): super().__init__() self.index = index @@ -75,7 +78,8 @@ def extra_repr(self): return repr_str -def init_fastkvzip(model_config, device="cuda"): +def init_fastkvzip(model_config, device: str = "cuda"): + """ Random initialization of gate weights """ dtype = model_config.dtype input_dim = model_config.hidden_size sink, output_dim = 16, 16 @@ -90,7 +94,8 @@ def init_fastkvzip(model_config, device="cuda"): return modules -def load_fastkvzip(model_name="Qwen/Qwen3-8B", file_name="fastkvzip", device="cuda"): +def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", file_name: str = "fastkvzip", device: str = "cuda"): + """ Load trained gate weights """ if not model_name: raise AssertionError("Model_name is empty. Please check load_gate.") state_dict, gate_id = get_gate_weight(model_name, file_name) @@ -115,7 +120,8 @@ def load_fastkvzip(model_name="Qwen/Qwen3-8B", file_name="fastkvzip", device="cu return modules -def get_gate_id(model_name, file_name="fastkvzip"): +def get_gate_id(model_name: str, file_name: str = "fastkvzip"): + """ Get the gate id from model names """ if file_name == "fastkvzip": config = AutoConfig.from_pretrained(model_name) if hasattr(config, "text_config"): @@ -128,7 +134,8 @@ def get_gate_id(model_name, file_name="fastkvzip"): return gate_id -def get_gate_weight(model_name, file_name): +def get_gate_weight(model_name: str, file_name: str): + """ Load trained gate weights from HuggingFace """ gate_id = get_gate_id(model_name, file_name) file_path = hf_hub_download(repo_id="Jang-Hyun/Fast-KVzip", filename=gate_id, repo_type="model") @@ -176,7 +183,7 @@ class FastKVzipPress(BasePress): gates: list[nn.Module] | None = field(init=False, default=None) score_val: list[torch.Tensor] | torch.Tensor | None = field(init=False, default=None) - def __post_init_from_model__(self, model): + def post_init_from_model(self, model): """ Automatically load gates for the model. """ @@ -189,16 +196,26 @@ def __post_init_from_model__(self, model): @contextmanager def __call__(self, model: PreTrainedModel) -> Generator: - self.__post_init_from_model__(model) + """ + Context manager that handles both initial prefilling and Fast KVzip scoring/compression. + This overrides the base class __call__ method to implement the Fast KVzip algorithm: + 1. First yield: allows initial prefilling with context and KV importance scoring via gates + 2. After yield: performs KV eviction based on the importance scores + """ if not isinstance(model, SUPPORTED_MODELS): logger.warning(f"Model {type(model)} not tested, supported models: {SUPPORTED_MODELS}") + self.post_init_from_model(model) hooks = [] try: self.score_val = [None for _ in range(len(model.model.layers))] # reset every prefilling - for layer in model.model.layers: - layer.self_attn.rotary_emb = model.model.rotary_emb + language_model = model.model.language_model if hasattr(model.model, "language_model") else model.model + for layer in language_model.layers: + if isinstance(model, Gemma3ForConditionalGeneration) and layer.self_attn.is_sliding: + # Skip layers with sliding window attention, only for Gemma3 + continue + layer.self_attn.rotary_emb = language_model.rotary_emb hooks.append(layer.self_attn.register_forward_hook(self.forward_hook, with_kwargs=True)) yield @@ -211,7 +228,7 @@ def __call__(self, model: PreTrainedModel) -> Generator: def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dict, output: list): """ Override the forward_hook of BasePress. - During the forward_hook, KVzip only calculates importance scores, + During the forward_hook, Fast KVzip calculates importance scores, aggregates scores across all layers, and then performs compression. """ From de4e2bd6017974f25e08688400b54a464d1f4e40 Mon Sep 17 00:00:00 2001 From: Jang-Hyun Date: Tue, 10 Feb 2026 06:23:41 +0000 Subject: [PATCH 4/4] clean up Signed-off-by: Jang-Hyun --- kvpress/presses/fastkvzip_press.py | 62 ++++++++++++------------------ tests/default_presses.py | 19 ++++++++- tests/presses/test_presses.py | 5 +-- 3 files changed, 45 insertions(+), 41 deletions(-) diff --git a/kvpress/presses/fastkvzip_press.py b/kvpress/presses/fastkvzip_press.py index 7f911c70..eb929eaf 100644 --- a/kvpress/presses/fastkvzip_press.py +++ b/kvpress/presses/fastkvzip_press.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -class Weight(nn.Module): +class FastKVzipGate(nn.Module): """ Fast KVzip gate architecture (https://arxiv.org/abs/2601.17668). """ @@ -28,11 +28,11 @@ def __init__( self, index: int, input_dim: int, - output_dim: int, nhead: int, ngroup: int, dtype: torch.dtype, - sink: int = 1, + output_dim: int = 16, + sink: int = 16, ): super().__init__() self.index = index @@ -78,27 +78,11 @@ def extra_repr(self): return repr_str -def init_fastkvzip(model_config, device: str = "cuda"): - """ Random initialization of gate weights """ - dtype = model_config.dtype - input_dim = model_config.hidden_size - sink, output_dim = 16, 16 - ngroup = model_config.num_attention_heads // model_config.num_key_value_heads - nhead = model_config.num_key_value_heads - - modules = [] - for idx in range(model_config.num_hidden_layers): - module = Weight(idx, input_dim, output_dim, nhead, ngroup, dtype, sink=sink).to(device) - modules.append(module) - print(f"load random gate ({module})") - return modules - - -def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", file_name: str = "fastkvzip", device: str = "cuda"): +def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", device: str = "cuda"): """ Load trained gate weights """ if not model_name: raise AssertionError("Model_name is empty. Please check load_gate.") - state_dict, gate_id = get_gate_weight(model_name, file_name) + state_dict, gate_id = get_gate_weight(model_name) dtype = state_dict[0]["q_proj.weight"].dtype head_group_outdim, input_dim = state_dict[0]["q_proj.weight"].shape @@ -112,7 +96,7 @@ def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", file_name: str = "fastkvzi modules = [] for idx, weight in enumerate(state_dict): - module = Weight(idx, input_dim, output_dim, nhead, ngroup, dtype, sink=sink).to(device) + module = FastKVzipGate(idx, input_dim, nhead, ngroup, dtype, output_dim, sink).to(device) module.load_state_dict(weight) modules.append(module) @@ -120,23 +104,22 @@ def load_fastkvzip(model_name: str = "Qwen/Qwen3-8B", file_name: str = "fastkvzi return modules -def get_gate_id(model_name: str, file_name: str = "fastkvzip"): +def get_gate_id(model_name: str): """ Get the gate id from model names """ - if file_name == "fastkvzip": - config = AutoConfig.from_pretrained(model_name) - if hasattr(config, "text_config"): - config = config.text_config - ngroup = config.num_attention_heads // config.num_key_value_heads - file_name = f"q{ngroup}_dim16_sink16" + config = AutoConfig.from_pretrained(model_name) + if hasattr(config, "text_config"): + config = config.text_config + ngroup = config.num_attention_heads // config.num_key_value_heads + file_name = f"q{ngroup}_dim16_sink16" model_name = model_name.split("/")[-1].lower() gate_id = os.path.join(model_name, file_name + ".pt") return gate_id -def get_gate_weight(model_name: str, file_name: str): +def get_gate_weight(model_name: str): """ Load trained gate weights from HuggingFace """ - gate_id = get_gate_id(model_name, file_name) + gate_id = get_gate_id(model_name) file_path = hf_hub_download(repo_id="Jang-Hyun/Fast-KVzip", filename=gate_id, repo_type="model") # Load the PyTorch tensor/dictionary @@ -169,7 +152,7 @@ class FastKVzipPress(BasePress): Number of initial tokens to preserve as attention sinks. window_size : int, default=4096 Number of tokens in the local window retained during chunked prefilling. - window_size : float, default=0.02 + window_ratio : float, default=0.02 Fraction of the context length used to calculate the local window size retained during short-context prefilling. """ @@ -190,9 +173,12 @@ def post_init_from_model(self, model): if self.gates is None: try: self.gates = load_fastkvzip(model_name=model.config.name_or_path, device=model.device) - except Exception: - print("The gates for the given model are not released!") - self.gates = init_fastkvzip(model.config, device=model.device) + except Exception as e: + raise RuntimeError( + "The gates for the given model are not released! " + "Please check the available models at: " + "https://huggingface.co/Jang-Hyun/Fast-KVzip/tree/main" + ) from e @contextmanager def __call__(self, model: PreTrainedModel) -> Generator: @@ -246,8 +232,10 @@ def _score_fast(self, module: nn.Module, hidden_states: torch.Tensor): """ Calculate the KV importance scores. """ + layer_idx = int(module.layer_idx) - scores = self.gates[int(module.layer_idx)](hidden_states) + self.gates[layer_idx] = self.gates[layer_idx].to(hidden_states.device) + scores = self.gates[layer_idx](hidden_states) scores[:, :, : self.n_sink] = 1.0 ctx_len = scores.size(-1) @@ -257,7 +245,7 @@ def _score_fast(self, module: nn.Module, hidden_states: torch.Tensor): window_size = self.window_size scores[:, :, -window_size:] = 1.0 - self.score_val[int(module.layer_idx)] = scores + self.score_val[layer_idx] = scores def compress_post(self, model: PreTrainedModel): """ diff --git a/tests/default_presses.py b/tests/default_presses.py index cd17c6e1..5b1278ed 100644 --- a/tests/default_presses.py +++ b/tests/default_presses.py @@ -27,6 +27,7 @@ TOVAPress, ) from kvpress.presses.kvzap_press import KVzapConfig, KVzapModel +from kvpress.presses.fastkvzip_press import FastKVzipGate class TestDuoAttentionPress(DuoAttentionPress): @@ -49,6 +50,22 @@ def post_init_from_model(self, model): self.kvzap_model = KVzapModel(config) +class TestFastKVzipPress(FastKVzipPress): + """Test version of FastKVzipPress that creates a mock model instead of loading from HuggingFace.""" + + def post_init_from_model(self, model): + if self.gates is None: + dtype = model.config.dtype + input_dim = model.config.hidden_size + ngroup = model.config.num_attention_heads // model.config.num_key_value_heads + nhead = model.config.num_key_value_heads + + self.gates = [] + for idx in range(model.config.num_hidden_layers): + module = FastKVzipGate(idx, input_dim, nhead, ngroup, dtype).to(model.device) + self.gates.append(module) + + # contains all presses to be tested # kwargs should be ordered easy to hard compression default_presses = [ @@ -94,7 +111,7 @@ def post_init_from_model(self, model): "cls": KVzipPress, "kwargs": [{"compression_ratio": 0.5, "layerwise": False}, {"compression_ratio": 0.8, "layerwise": True}], }, - {"cls": FastKVzipPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, + {"cls": TestFastKVzipPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": CURPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, {"cls": TestKVzapPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]}, { diff --git a/tests/presses/test_presses.py b/tests/presses/test_presses.py index f937b6ab..b86b7eef 100644 --- a/tests/presses/test_presses.py +++ b/tests/presses/test_presses.py @@ -81,9 +81,8 @@ def test_presses_run(unit_test_model, press_dict, wrapper_press): # noqa: F811 if hasattr(press, "post_init_from_model"): press.post_init_from_model(unit_test_model) if issubclass(wrapper_press, ComposedPress): - if isinstance(press, KVzipPress) or isinstance( - press, FastKVzipPress - ): # KVzipPress and FastKVzipPress are currently not compatible with ComposedPress + if isinstance(press, (KVzipPress, FastKVzipPress)): + # KVzipPress and FastKVzipPress are currently not compatible with ComposedPress return press = ComposedPress(presses=[press]) elif not isinstance(press, ScorerPress): # remaining wrapper presses only support ScorerPress