From 77ad7d7ea173eda1306297d275b2d5f7348f9f60 Mon Sep 17 00:00:00 2001 From: layday Date: Sun, 30 Mar 2025 22:31:24 +0200 Subject: [PATCH 1/5] Replace deprecated `asyncio.iscoroutinefunction` with its counterpart from `inspect` (#10634) --- CHANGES/10634.bugfix.rst | 2 ++ aiohttp/pytest_plugin.py | 6 +++--- aiohttp/web_urldispatcher.py | 9 ++++++--- aiohttp/worker.py | 5 ++++- 4 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 CHANGES/10634.bugfix.rst diff --git a/CHANGES/10634.bugfix.rst b/CHANGES/10634.bugfix.rst new file mode 100644 index 00000000000..d6ec64a607e --- /dev/null +++ b/CHANGES/10634.bugfix.rst @@ -0,0 +1,2 @@ +Replaced deprecated ``asyncio.iscoroutinefunction`` with its counterpart from ``inspect`` +-- by :user:`layday`. diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index dcac5296908..9d11231c6f4 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -106,7 +106,7 @@ def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def] if inspect.isasyncgenfunction(func): # async generator fixture is_async_gen = True - elif asyncio.iscoroutinefunction(func): + elif inspect.iscoroutinefunction(func): # regular async fixture is_async_gen = False else: @@ -216,14 +216,14 @@ def _passthrough_loop_context( def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def] """Fix pytest collecting for coroutines.""" - if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj): + if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj): return list(collector._genfunctions(name, obj)) def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def] """Run coroutines in an event loop instead of a normal function call.""" fast = pyfuncitem.config.getoption("--aiohttp-fast") - if asyncio.iscoroutinefunction(pyfuncitem.function): + if inspect.iscoroutinefunction(pyfuncitem.function): existing_loop = pyfuncitem.funcargs.get( "proactor_loop" ) or pyfuncitem.funcargs.get("loop", None) diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 0d3cbb26398..0ed199f7e59 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -4,6 +4,7 @@ import functools import hashlib import html +import inspect import keyword import os import re @@ -174,15 +175,17 @@ def __init__( if expect_handler is None: expect_handler = _default_expect_handler - assert asyncio.iscoroutinefunction( - expect_handler + assert inspect.iscoroutinefunction(expect_handler) or ( + sys.version_info < (3, 14) and asyncio.iscoroutinefunction(expect_handler) ), f"Coroutine is expected, got {expect_handler!r}" method = method.upper() if not HTTP_METHOD_RE.match(method): raise ValueError(f"{method} is not allowed HTTP method") - if asyncio.iscoroutinefunction(handler): + if inspect.iscoroutinefunction(handler) or ( + sys.version_info < (3, 14) and asyncio.iscoroutinefunction(handler) + ): pass elif isinstance(handler, type) and issubclass(handler, AbstractView): pass diff --git a/aiohttp/worker.py b/aiohttp/worker.py index cc100d59faa..d4f062062a5 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -1,6 +1,7 @@ """Async gunicorn worker for aiohttp.web""" import asyncio +import inspect import os import re import signal @@ -68,7 +69,9 @@ async def _run(self) -> None: runner = None if isinstance(self.wsgi, Application): app = self.wsgi - elif asyncio.iscoroutinefunction(self.wsgi): + elif inspect.iscoroutinefunction(self.wsgi) or ( + sys.version_info < (3, 14) and asyncio.iscoroutinefunction(self.wsgi) + ): wsgi = await self.wsgi() if isinstance(wsgi, web.AppRunner): runner = wsgi From 4599b87f44569079942542c99c46779ca6e8bef7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 30 Mar 2025 11:03:47 -1000 Subject: [PATCH 2/5] Improve performance of serializing headers (#10625) Improve performance of serializing headers by moving the check for `\r` and `\n` into the write loop instead of making a separate call to check each disallowed character in the Python string. --- CHANGES/10625.misc.rst | 1 + aiohttp/_http_writer.pyx | 38 ++++++++++++++++++-------------------- tests/test_http_writer.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 21 deletions(-) create mode 100644 CHANGES/10625.misc.rst diff --git a/CHANGES/10625.misc.rst b/CHANGES/10625.misc.rst new file mode 100644 index 00000000000..30cd7f0f3a6 --- /dev/null +++ b/CHANGES/10625.misc.rst @@ -0,0 +1 @@ +Improved performance of serializing headers -- by :user:`bdraco`. diff --git a/aiohttp/_http_writer.pyx b/aiohttp/_http_writer.pyx index 287371334f8..4a3ae1f9e68 100644 --- a/aiohttp/_http_writer.pyx +++ b/aiohttp/_http_writer.pyx @@ -97,27 +97,34 @@ cdef inline int _write_str(Writer* writer, str s): return -1 -# --------------- _serialize_headers ---------------------- - -cdef str to_str(object s): +cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s): + cdef Py_UCS4 ch + cdef str out_str if type(s) is str: - return s + out_str = s elif type(s) is _istr: - return PyObject_Str(s) + out_str = PyObject_Str(s) elif not isinstance(s, str): raise TypeError("Cannot serialize non-str key {!r}".format(s)) else: - return str(s) + out_str = str(s) + + for ch in out_str: + if ch == 0x0D or ch == 0x0A: + raise ValueError( + "Newline or carriage return detected in headers. " + "Potential header injection attack." + ) + if _write_utf8(writer, ch) < 0: + return -1 +# --------------- _serialize_headers ---------------------- def _serialize_headers(str status_line, headers): cdef Writer writer cdef object key cdef object val - cdef bytes ret - cdef str key_str - cdef str val_str _init_writer(&writer) @@ -130,22 +137,13 @@ def _serialize_headers(str status_line, headers): raise for key, val in headers.items(): - key_str = to_str(key) - val_str = to_str(val) - - if "\r" in key_str or "\n" in key_str or "\r" in val_str or "\n" in val_str: - raise ValueError( - "Newline or carriage return character detected in HTTP status message or " - "header. This is a potential security issue." - ) - - if _write_str(&writer, key_str) < 0: + if _write_str_raise_on_nlcr(&writer, key) < 0: raise if _write_byte(&writer, b':') < 0: raise if _write_byte(&writer, b' ') < 0: raise - if _write_str(&writer, val_str) < 0: + if _write_str_raise_on_nlcr(&writer, val) < 0: raise if _write_byte(&writer, b'\r') < 0: raise diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 71a5bb93d13..7032e1417b5 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -8,8 +8,9 @@ import pytest from multidict import CIMultiDict -from aiohttp import ClientConnectionResetError, http +from aiohttp import ClientConnectionResetError, hdrs, http from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_writer import _serialize_headers from aiohttp.test_utils import make_mocked_coro @@ -603,3 +604,29 @@ async def test_set_eof_after_write_headers( msg.set_eof() await msg.write_eof() assert not transport.write.called + + +@pytest.mark.parametrize( + "char", + [ + "\n", + "\r", + ], +) +def test_serialize_headers_raises_on_new_line_or_carriage_return(char: str) -> None: + """Verify serialize_headers raises on cr or nl in the headers.""" + status_line = "HTTP/1.1 200 OK" + headers = CIMultiDict( + { + hdrs.CONTENT_TYPE: f"text/plain{char}", + } + ) + + with pytest.raises( + ValueError, + match=( + "Newline or carriage return detected in headers. " + "Potential header injection attack." + ), + ): + _serialize_headers(status_line, headers) From 8ac483068ea24f6a709b3ead51ec87e3660a3b24 Mon Sep 17 00:00:00 2001 From: Tim Menninger Date: Sun, 30 Mar 2025 14:05:09 -0700 Subject: [PATCH 3/5] Docs fixups following implement socket factory (#10534) (#10568) --- CHANGES/10474.feature.rst | 1 + docs/client_advanced.rst | 4 ++-- docs/client_reference.rst | 18 +++++++----------- docs/conf.py | 3 ++- 4 files changed, 12 insertions(+), 14 deletions(-) create mode 120000 CHANGES/10474.feature.rst diff --git a/CHANGES/10474.feature.rst b/CHANGES/10474.feature.rst new file mode 120000 index 00000000000..7c4f9a7b83b --- /dev/null +++ b/CHANGES/10474.feature.rst @@ -0,0 +1 @@ +10520.feature.rst \ No newline at end of file diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 4b0a878d715..0f6eb99974b 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -472,8 +472,8 @@ Custom socket creation ^^^^^^^^^^^^^^^^^^^^^^ If the default socket is insufficient for your use case, pass an optional -`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements -`SocketFactoryType`. This will be used to create all sockets for the +``socket_factory`` to the :class:`~aiohttp.TCPConnector`, which implements +:class:`SocketFactoryType`. This will be used to create all sockets for the lifetime of the class object. For example, we may want to change the conditions under which we consider a connection dead. The following would make all sockets respect 9*7200 = 18 hours:: diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 7dabfe1a6db..147ec7dce2d 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1122,9 +1122,7 @@ is controlled by *force_close* constructor's parameter). overridden in subclasses. -.. autodata:: AddrInfoType - -.. note:: +.. py:class:: AddrInfoType Refer to :py:data:`aiohappyeyeballs.AddrInfoType` for more info. @@ -1132,13 +1130,11 @@ is controlled by *force_close* constructor's parameter). Be sure to use ``aiohttp.AddrInfoType`` rather than ``aiohappyeyeballs.AddrInfoType`` to avoid import breakage, as - it is likely to be removed from ``aiohappyeyeballs`` in the + it is likely to be removed from :mod:`aiohappyeyeballs` in the future. -.. autodata:: SocketFactoryType - -.. note:: +.. py:class:: SocketFactoryType Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` for more info. @@ -1146,7 +1142,7 @@ is controlled by *force_close* constructor's parameter). Be sure to use ``aiohttp.SocketFactoryType`` rather than ``aiohappyeyeballs.SocketFactoryType`` to avoid import breakage, - as it is likely to be removed from ``aiohappyeyeballs`` in the + as it is likely to be removed from :mod:`aiohappyeyeballs` in the future. @@ -1278,9 +1274,9 @@ is controlled by *force_close* constructor's parameter). .. versionadded:: 3.10 - :param :py:data:``SocketFactoryType`` socket_factory: This function takes an - :py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when - creating TCP connections. + :param SocketFactoryType socket_factory: This function takes an + :py:data:`AddrInfoType` and is used in lieu of + :py:func:`socket.socket` when creating TCP connections. .. versionadded:: 3.12 diff --git a/docs/conf.py b/docs/conf.py index eba93188b44..15de3598c7e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,7 +83,7 @@ "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), - "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None), + "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/latest/", None), } # Add any paths that contain templates here, relative to this directory. @@ -419,6 +419,7 @@ ("py:class", "aiohttp.web.MatchedSubAppResource"), # undocumented ("py:attr", "body"), # undocumented ("py:class", "socket.socket"), # undocumented + ("py:func", "socket.socket"), # undocumented ("py:class", "socket.AddressFamily"), # undocumented ("py:obj", "logging.DEBUG"), # undocumented ("py:class", "aiohttp.abc.AbstractAsyncAccessLogger"), # undocumented From f7cac7e63f18691e4261af353e84f9073b16624a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 30 Mar 2025 11:11:10 -1000 Subject: [PATCH 4/5] Reduce WebSocket buffer slicing overhead (#10601) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What do these changes do? Use a `const unsigned char *` for the buffer (Cython will automatically extract is using `__Pyx_PyBytes_AsUString`) as its a lot faster than copying around `PyBytes` objects. We do need to be careful that all slices are bounded and we bound check everything to make sure we do not do an out of bounds read since Cython does not bounds check C strings. I checked that all accesses to `buf_cstr` are proceeded by a bounds check but it would be good to get another set of eyes on that to verify in the `self._state == READ_PAYLOAD` block that we will never try to read out of bounds. Screenshot 2025-03-19 at 10 21 54 AM ## Are there changes in behavior for the user? performance improvement ## Is it a substantial burden for the maintainers to support this? no There is a small risk that someone could remove a bounds check in the future and create a memory safety issue, however in this case its likely we would already be trying to read data that wasn't there if we are missing the bounds checking so the pure python version would throw if we are testing properly. --------- Co-authored-by: Sam Bull --- CHANGES/10601.misc.rst | 1 + aiohttp/_websocket/reader_c.pxd | 1 + aiohttp/_websocket/reader_py.py | 20 +++++++++++--------- 3 files changed, 13 insertions(+), 9 deletions(-) create mode 100644 CHANGES/10601.misc.rst diff --git a/CHANGES/10601.misc.rst b/CHANGES/10601.misc.rst new file mode 100644 index 00000000000..c0d21082724 --- /dev/null +++ b/CHANGES/10601.misc.rst @@ -0,0 +1 @@ +Improved performance of WebSocket buffer handling -- by :user:`bdraco`. diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index 02b4efa6557..07a7d979553 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -99,6 +99,7 @@ cdef class WebSocketReader: chunk_size="unsigned int", chunk_len="unsigned int", buf_length="unsigned int", + buf_cstr="const unsigned char *", first_byte="unsigned char", second_byte="unsigned char", end_pos="unsigned int", diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 201b5e84de2..8e7c7972995 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -333,14 +333,15 @@ def parse_frame( start_pos: int = 0 buf_length = len(buf) + buf_cstr = buf while True: # read header if self._state == READ_HEADER: if buf_length - start_pos < 2: break - first_byte = buf[start_pos] - second_byte = buf[start_pos + 1] + first_byte = buf_cstr[start_pos] + second_byte = buf_cstr[start_pos + 1] start_pos += 2 fin = (first_byte >> 7) & 1 @@ -405,14 +406,14 @@ def parse_frame( if length_flag == 126: if buf_length - start_pos < 2: break - first_byte = buf[start_pos] - second_byte = buf[start_pos + 1] + first_byte = buf_cstr[start_pos] + second_byte = buf_cstr[start_pos + 1] start_pos += 2 self._payload_length = first_byte << 8 | second_byte elif length_flag > 126: if buf_length - start_pos < 8: break - data = buf[start_pos : start_pos + 8] + data = buf_cstr[start_pos : start_pos + 8] start_pos += 8 self._payload_length = UNPACK_LEN3(data)[0] else: @@ -424,7 +425,7 @@ def parse_frame( if self._state == READ_PAYLOAD_MASK: if buf_length - start_pos < 4: break - self._frame_mask = buf[start_pos : start_pos + 4] + self._frame_mask = buf_cstr[start_pos : start_pos + 4] start_pos += 4 self._state = READ_PAYLOAD @@ -440,10 +441,10 @@ def parse_frame( if self._frame_payload_len: if type(self._frame_payload) is not bytearray: self._frame_payload = bytearray(self._frame_payload) - self._frame_payload += buf[start_pos:end_pos] + self._frame_payload += buf_cstr[start_pos:end_pos] else: # Fast path for the first frame - self._frame_payload = buf[start_pos:end_pos] + self._frame_payload = buf_cstr[start_pos:end_pos] self._frame_payload_len += end_pos - start_pos start_pos = end_pos @@ -469,6 +470,7 @@ def parse_frame( self._frame_payload_len = 0 self._state = READ_HEADER - self._tail = buf[start_pos:] if start_pos < buf_length else b"" + # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. + self._tail = buf_cstr[start_pos:buf_length] if start_pos < buf_length else b"" return frames From caa5792a55e6a380cbb27d907d7d09e8785b7312 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 30 Mar 2025 11:57:38 -1000 Subject: [PATCH 5/5] Convert format calls to f-strings in WebSocket reader (#10638) Small code cleanup --- aiohttp/_websocket/reader_py.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 8e7c7972995..52d6b83925f 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -199,9 +199,8 @@ def _feed_data(self, data: bytes) -> None: if self._max_msg_size and len(self._partial) >= self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(self._partial), self._max_msg_size - ), + f"Message size {len(self._partial)} " + f"exceeds limit {self._max_msg_size}", ) continue @@ -220,7 +219,7 @@ def _feed_data(self, data: bytes) -> None: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "The opcode in non-fin frame is expected " - "to be zero, got {!r}".format(opcode), + f"to be zero, got {opcode!r}", ) assembled_payload: Union[bytes, bytearray] @@ -233,9 +232,8 @@ def _feed_data(self, data: bytes) -> None: if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(assembled_payload), self._max_msg_size - ), + f"Message size {len(assembled_payload)} " + f"exceeds limit {self._max_msg_size}", ) # Decompress process must to be done after all packets @@ -252,9 +250,8 @@ def _feed_data(self, data: bytes) -> None: left = len(self._decompressobj.unconsumed_tail) raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - "Decompressed message size {} exceeds limit {}".format( - self._max_msg_size + left, self._max_msg_size - ), + f"Decompressed message size {self._max_msg_size + left}" + f" exceeds limit {self._max_msg_size}", ) elif type(assembled_payload) is bytes: payload_merged = assembled_payload