Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,9 @@ Below we report the average performance on the RULER dataset with 4k context len
We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline:

```python
from transformers import QuantizedCacheConfig, QuantoQuantizedCache
from transformers import QuantizedCache

config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)
cache = QuantizedCache(backend="quanto", nbits=4)

pipe(..., cache=cache)
```
Expand Down
11 changes: 5 additions & 6 deletions kvpress/presses/kvzip_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,13 @@ def compress_post(self, model: PreTrainedModel):

# calculate the pruned KV pairs across layers
if self.layerwise:
nl = int(num_key_value_heads * ctx_len * self.compression_ratio)
nl = int(bsz * num_key_value_heads * ctx_len * self.compression_ratio)
n_pruned_layers = nl * torch.ones(n_layer, device=self.score_val.device, dtype=torch.int)
else:
score_sort = torch.sort(self.score_val.reshape(-1)).values # ascending order
n = max(int(len(score_sort) * self.compression_ratio) - 1, 0)
thres = score_sort[n].item()

n_pruned_layers = (self.score_val.reshape(n_layer, -1) <= thres).sum(-1) # n_prune
n_pruned_indices = int(self.score_val.numel() * self.compression_ratio)
pruned_indices = torch.topk(-self.score_val.reshape(-1), n_pruned_indices).indices
n_tokens_per_layer = bsz * num_key_value_heads * ctx_len
n_pruned_layers = torch.bincount(pruned_indices // n_tokens_per_layer, minlength=n_layer).int()

for layer in model.model.layers:
module = layer.self_attn
Expand Down
16 changes: 8 additions & 8 deletions notebooks/speed_and_memory.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -35,7 +35,7 @@
"import numpy as np\n",
"import torch\n",
"from transformers import AutoModelForCausalLM, pipeline\n",
"from transformers import QuantizedCacheConfig, QuantoQuantizedCache, DynamicCache, QuantizedCache\n",
"from transformers import DynamicCache, QuantizedCache\n",
"from transformers.utils.logging import disable_progress_bar\n",
"import transformers\n",
"\n",
Expand Down Expand Up @@ -65,19 +65,19 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_size_of_cache(cache):\n",
" if isinstance(cache, QuantoQuantizedCache):\n",
" if isinstance(cache, QuantizedCache):\n",
" # We cannot use x.element_size() * x.nelement() as below to calculate the size of the cache, \n",
" # as cache._quantized_value_cache[0].element_size() triggers a call of __torch_dispatch__,\n",
" # which, in turn, unpacks the internally packed tensor; and thus does not report the correct internal storage size.\n",
" # See also https://github.com/huggingface/optimum-quanto/blob/main/optimum/quanto/tensor/packed.py#L144\n",
"\n",
" # As QuantoQuantizedCache stores values, as well as shift and scale, \n",
" # we temporarily save the cache to disc and getthe size of the saved object\n",
" # As QuantizedCache stores values, as well as shift and scale, \n",
" # we temporarily save the cache to disc and get the size of the saved object\n",
" temp_file = \"tmp.pickle\"\n",
" with open(temp_file, \"wb\") as f:\n",
" pickle.dump(cache, f)\n",
Expand All @@ -100,7 +100,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -125,7 +125,7 @@
" if cache_implementation == \"dynamic\":\n",
" cache = DynamicCache()\n",
" elif cache_implementation == \"quantized\":\n",
" cache = QuantoQuantizedCache(config=model.config, nbits=4)\n",
" cache = QuantizedCache(backend=\"quanto\", config=model.config, nbits=4)\n",
" else:\n",
" raise NotImplementedError(f\"Cache {cache_implementation} not yet implemented\")\n",
"\n",
Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "kvpress"
version = "0.4.3"
version = "0.5.0"
description = "Efficiently compress the KV cache of any pretrained transformer"
authors = [
{ name = "Simon Jegou" },
Expand All @@ -13,9 +13,7 @@ readme = "README.md"
dependencies = [
"numpy>=2.0.0,<3",
"torch>=2.3.1,<3",
# transformers<4.54 is not supported due to refactoring of the transformers library.
# transformers 4.54-4.55.2 are not compatible with kvpress due to flash attention bugs in transformers
"transformers>=4.56,<5.0.0",
"transformers>=5.0.0",
"sentencepiece>=0.2.0,<0.3",
"protobuf>=5.27.2,<6",
"datasets>=2.21.0,<3",
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datasets
import pytest
import torch
from transformers import DynamicCache, QuantoQuantizedCache
from transformers import DynamicCache, QuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available

from kvpress import QFilterPress
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_ruler_is_correct(
if cache == "dynamic":
cache = DynamicCache()
elif cache == "quantized" and is_optimum_quanto_available():
cache = QuantoQuantizedCache(config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4)
cache = QuantizedCache(backend="quanto", config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4)
elif cache == "quantized" and not is_optimum_quanto_available():
pytest.skip("Quanto is not installed")
else:
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_ruler_is_correct_for_qfilter(
if cache == "dynamic":
cache = DynamicCache()
elif cache == "quantized" and is_optimum_quanto_available():
cache = QuantoQuantizedCache(config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4)
cache = QuantizedCache(backend="quanto", config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4)
elif cache == "quantized" and not is_optimum_quanto_available():
pytest.skip("Quanto is not installed")
else:
Expand Down
3 changes: 3 additions & 0 deletions tests/presses/test_key_rerotation_press_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_rerotate_keys_is_matches_reference_implementation(
"rope_type": "yarn",
}
cfg.max_position_embeddings = 131072
cfg.rope_theta = 500000.0
try:
unit_test_model.model.rotary_emb = LlamaRotaryEmbedding(cfg, device=unit_test_model.device)
except KeyError:
Expand All @@ -63,6 +64,8 @@ def test_rerotate_keys_is_matches_reference_implementation(
unit_test_model = unit_test_model.cuda().half()
elif precision == "half":
pytest.skip("Half-precision test skipped because CUDA is not available.")
elif precision == "full":
unit_test_model = unit_test_model.float()

original_press = RandomPressStoreIndices(compression_ratio=0.5)
key_rerotation_press = KeyRerotationPress(press=original_press)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from transformers import AutoTokenizer, DynamicCache, QuantoQuantizedCache
from transformers import AutoTokenizer, DynamicCache, QuantizedCache
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available

from kvpress import ExpectedAttentionPress
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noq
context = "This is a test article. It was written on 2022-01-01."
questions = ["When was this article written?"]
press = ExpectedAttentionPress(compression_ratio=0.4)
cache = QuantoQuantizedCache(config=kv_press_danube_pipeline.model.config, nbits=4)
cache = QuantizedCache(backend="quanto", config=kv_press_danube_pipeline.model.config, nbits=4)
answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"]

assert len(answers) == 1
Expand Down
12 changes: 6 additions & 6 deletions tests/test_press_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def test_context_manager_applies_compression(unit_test_model): # noqa: F811

seq_len = input_ids.shape[-1]

for key, values in past_key_values:
assert key.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
assert values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
for layer in past_key_values.layers:
assert layer.keys.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
assert layer.values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()

input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values

for key, values in past_key_values:
assert key.shape[2] == seq_len == past_key_values.get_seq_length()
assert values.shape[2] == seq_len == past_key_values.get_seq_length()
for layer in past_key_values.layers:
assert layer.keys.shape[2] == seq_len == past_key_values.get_seq_length()
assert layer.values.shape[2] == seq_len == past_key_values.get_seq_length()