diff --git a/.flake8 b/.flake8 index d9ef69c4..456a8107 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -exclude = .venv,venv,.git,__pycache__,build,dist, .mypy_cache +exclude = .venv,venv,.git,__pycache__,build,dist,.mypy_cache,.pytest_cache max-line-length = 120 per-file-ignores = __init__.py:F401 diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 6a67207a..b8819c88 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -18,12 +18,13 @@ jobs: uses: actions/setup-python@v3 with: python-version: 3.10.11 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install build - - name: Build package - run: python -m build + + - name: Install uv + uses: astral-sh/setup-uv@v6 + + - name: Build + run: uv build --no-sources + - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 19f6b008..8ba50a9c 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -15,13 +15,13 @@ jobs: with: python-version: 3.10.11 - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH # Add Poetry to the PATH + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true - name: Install dependencies - run: poetry install --with dev + run: uv sync --all-groups - name: Run style checks run: make style diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 07eafff1..289fe312 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,7 +2,7 @@ name: Test on: push: - branches: [ main ] + branches: [ main ] pull_request: jobs: @@ -15,12 +15,12 @@ jobs: with: python-version: 3.10.11 - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.local/bin" >> $GITHUB_PATH # Add Poetry to the PATH + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true - name: Install dependencies - run: poetry install --with dev + run: uv sync --all-groups - run: make test diff --git a/.gitignore b/.gitignore index c3c82406..0d05a95d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ dev_notebooks/ results/ reports/ .DS_Store -poetry.lock +uv.lock *.parquet # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/Makefile b/Makefile index 9015c401..fbd14aa9 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ SHELL := /bin/bash -POETRY ?= $(shell which poetry) +UV ?= $(shell which uv) BUILD_VERSION:=$(APP_VERSION) TESTS_FILTER:= @@ -7,11 +7,11 @@ PYTEST_LOG=--log-cli-level=debug --log-format="%(asctime)s %(levelname)s [%(name .PHONY: isort isort: - $(POETRY) run isort . + $(UV) run isort . .PHONY: black black: - $(POETRY) run black . + $(UV) run black . PHONY: format format: isort black @@ -24,10 +24,10 @@ style: reports @echo -n > reports/copyright_errors.log @echo - -$(POETRY) run flake8 | tee -a reports/flake8_errors.log + -$(UV) run flake8 | tee -a reports/flake8_errors.log @if [ -s reports/flake8_errors.log ]; then exit 1; fi - -$(POETRY) run mypy . --check-untyped-defs | tee -a reports/mypy.log + -$(UV) run mypy . --check-untyped-defs | tee -a reports/mypy.log @if ! grep -Eq "Success: no issues found in [0-9]+ source files" reports/mypy.log ; then exit 1; fi @echo "Checking for SPDX-FileCopyrightText headers in Python files..." @@ -42,7 +42,7 @@ reports: .PHONY: test test: reports PYTHONPATH=. \ - $(POETRY) run pytest \ + $(UV) run pytest \ --cov-report xml:reports/coverage.xml \ --cov=kvpress/ \ --junitxml=./reports/junit.xml \ diff --git a/README.md b/README.md index 8b1c99f8..4e412ae5 100644 --- a/README.md +++ b/README.md @@ -16,22 +16,17 @@ Deploying long-context LLMs is costly due to the linear growth of the key-value pip install kvpress ``` -If possible, install flash attention: -```bash -pip install flash-attn --no-build-isolation -``` - -For a local installation with all dev dependencies, use poetry: +For a local installation with all dev dependencies, use uv: ```bash git clone https://github.com/NVIDIA/kvpress.git cd kvpress -poetry install --with dev +uv sync --all-groups ``` ## Usage -kvpress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline`. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you: +KVPress provides a set of "presses" that compress the KV cache during the prefilling-phase. Each press is associated with a `compression_ratio` attribute that measures the compression of the cache. The easiest way to use a press is through our custom `KVPressTextGenerationPipeline`. It is automatically registered as a transformers pipeline with the name "kv-press-text-generation" when kvpress is imported and handles chat templates and tokenization for you: ```python from transformers import pipeline @@ -208,4 +203,25 @@ with press(model): However, the `generate` method does not allow to exclude the question from the compression, which would artificially favors methods such as SnapKV. Ideally, we want a compression method that works whatever comes after the context (_e.g._ for use cases such as chat or document question answering). Finally the `generate` method does not allow to provide generation for multiple questions at once. - \ No newline at end of file + + + +## Advanced installation settings +To install optional packages, you can use [uv](https://docs.astral.sh/uv/). +To install with flash attention, just run: + +```bash +git clone https://github.com/NVIDIA/kvpress.git +cd kvpress +uv sync --extra flash-attn +``` + +To install with dependencies for evaluation, run + +```bash +git clone https://github.com/NVIDIA/kvpress.git +cd kvpress +uv sync --extra eval +``` + +Notice that optional dependecies can be combined. \ No newline at end of file diff --git a/evaluation/README.md b/evaluation/README.md index d5ba00e1..bd78f400 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -5,6 +5,7 @@ We support evaluation for all the presses implemented in the library, on a variety of popular benchmarks. ### Quick Start 🚀 +> Evaluation requires some additional packages. You can install them with `uv sync --group eval` Running evaluation is straightforward! Make sure you are in the `evaluation` directory, then: diff --git a/kvpress/pipeline.py b/kvpress/pipeline.py index c4da1e4a..606f8575 100644 --- a/kvpress/pipeline.py +++ b/kvpress/pipeline.py @@ -135,7 +135,10 @@ def preprocess( else: separator = "\n" + "#" * len(context) context = self.tokenizer.apply_chat_template( - [{"role": "user", "content": context + separator}], add_generation_prompt=True, tokenize=False + [{"role": "user", "content": context + separator}], + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, ) context, question_suffix = context.split(separator) diff --git a/kvpress/presses/block_press.py b/kvpress/presses/block_press.py index 6ae50ca5..6d8da788 100644 --- a/kvpress/presses/block_press.py +++ b/kvpress/presses/block_press.py @@ -16,10 +16,11 @@ class BlockPress(BasePress): BlockPress: Block-wise iterative KV cache compression. Applies compression in fixed-size blocks. Iteratively scores and prunes tokens block by block, maintaining - a buffer of previously kept tokens for context. Mathematically equivalent - to global compression when scoring uses only local information. + a buffer of previously kept tokens for context. Mathematically equivalent to global compression when + scoring uses only local information. It was introduced in the KeyDiff paper as part of the KeyDiff press, + but it can also work as a standalone press. - Based on BlockPress (https://arxiv.org/abs/2504.15364). + Based on the KeyDiff paper (https://arxiv.org/abs/2504.15364). Parameters ---------- diff --git a/kvpress/presses/keydiff_press.py b/kvpress/presses/keydiff_press.py index 3932eea9..b2bba83f 100644 --- a/kvpress/presses/keydiff_press.py +++ b/kvpress/presses/keydiff_press.py @@ -21,6 +21,12 @@ class KeyDiffPress(ScorerPress): Based on KeyDiff (https://arxiv.org/abs/2504.15364). + Note: The original press in the KeyDiff paper implements a block-wise iterative compression. + In KVPress, the iterative compression is implemented in the BlockPress class. + Therefore, to replicate the paper's implementation, please use: + + `press = BlockPress(press=KeyDiffPress(compression_ratio=0.x), block_size=N)` + Parameters ---------- compression_ratio : float, default=0.0 diff --git a/kvpress/presses/kvzip_press.py b/kvpress/presses/kvzip_press.py index f41f7478..0c68fd79 100644 --- a/kvpress/presses/kvzip_press.py +++ b/kvpress/presses/kvzip_press.py @@ -101,7 +101,10 @@ def __call__(self, model: PreTrainedModel) -> Generator: dummy_context = "dummy context" separator = "\n" + "#" * len(dummy_context) temp_context = tokenizer.apply_chat_template( - [{"role": "user", "content": dummy_context + separator}], add_generation_prompt=True, tokenize=False + [{"role": "user", "content": dummy_context + separator}], + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, ) context, suffix_text = temp_context.split(separator) prefix_text = context.split(dummy_context)[0] diff --git a/pyproject.toml b/pyproject.toml index 31f9e909..bf9e226c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,50 +1,66 @@ -[tool.poetry] +[project] name = "kvpress" -authors = ["Simon Jegou", "Maximilian Jeblick", "Alessio Devoto", "Jiwei Liu", "David Austin"] +version = "0.2.10" description = "Efficiently compress the KV cache of any pretrained transformer" -version = "0.2.9" +authors = [ + { name = "Simon Jegou" }, + { name = "Maximilian Jeblick" }, + { name = "Alessio Devoto" }, + { name = "Jiwei Liu" }, + { name = "David Austin" }, +] +requires-python = ">=3.10" readme = "README.md" +dependencies = [ + "numpy>=2.0.0,<3", + "torch>=2.3.1,<3", + "transformers>=4.48.0, <4.54.0", + "sentencepiece>=0.2.0,<0.3", + "protobuf>=5.27.2,<6", + "datasets>=2.21.0,<3", + "pandas>=2.2.2,<3", + "accelerate>=1.0.0,<2", + "requests>=2.32.3,<3", + "cachetools>=5.5.2,<6", +] -[tool.poetry.dependencies] -python = ">=3.10" -numpy = "^2.0.0" -torch = "^2.3.1" -transformers = ">=4.48.0, <4.54.0" -sentencepiece = "^0.2.0" -protobuf = "^5.27.2" -datasets = "^2.21.0" -pandas = "^2.2.2" -accelerate = "^1.0.0" -requests = "^2.32.3" -cachetools = "^5.5.2" +[project.optional-dependencies] +eval = [ + "rouge>=1.0.1,<2", + "nltk>=3.9.1,<4", + "tqdm>=4.66.4,<5", + "scipy>=1.13.1,<2", + "fire>=0.6.0,<0.7", + "bert-score>=0.3.13,<0.4", +] +flash-attn = [ + "flash-attn" +] -[tool.poetry.group.dev] -optional = true +[dependency-groups] +dev = [ + "pytest>=7.0.0,<8", + "flake8>=7.0.0,<8", + "isort>=5.13.2,<6", + "black>=24.8.0,<25", + "mypy>=1.13.0,<2", + "pytest-cov>=5.0.0,<6", + "pytest-dependency>=0.6.0,<0.7", + "pytest-html>=4.1.1, <5.0.0", + "types-pyyaml~=6.0", + "ipykernel>=6.29.4,<7", + "bs4>=0.0.2,<0.0.3", + "nvitop>=1.3.2,<2", + "matplotlib>=3.9.0,<4", +] -[tool.poetry.group.dev.dependencies] -pytest = "^7.0.0" -flake8 = "^7.0.0" -isort = "^5.13.2" -black = "^24.8.0" -mypy = "^1.13.0" -pytest-cov = "^5.0.0" -pytest-dependency = "^0.6.0" -pytest-html = ">=4.1.1, <5.0.0" -types-pyyaml = "^6.0" -ipykernel = "^6.29.4" -bs4 = "^0.0.2" -nvitop = "^1.3.2" -bert-score = "^0.3.13" -rouge = "^1.0.1" -nltk = "^3.9.1" -tqdm = "^4.66.4" -scipy = "^1.13.1" -matplotlib = "^3.9.0" -fire = "^0.6.0" [build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.uv] +no-build-isolation-package = ["flash-attn"] [tool.black] line-length = 120 @@ -64,7 +80,7 @@ skip = ["venv", ".venv"] ignore_missing_imports = true allow_redefinition = true strict_optional = false -exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|tests|bundles)" +exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|venv|.venv|doc-venv|.svn|_build|buck-out|build|dist|notebooks|tools|tmp|tests|bundles|.pytest_cache|reports)" disable_error_code = ["union-attr", "operator", "call-overload", "arg-type"] [[tool.mypy.overrides]]