From a5f520552d95fd0b348c9183cd51b38869a49334 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 28 Feb 2025 14:46:50 +0000 Subject: [PATCH] Fixes for strict-bytes (#10454) --- .mypy.ini | 1 + aiohttp/_websocket/reader_py.py | 4 ++-- aiohttp/_websocket/writer.py | 6 +++--- aiohttp/abc.py | 4 +++- aiohttp/compression_utils.py | 18 +++++++++++++----- aiohttp/http_writer.py | 24 ++++++++++++++++++------ aiohttp/web_response.py | 8 +++++--- aiohttp/web_ws.py | 4 +++- 8 files changed, 48 insertions(+), 21 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index 26971cf2bda..2167434fff4 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -20,6 +20,7 @@ pretty = True show_column_numbers = True show_error_codes = True show_error_code_links = True +strict_bytes = True strict_equality = True warn_incomplete_stub = True warn_redundant_casts = True diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index aa4b6ba2704..711c9c8bbbe 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -309,11 +309,11 @@ def _feed_data(self, data: bytes) -> None: self.queue.feed_data(msg) elif opcode == OP_CODE_PING: self.queue.feed_data( - WSMessagePing(data=payload, size=len(payload), extra="") + WSMessagePing(data=bytes(payload), size=len(payload), extra="") ) elif opcode == OP_CODE_PONG: self.queue.feed_data( - WSMessagePong(data=payload, size=len(payload), extra="") + WSMessagePong(data=bytes(payload), size=len(payload), extra="") ) else: raise WebSocketError( diff --git a/aiohttp/_websocket/writer.py b/aiohttp/_websocket/writer.py index fc2cf32b934..ea4962e86fa 100644 --- a/aiohttp/_websocket/writer.py +++ b/aiohttp/_websocket/writer.py @@ -133,9 +133,9 @@ async def send_frame( # when aiohttp is acting as a client. Servers do not use a mask. if use_mask: mask = PACK_RANDBITS(self.get_random_bits()) - message = bytearray(message) - websocket_mask(mask, message) - self.transport.write(header + mask + message) + message_arr = bytearray(message) + websocket_mask(mask, message_arr) + self.transport.write(header + mask + message_arr) self._output_size += MASK_LEN elif msg_length > MSG_SIZE: self.transport.write(header) diff --git a/aiohttp/abc.py b/aiohttp/abc.py index 498eace04d7..7ff3fee73e8 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -202,7 +202,9 @@ class AbstractStreamWriter(ABC): length: Optional[int] = 0 @abstractmethod - async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None: + async def write( + self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + ) -> None: """Write chunk into stream.""" @abstractmethod diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 7efba7600a3..43460af5a27 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -1,8 +1,16 @@ import asyncio +import sys import zlib from concurrent.futures import Executor from typing import Optional, cast +if sys.version_info >= (3, 12): + from collections.abc import Buffer +else: + from typing import Union + + Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + try: try: import brotlicffi as brotli @@ -66,10 +74,10 @@ def __init__( ) self._compress_lock = asyncio.Lock() - def compress_sync(self, data: bytes) -> bytes: + def compress_sync(self, data: Buffer) -> bytes: return self._compressor.compress(data) - async def compress(self, data: bytes) -> bytes: + async def compress(self, data: Buffer) -> bytes: """Compress the data and returned the compressed bytes. Note that flush() must be called after the last call to compress() @@ -111,10 +119,10 @@ def __init__( ) self._decompressor = zlib.decompressobj(wbits=self._mode) - def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes: + def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes: return self._decompressor.decompress(data, max_length) - async def decompress(self, data: bytes, max_length: int = 0) -> bytes: + async def decompress(self, data: Buffer, max_length: int = 0) -> bytes: """Decompress the data and return the decompressed bytes. If the data size is large than the max_sync_chunk_size, the decompression @@ -162,7 +170,7 @@ def __init__(self) -> None: ) self._obj = brotli.Decompressor() - def decompress_sync(self, data: bytes) -> bytes: + def decompress_sync(self, data: Buffer) -> bytes: if hasattr(self._obj, "decompress"): return cast(bytes, self._obj.decompress(data)) return cast(bytes, self._obj.process(data)) diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index e031a97708d..f9b0e9b2268 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -45,7 +45,12 @@ class HttpVersion(NamedTuple): HttpVersion11 = HttpVersion(1, 1) -_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] +_T_OnChunkSent = Optional[ + Callable[ + [Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]], + Awaitable[None], + ] +] _T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]] @@ -84,16 +89,23 @@ def enable_compression( ) -> None: self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) - def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None: + def _write( + self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + ) -> None: size = len(chunk) self.buffer_size += size self.output_size += size transport = self._protocol.transport if transport is None or transport.is_closing(): raise ClientConnectionResetError("Cannot write to closing transport") - transport.write(chunk) + transport.write(chunk) # type: ignore[arg-type] - def _writelines(self, chunks: Iterable[bytes]) -> None: + def _writelines( + self, + chunks: Iterable[ + Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + ], + ) -> None: size = 0 for chunk in chunks: size += len(chunk) @@ -105,11 +117,11 @@ def _writelines(self, chunks: Iterable[bytes]) -> None: if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES: transport.write(b"".join(chunks)) else: - transport.writelines(chunks) + transport.writelines(chunks) # type: ignore[arg-type] async def write( self, - chunk: Union[bytes, bytearray, memoryview], + chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"], *, drain: bool = True, LIMIT: int = 0x10000, diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 88a1d145da0..d1bb401a5e6 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -435,7 +435,9 @@ async def _write_headers(self) -> None: status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}" await writer.write_headers(status_line, self._headers) - async def write(self, data: Union[bytes, bytearray, memoryview]) -> None: + async def write( + self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + ) -> None: assert isinstance( data, (bytes, bytearray, memoryview) ), "data argument must be byte-ish (%r)" % type(data) @@ -580,7 +582,7 @@ def __init__( self._zlib_executor = zlib_executor @property - def body(self) -> Optional[Union[bytes, Payload]]: + def body(self) -> Optional[Union[bytes, bytearray, Payload]]: return self._body @body.setter @@ -654,7 +656,7 @@ async def write_eof(self, data: bytes = b"") -> None: if self._eof_sent: return if self._compressed_body is None: - body: Optional[Union[bytes, Payload]] = self._body + body = self._body else: body = self._compressed_body assert not data, f"data arg is not supported, got {data!r}" diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 813002e7871..78c130179f5 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -606,7 +606,9 @@ async def receive_json( data = await self.receive_str(timeout=timeout) return loads(data) - async def write(self, data: bytes) -> None: + async def write( + self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + ) -> None: raise RuntimeError("Cannot call .write() for websocket") def __aiter__(self) -> "WebSocketResponse":