From 0360ebccfd2cd7ec85b99d5a96a5d00bc66b102d Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 28 Jan 2026 10:55:52 +0000 Subject: [PATCH 1/2] Support transformers v5 Signed-off-by: SimJeg --- README.md | 5 ++--- kvpress/presses/kvzip_press.py | 11 +++++------ notebooks/speed_and_memory.ipynb | 16 ++++++++-------- pyproject.toml | 4 +--- tests/integration/test_ruler.py | 6 +++--- tests/presses/test_key_rerotation_press_rope.py | 3 +++ tests/test_pipeline.py | 4 ++-- tests/test_press_call.py | 12 ++++++------ 8 files changed, 30 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 6cb4e8ee..5348a793 100644 --- a/README.md +++ b/README.md @@ -170,10 +170,9 @@ Below we report the average performance on the RULER dataset with 4k context len We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline: ```python -from transformers import QuantizedCacheConfig, QuantoQuantizedCache +from transformers import QuantizedCache -config = QuantizedCacheConfig(nbits=4) -cache = QuantoQuantizedCache(config) +cache = QuantizedCache(backend="quanto", nbits=4) pipe(..., cache=cache) ``` diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py index 638fa702..d81834af 100644 --- a/kvpress/presses/kvzip_press.py +++ b/kvpress/presses/kvzip_press.py @@ -363,14 +363,13 @@ def compress_post(self, model: PreTrainedModel): # calculate the pruned KV pairs across layers if self.layerwise: - nl = int(num_key_value_heads * ctx_len * self.compression_ratio) + nl = int(bsz * num_key_value_heads * ctx_len * self.compression_ratio) n_pruned_layers = nl * torch.ones(n_layer, device=self.score_val.device, dtype=torch.int) else: - score_sort = torch.sort(self.score_val.reshape(-1)).values # ascending order - n = max(int(len(score_sort) * self.compression_ratio) - 1, 0) - thres = score_sort[n].item() - - n_pruned_layers = (self.score_val.reshape(n_layer, -1) <= thres).sum(-1) # n_prune + n_pruned_indices = int(self.score_val.numel() * self.compression_ratio) + pruned_indices = torch.topk(-self.score_val.reshape(-1), n_pruned_indices).indices + n_tokens_per_layer = bsz * num_key_value_heads * ctx_len + n_pruned_layers = torch.bincount(pruned_indices // n_tokens_per_layer, minlength=n_layer).int() for layer in model.model.layers: module = layer.self_attn diff --git a/notebooks/speed_and_memory.ipynb b/notebooks/speed_and_memory.ipynb index dac37a0a..008d557b 100644 --- a/notebooks/speed_and_memory.ipynb +++ b/notebooks/speed_and_memory.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -35,7 +35,7 @@ "import numpy as np\n", "import torch\n", "from transformers import AutoModelForCausalLM, pipeline\n", - "from transformers import QuantizedCacheConfig, QuantoQuantizedCache, DynamicCache, QuantizedCache\n", + "from transformers import DynamicCache, QuantizedCache\n", "from transformers.utils.logging import disable_progress_bar\n", "import transformers\n", "\n", @@ -65,19 +65,19 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_size_of_cache(cache):\n", - " if isinstance(cache, QuantoQuantizedCache):\n", + " if isinstance(cache, QuantizedCache):\n", " # We cannot use x.element_size() * x.nelement() as below to calculate the size of the cache, \n", " # as cache._quantized_value_cache[0].element_size() triggers a call of __torch_dispatch__,\n", " # which, in turn, unpacks the internally packed tensor; and thus does not report the correct internal storage size.\n", " # See also https://github.com/huggingface/optimum-quanto/blob/main/optimum/quanto/tensor/packed.py#L144\n", "\n", - " # As QuantoQuantizedCache stores values, as well as shift and scale, \n", - " # we temporarily save the cache to disc and getthe size of the saved object\n", + " # As QuantizedCache stores values, as well as shift and scale, \n", + " # we temporarily save the cache to disc and get the size of the saved object\n", " temp_file = \"tmp.pickle\"\n", " with open(temp_file, \"wb\") as f:\n", " pickle.dump(cache, f)\n", @@ -100,7 +100,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -125,7 +125,7 @@ " if cache_implementation == \"dynamic\":\n", " cache = DynamicCache()\n", " elif cache_implementation == \"quantized\":\n", - " cache = QuantoQuantizedCache(config=model.config, nbits=4)\n", + " cache = QuantizedCache(backend=\"quanto\", config=model.config, nbits=4)\n", " else:\n", " raise NotImplementedError(f\"Cache {cache_implementation} not yet implemented\")\n", "\n", diff --git a/pyproject.toml b/pyproject.toml index 66efdae7..577a52e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,9 +13,7 @@ readme = "README.md" dependencies = [ "numpy>=2.0.0,<3", "torch>=2.3.1,<3", - # transformers<4.54 is not supported due to refactoring of the transformers library. - # transformers 4.54-4.55.2 are not compatible with kvpress due to flash attention bugs in transformers - "transformers>=4.56,<5.0.0", + "transformers>=5.0.0", "sentencepiece>=0.2.0,<0.3", "protobuf>=5.27.2,<6", "datasets>=2.21.0,<3", diff --git a/tests/integration/test_ruler.py b/tests/integration/test_ruler.py index 155cf6dd..f7bc62bb 100644 --- a/tests/integration/test_ruler.py +++ b/tests/integration/test_ruler.py @@ -4,7 +4,7 @@ import datasets import pytest import torch -from transformers import DynamicCache, QuantoQuantizedCache +from transformers import DynamicCache, QuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available from kvpress import QFilterPress @@ -44,7 +44,7 @@ def test_ruler_is_correct( if cache == "dynamic": cache = DynamicCache() elif cache == "quantized" and is_optimum_quanto_available(): - cache = QuantoQuantizedCache(config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4) + cache = QuantizedCache(backend="quanto", config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4) elif cache == "quantized" and not is_optimum_quanto_available(): pytest.skip("Quanto is not installed") else: @@ -89,7 +89,7 @@ def test_ruler_is_correct_for_qfilter( if cache == "dynamic": cache = DynamicCache() elif cache == "quantized" and is_optimum_quanto_available(): - cache = QuantoQuantizedCache(config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4) + cache = QuantizedCache(backend="quanto", config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4) elif cache == "quantized" and not is_optimum_quanto_available(): pytest.skip("Quanto is not installed") else: diff --git a/tests/presses/test_key_rerotation_press_rope.py b/tests/presses/test_key_rerotation_press_rope.py index bf838080..74217b52 100644 --- a/tests/presses/test_key_rerotation_press_rope.py +++ b/tests/presses/test_key_rerotation_press_rope.py @@ -48,6 +48,7 @@ def test_rerotate_keys_is_matches_reference_implementation( "rope_type": "yarn", } cfg.max_position_embeddings = 131072 + cfg.rope_theta = 500000.0 try: unit_test_model.model.rotary_emb = LlamaRotaryEmbedding(cfg, device=unit_test_model.device) except KeyError: @@ -63,6 +64,8 @@ def test_rerotate_keys_is_matches_reference_implementation( unit_test_model = unit_test_model.cuda().half() elif precision == "half": pytest.skip("Half-precision test skipped because CUDA is not available.") + elif precision == "full": + unit_test_model = unit_test_model.float() original_press = RandomPressStoreIndices(compression_ratio=0.5) key_rerotation_press = KeyRerotationPress(press=original_press) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 6f7260c0..125da038 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -6,7 +6,7 @@ import pytest import torch -from transformers import AutoTokenizer, DynamicCache, QuantoQuantizedCache +from transformers import AutoTokenizer, DynamicCache, QuantizedCache from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available from kvpress import ExpectedAttentionPress @@ -112,7 +112,7 @@ def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noq context = "This is a test article. It was written on 2022-01-01." questions = ["When was this article written?"] press = ExpectedAttentionPress(compression_ratio=0.4) - cache = QuantoQuantizedCache(config=kv_press_danube_pipeline.model.config, nbits=4) + cache = QuantizedCache(backend="quanto", config=kv_press_danube_pipeline.model.config, nbits=4) answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"] assert len(answers) == 1 diff --git a/tests/test_press_call.py b/tests/test_press_call.py index 4a0b4100..80234aa0 100644 --- a/tests/test_press_call.py +++ b/tests/test_press_call.py @@ -28,13 +28,13 @@ def test_context_manager_applies_compression(unit_test_model): # noqa: F811 seq_len = input_ids.shape[-1] - for key, values in past_key_values: - assert key.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length() - assert values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length() + for layer in past_key_values.layers: + assert layer.keys.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length() + assert layer.values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length() input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device) past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values - for key, values in past_key_values: - assert key.shape[2] == seq_len == past_key_values.get_seq_length() - assert values.shape[2] == seq_len == past_key_values.get_seq_length() + for layer in past_key_values.layers: + assert layer.keys.shape[2] == seq_len == past_key_values.get_seq_length() + assert layer.values.shape[2] == seq_len == past_key_values.get_seq_length() From 7131ec60adcd37dc4593c4b39b27247fd540aa10 Mon Sep 17 00:00:00 2001 From: SimJeg Date: Wed, 28 Jan 2026 11:31:01 +0000 Subject: [PATCH 2/2] Upgrade version Signed-off-by: SimJeg --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 577a52e5..3f68ad88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kvpress" -version = "0.4.3" +version = "0.5.0" description = "Efficiently compress the KV cache of any pretrained transformer" authors = [ { name = "Simon Jegou" },