From 4872fce3426119e63e1a892c39b474786dafddac Mon Sep 17 00:00:00 2001 From: KGuillaume-chaps Date: Fri, 20 Jun 2025 00:36:51 +0200 Subject: [PATCH] Add support for ZSTD compression (#11161) --- .pre-commit-config.yaml | 4 ++++ CHANGES/11161.feature.rst | 2 ++ CONTRIBUTORS.txt | 1 + aiohttp/_http_parser.pyx | 2 +- aiohttp/client_reqrep.py | 12 +++++++++-- aiohttp/compression_utils.py | 31 ++++++++++++++++++++++++++++ aiohttp/http_parser.py | 19 +++++++++++++++--- docs/client_quickstart.rst | 4 ++++ docs/spelling_wordlist.txt | 2 ++ requirements/lint.in | 1 + requirements/lint.txt | 15 +++----------- requirements/runtime-deps.in | 1 + requirements/runtime-deps.txt | 6 +++--- setup.cfg | 1 + tests/test_client_request.py | 17 ++++++++++------ tests/test_http_parser.py | 38 +++++++++++++++++++++++++++++++++++ 16 files changed, 129 insertions(+), 27 deletions(-) create mode 100644 CHANGES/11161.feature.rst diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a0820e12283..df41c1a607d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,10 @@ repos: rev: v1.5.0 hooks: - id: yesqa + additional_dependencies: + - flake8-docstrings==1.6.0 + - flake8-no-implicit-concat==0.3.4 + - flake8-requirements==1.7.8 - repo: https://github.com/PyCQA/isort rev: '6.0.1' hooks: diff --git a/CHANGES/11161.feature.rst b/CHANGES/11161.feature.rst new file mode 100644 index 00000000000..617c4147a38 --- /dev/null +++ b/CHANGES/11161.feature.rst @@ -0,0 +1,2 @@ +Add support for Zstandard (aka Zstd) compression +-- by :user:`KGuillaume-chaps`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index aa1f744b7a3..32c0747adf6 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -218,6 +218,7 @@ Justin Foo Justin Turner Arthur Kay Zheng Kevin Samuel +Kilian Guillaume Kimmo Parviainen-Jalanko Kirill Klenov Kirill Malovitsa diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 6bf125c06e9..f5015b297b0 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -437,7 +437,7 @@ cdef class HttpParser: if enc is not None: self._content_encoding = None enc = enc.lower() - if enc in ('gzip', 'deflate', 'br'): + if enc in ('gzip', 'deflate', 'br', 'zstd'): encoding = enc if self._cparser.type == cparser.HTTP_REQUEST: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index ff5ec9432de..e4d20f136cc 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -44,7 +44,7 @@ InvalidURL, ServerFingerprintMismatch, ) -from .compression_utils import HAS_BROTLI +from .compression_utils import HAS_BROTLI, HAS_ZSTD from .formdata import FormData from .hdrs import CONTENT_TYPE from .helpers import ( @@ -105,7 +105,15 @@ def _gen_default_accept_encoding() -> str: - return "gzip, deflate, br" if HAS_BROTLI else "gzip, deflate" + encodings = [ + "gzip", + "deflate", + ] + if HAS_BROTLI: + encodings.append("br") + if HAS_ZSTD: + encodings.append("zstd") + return ", ".join(encodings) @frozen_dataclass_decorator diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 918b764baf5..9f5562ea1cb 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -21,6 +21,18 @@ except ImportError: HAS_BROTLI = False +if sys.version_info >= (3, 14): + import compression.zstd # noqa: I900 + + HAS_ZSTD = True +else: + try: + import zstandard + + HAS_ZSTD = True + except ImportError: + HAS_ZSTD = False + MAX_SYNC_CHUNK_SIZE = 1024 @@ -276,3 +288,22 @@ def flush(self) -> bytes: if hasattr(self._obj, "flush"): return cast(bytes, self._obj.flush()) return b"" + + +class ZSTDDecompressor: + def __init__(self) -> None: + if not HAS_ZSTD: + raise RuntimeError( + "The zstd decompression is not available. " + "Please install `zstandard` module" + ) + if sys.version_info >= (3, 14): + self._obj = compression.zstd.ZstdDecompressor() + else: + self._obj = zstandard.ZstdDecompressor() + + def decompress_sync(self, data: bytes) -> bytes: + return self._obj.decompress(data) + + def flush(self) -> bytes: + return b"" diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 8805319e309..fa0328830a2 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -26,7 +26,13 @@ from . import hdrs from .base_protocol import BaseProtocol -from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor +from .compression_utils import ( + HAS_BROTLI, + HAS_ZSTD, + BrotliDecompressor, + ZLibDecompressor, + ZSTDDecompressor, +) from .helpers import ( _EXC_SENTINEL, DEBUG, @@ -527,7 +533,7 @@ def parse_headers( enc = headers.get(hdrs.CONTENT_ENCODING) if enc: enc = enc.lower() - if enc in ("gzip", "deflate", "br"): + if enc in ("gzip", "deflate", "br", "zstd"): encoding = enc # chunking @@ -930,7 +936,7 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: self.encoding = encoding self._started_decoding = False - self.decompressor: Union[BrotliDecompressor, ZLibDecompressor] + self.decompressor: Union[BrotliDecompressor, ZLibDecompressor, ZSTDDecompressor] if encoding == "br": if not HAS_BROTLI: raise ContentEncodingError( @@ -938,6 +944,13 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: "Please install `Brotli`" ) self.decompressor = BrotliDecompressor() + elif encoding == "zstd": + if not HAS_ZSTD: + raise ContentEncodingError( + "Can not decode content-encoding: zstandard (zstd). " + "Please install `zstandard`" + ) + self.decompressor = ZSTDDecompressor() else: self.decompressor = ZLibDecompressor(encoding=encoding) diff --git a/docs/client_quickstart.rst b/docs/client_quickstart.rst index 95d5b6bf2c4..48f123b94bd 100644 --- a/docs/client_quickstart.rst +++ b/docs/client_quickstart.rst @@ -190,6 +190,10 @@ You can enable ``brotli`` transfer-encodings support, just install `Brotli `_ or `brotlicffi `_. +You can enable ``zstd`` transfer-encodings support, +install `zstandard `_. +If you are using Python >= 3.14, no dependency should be required. + JSON Request ============ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 79a2b5075f5..c3863fb41c5 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -382,3 +382,5 @@ www xxx yarl zlib +zstandard +zstd diff --git a/requirements/lint.in b/requirements/lint.in index 21a9fb4e0f4..e36b5f28cd6 100644 --- a/requirements/lint.in +++ b/requirements/lint.in @@ -14,3 +14,4 @@ trustme uvloop; platform_system != "Windows" valkey zlib_ng +zstandard; implementation_name == "cpython" diff --git a/requirements/lint.txt b/requirements/lint.txt index bf329ee199a..02a73831cb7 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.10 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/lint.txt --strip-extras requirements/lint.in @@ -8,8 +8,6 @@ aiodns==3.5.0 # via -r requirements/lint.in annotated-types==0.7.0 # via pydantic -async-timeout==5.0.1 - # via valkey blockbuster==1.5.24 # via -r requirements/lint.in cffi==1.17.1 @@ -25,8 +23,6 @@ cryptography==45.0.4 # via trustme distlib==0.3.9 # via virtualenv -exceptiongroup==1.3.0 - # via pytest filelock==3.18.0 # via virtualenv forbiddenfruit==0.1.4 @@ -94,21 +90,14 @@ six==1.17.0 # via python-dateutil slotscheck==0.19.1 # via -r requirements/lint.in -tomli==2.2.1 - # via - # mypy - # pytest - # slotscheck trustme==1.2.1 # via -r requirements/lint.in typing-extensions==4.14.0 # via - # exceptiongroup # mypy # pydantic # pydantic-core # python-on-whales - # rich # typing-inspection typing-inspection==0.4.1 # via pydantic @@ -120,3 +109,5 @@ virtualenv==20.31.2 # via pre-commit zlib-ng==0.5.1 # via -r requirements/lint.in +zstandard==0.23.0 ; implementation_name == "cpython" + # via -r requirements/lint.in diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index 268ace3c9c7..cf0951e7276 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -10,3 +10,4 @@ frozenlist >= 1.1.1 multidict >=4.5, < 7.0 propcache >= 0.2.0 yarl >= 1.17.0, < 2.0 +zstandard; platform_python_implementation == 'CPython' and python_version < "3.14" diff --git a/requirements/runtime-deps.txt b/requirements/runtime-deps.txt index a9e91e0bf9e..5fb79a21aa3 100644 --- a/requirements/runtime-deps.txt +++ b/requirements/runtime-deps.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.10 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements/runtime-deps.txt --strip-extras requirements/runtime-deps.in @@ -10,8 +10,6 @@ aiohappyeyeballs==2.6.1 # via -r requirements/runtime-deps.in aiosignal==1.3.2 # via -r requirements/runtime-deps.in -async-timeout==5.0.1 ; python_version < "3.11" - # via -r requirements/runtime-deps.in brotli==1.1.0 ; platform_python_implementation == "CPython" # via -r requirements/runtime-deps.in cffi==1.17.1 @@ -38,3 +36,5 @@ typing-extensions==4.14.0 # via multidict yarl==1.20.1 # via -r requirements/runtime-deps.in +zstandard==0.23.0 ; platform_python_implementation == "CPython" and python_version < "3.14" + # via -r requirements/runtime-deps.in diff --git a/setup.cfg b/setup.cfg index c4ab069f396..21a8ca2e44f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,7 @@ speedups = aiodns >= 3.3.0 Brotli; platform_python_implementation == 'CPython' brotlicffi; platform_python_implementation != 'CPython' + zstandard; platform_python_implementation == 'CPython' and python_version < "3.14" [options.packages.find] exclude = diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 23b27556ab1..1852ebb1c81 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -358,7 +358,7 @@ def test_headers(make_request: _RequestMaker) -> None: assert hdrs.CONTENT_TYPE in req.headers assert req.headers[hdrs.CONTENT_TYPE] == "text/plain" - assert req.headers[hdrs.ACCEPT_ENCODING] == "gzip, deflate, br" + assert "gzip" in req.headers[hdrs.ACCEPT_ENCODING] def test_headers_list(make_request: _RequestMaker) -> None: @@ -1568,15 +1568,20 @@ def test_loose_cookies_types(loop: asyncio.AbstractEventLoop) -> None: @pytest.mark.parametrize( - "has_brotli,expected", + "has_brotli,has_zstd,expected", [ - (False, "gzip, deflate"), - (True, "gzip, deflate, br"), + (False, False, "gzip, deflate"), + (True, False, "gzip, deflate, br"), + (False, True, "gzip, deflate, zstd"), + (True, True, "gzip, deflate, br, zstd"), ], ) -def test_gen_default_accept_encoding(has_brotli: bool, expected: str) -> None: +def test_gen_default_accept_encoding( + has_brotli: bool, has_zstd: bool, expected: str +) -> None: with mock.patch("aiohttp.client_reqrep.HAS_BROTLI", has_brotli): - assert _gen_default_accept_encoding() == expected + with mock.patch("aiohttp.client_reqrep.HAS_ZSTD", has_zstd): + assert _gen_default_accept_encoding() == expected @pytest.mark.parametrize( diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 6bb06159f21..fd7a52f0b88 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -2,6 +2,7 @@ import asyncio import re +import sys from contextlib import suppress from typing import Any, Dict, Iterable, List, Type from unittest import mock @@ -34,6 +35,14 @@ except ImportError: brotli = None +if sys.version_info >= (3, 14): + import compression.zstd as zstandard # noqa: I900 +else: + try: + import zstandard + except ImportError: + zstandard = None # type: ignore[assignment] + REQUEST_PARSERS = [HttpRequestParserPy] RESPONSE_PARSERS = [HttpResponseParserPy] @@ -600,6 +609,14 @@ def test_compression_brotli(parser: HttpRequestParser) -> None: assert msg.compression == "br" +@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") +def test_compression_zstd(parser: HttpRequestParser) -> None: + text = b"GET /test HTTP/1.1\r\ncontent-encoding: zstd\r\n\r\n" + messages, upgrade, tail = parser.feed_data(text) + msg = messages[0][0] + assert msg.compression == "zstd" + + def test_compression_unknown(parser: HttpRequestParser) -> None: text = b"GET /test HTTP/1.1\r\ncontent-encoding: compress\r\n\r\n" messages, upgrade, tail = parser.feed_data(text) @@ -1849,6 +1866,15 @@ async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: assert b"brotli data" == out._buffer[0] assert out.is_eof() + @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") + async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None: + compressed = zstandard.compress(b"zstd data") + out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + p = HttpPayloadParser(out, length=len(compressed), compression="zstd") + p.feed_data(compressed) + assert b"zstd data" == out._buffer[0] + assert out.is_eof() + class TestDeflateBuffer: async def test_feed_data(self, protocol: BaseProtocol) -> None: @@ -1919,6 +1945,18 @@ async def test_feed_eof_no_err_brotli(self, protocol: BaseProtocol) -> None: dbuf.feed_eof() assert [b"line"] == list(buf._buffer) + @pytest.mark.skipif(zstandard is None, reason="zstandard is not installed") + async def test_feed_eof_no_err_zstandard(self, protocol: BaseProtocol) -> None: + buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) + dbuf = DeflateBuffer(buf, "zstd") + + dbuf.decompressor = mock.Mock() + dbuf.decompressor.flush.return_value = b"line" + dbuf.decompressor.eof = False + + dbuf.feed_eof() + assert [b"line"] == list(buf._buffer) + async def test_empty_body(self, protocol: BaseProtocol) -> None: buf = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) dbuf = DeflateBuffer(buf, "deflate")