diff --git a/CHANGES/10474.feature.rst b/CHANGES/10474.feature.rst new file mode 120000 index 00000000000..7c4f9a7b83b --- /dev/null +++ b/CHANGES/10474.feature.rst @@ -0,0 +1 @@ +10520.feature.rst \ No newline at end of file diff --git a/CHANGES/10601.misc.rst b/CHANGES/10601.misc.rst new file mode 100644 index 00000000000..c0d21082724 --- /dev/null +++ b/CHANGES/10601.misc.rst @@ -0,0 +1 @@ +Improved performance of WebSocket buffer handling -- by :user:`bdraco`. diff --git a/CHANGES/10625.misc.rst b/CHANGES/10625.misc.rst new file mode 100644 index 00000000000..30cd7f0f3a6 --- /dev/null +++ b/CHANGES/10625.misc.rst @@ -0,0 +1 @@ +Improved performance of serializing headers -- by :user:`bdraco`. diff --git a/CHANGES/10634.bugfix.rst b/CHANGES/10634.bugfix.rst new file mode 100644 index 00000000000..d6ec64a607e --- /dev/null +++ b/CHANGES/10634.bugfix.rst @@ -0,0 +1,2 @@ +Replaced deprecated ``asyncio.iscoroutinefunction`` with its counterpart from ``inspect`` +-- by :user:`layday`. diff --git a/aiohttp/_http_writer.pyx b/aiohttp/_http_writer.pyx index 287371334f8..4a3ae1f9e68 100644 --- a/aiohttp/_http_writer.pyx +++ b/aiohttp/_http_writer.pyx @@ -97,27 +97,34 @@ cdef inline int _write_str(Writer* writer, str s): return -1 -# --------------- _serialize_headers ---------------------- - -cdef str to_str(object s): +cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s): + cdef Py_UCS4 ch + cdef str out_str if type(s) is str: - return s + out_str = s elif type(s) is _istr: - return PyObject_Str(s) + out_str = PyObject_Str(s) elif not isinstance(s, str): raise TypeError("Cannot serialize non-str key {!r}".format(s)) else: - return str(s) + out_str = str(s) + + for ch in out_str: + if ch == 0x0D or ch == 0x0A: + raise ValueError( + "Newline or carriage return detected in headers. " + "Potential header injection attack." + ) + if _write_utf8(writer, ch) < 0: + return -1 +# --------------- _serialize_headers ---------------------- def _serialize_headers(str status_line, headers): cdef Writer writer cdef object key cdef object val - cdef bytes ret - cdef str key_str - cdef str val_str _init_writer(&writer) @@ -130,22 +137,13 @@ def _serialize_headers(str status_line, headers): raise for key, val in headers.items(): - key_str = to_str(key) - val_str = to_str(val) - - if "\r" in key_str or "\n" in key_str or "\r" in val_str or "\n" in val_str: - raise ValueError( - "Newline or carriage return character detected in HTTP status message or " - "header. This is a potential security issue." - ) - - if _write_str(&writer, key_str) < 0: + if _write_str_raise_on_nlcr(&writer, key) < 0: raise if _write_byte(&writer, b':') < 0: raise if _write_byte(&writer, b' ') < 0: raise - if _write_str(&writer, val_str) < 0: + if _write_str_raise_on_nlcr(&writer, val) < 0: raise if _write_byte(&writer, b'\r') < 0: raise diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index 02b4efa6557..07a7d979553 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -99,6 +99,7 @@ cdef class WebSocketReader: chunk_size="unsigned int", chunk_len="unsigned int", buf_length="unsigned int", + buf_cstr="const unsigned char *", first_byte="unsigned char", second_byte="unsigned char", end_pos="unsigned int", diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 201b5e84de2..52d6b83925f 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -199,9 +199,8 @@ def _feed_data(self, data: bytes) -> None: if self._max_msg_size and len(self._partial) >= self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(self._partial), self._max_msg_size - ), + f"Message size {len(self._partial)} " + f"exceeds limit {self._max_msg_size}", ) continue @@ -220,7 +219,7 @@ def _feed_data(self, data: bytes) -> None: raise WebSocketError( WSCloseCode.PROTOCOL_ERROR, "The opcode in non-fin frame is expected " - "to be zero, got {!r}".format(opcode), + f"to be zero, got {opcode!r}", ) assembled_payload: Union[bytes, bytearray] @@ -233,9 +232,8 @@ def _feed_data(self, data: bytes) -> None: if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(assembled_payload), self._max_msg_size - ), + f"Message size {len(assembled_payload)} " + f"exceeds limit {self._max_msg_size}", ) # Decompress process must to be done after all packets @@ -252,9 +250,8 @@ def _feed_data(self, data: bytes) -> None: left = len(self._decompressobj.unconsumed_tail) raise WebSocketError( WSCloseCode.MESSAGE_TOO_BIG, - "Decompressed message size {} exceeds limit {}".format( - self._max_msg_size + left, self._max_msg_size - ), + f"Decompressed message size {self._max_msg_size + left}" + f" exceeds limit {self._max_msg_size}", ) elif type(assembled_payload) is bytes: payload_merged = assembled_payload @@ -333,14 +330,15 @@ def parse_frame( start_pos: int = 0 buf_length = len(buf) + buf_cstr = buf while True: # read header if self._state == READ_HEADER: if buf_length - start_pos < 2: break - first_byte = buf[start_pos] - second_byte = buf[start_pos + 1] + first_byte = buf_cstr[start_pos] + second_byte = buf_cstr[start_pos + 1] start_pos += 2 fin = (first_byte >> 7) & 1 @@ -405,14 +403,14 @@ def parse_frame( if length_flag == 126: if buf_length - start_pos < 2: break - first_byte = buf[start_pos] - second_byte = buf[start_pos + 1] + first_byte = buf_cstr[start_pos] + second_byte = buf_cstr[start_pos + 1] start_pos += 2 self._payload_length = first_byte << 8 | second_byte elif length_flag > 126: if buf_length - start_pos < 8: break - data = buf[start_pos : start_pos + 8] + data = buf_cstr[start_pos : start_pos + 8] start_pos += 8 self._payload_length = UNPACK_LEN3(data)[0] else: @@ -424,7 +422,7 @@ def parse_frame( if self._state == READ_PAYLOAD_MASK: if buf_length - start_pos < 4: break - self._frame_mask = buf[start_pos : start_pos + 4] + self._frame_mask = buf_cstr[start_pos : start_pos + 4] start_pos += 4 self._state = READ_PAYLOAD @@ -440,10 +438,10 @@ def parse_frame( if self._frame_payload_len: if type(self._frame_payload) is not bytearray: self._frame_payload = bytearray(self._frame_payload) - self._frame_payload += buf[start_pos:end_pos] + self._frame_payload += buf_cstr[start_pos:end_pos] else: # Fast path for the first frame - self._frame_payload = buf[start_pos:end_pos] + self._frame_payload = buf_cstr[start_pos:end_pos] self._frame_payload_len += end_pos - start_pos start_pos = end_pos @@ -469,6 +467,7 @@ def parse_frame( self._frame_payload_len = 0 self._state = READ_HEADER - self._tail = buf[start_pos:] if start_pos < buf_length else b"" + # XXX: Cython needs slices to be bounded, so we can't omit the slice end here. + self._tail = buf_cstr[start_pos:buf_length] if start_pos < buf_length else b"" return frames diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index dcac5296908..9d11231c6f4 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -106,7 +106,7 @@ def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def] if inspect.isasyncgenfunction(func): # async generator fixture is_async_gen = True - elif asyncio.iscoroutinefunction(func): + elif inspect.iscoroutinefunction(func): # regular async fixture is_async_gen = False else: @@ -216,14 +216,14 @@ def _passthrough_loop_context( def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def] """Fix pytest collecting for coroutines.""" - if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj): + if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj): return list(collector._genfunctions(name, obj)) def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def] """Run coroutines in an event loop instead of a normal function call.""" fast = pyfuncitem.config.getoption("--aiohttp-fast") - if asyncio.iscoroutinefunction(pyfuncitem.function): + if inspect.iscoroutinefunction(pyfuncitem.function): existing_loop = pyfuncitem.funcargs.get( "proactor_loop" ) or pyfuncitem.funcargs.get("loop", None) diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 0d3cbb26398..0ed199f7e59 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -4,6 +4,7 @@ import functools import hashlib import html +import inspect import keyword import os import re @@ -174,15 +175,17 @@ def __init__( if expect_handler is None: expect_handler = _default_expect_handler - assert asyncio.iscoroutinefunction( - expect_handler + assert inspect.iscoroutinefunction(expect_handler) or ( + sys.version_info < (3, 14) and asyncio.iscoroutinefunction(expect_handler) ), f"Coroutine is expected, got {expect_handler!r}" method = method.upper() if not HTTP_METHOD_RE.match(method): raise ValueError(f"{method} is not allowed HTTP method") - if asyncio.iscoroutinefunction(handler): + if inspect.iscoroutinefunction(handler) or ( + sys.version_info < (3, 14) and asyncio.iscoroutinefunction(handler) + ): pass elif isinstance(handler, type) and issubclass(handler, AbstractView): pass diff --git a/aiohttp/worker.py b/aiohttp/worker.py index cc100d59faa..d4f062062a5 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -1,6 +1,7 @@ """Async gunicorn worker for aiohttp.web""" import asyncio +import inspect import os import re import signal @@ -68,7 +69,9 @@ async def _run(self) -> None: runner = None if isinstance(self.wsgi, Application): app = self.wsgi - elif asyncio.iscoroutinefunction(self.wsgi): + elif inspect.iscoroutinefunction(self.wsgi) or ( + sys.version_info < (3, 14) and asyncio.iscoroutinefunction(self.wsgi) + ): wsgi = await self.wsgi() if isinstance(wsgi, web.AppRunner): runner = wsgi diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 4b0a878d715..0f6eb99974b 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -472,8 +472,8 @@ Custom socket creation ^^^^^^^^^^^^^^^^^^^^^^ If the default socket is insufficient for your use case, pass an optional -`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements -`SocketFactoryType`. This will be used to create all sockets for the +``socket_factory`` to the :class:`~aiohttp.TCPConnector`, which implements +:class:`SocketFactoryType`. This will be used to create all sockets for the lifetime of the class object. For example, we may want to change the conditions under which we consider a connection dead. The following would make all sockets respect 9*7200 = 18 hours:: diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 7dabfe1a6db..147ec7dce2d 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1122,9 +1122,7 @@ is controlled by *force_close* constructor's parameter). overridden in subclasses. -.. autodata:: AddrInfoType - -.. note:: +.. py:class:: AddrInfoType Refer to :py:data:`aiohappyeyeballs.AddrInfoType` for more info. @@ -1132,13 +1130,11 @@ is controlled by *force_close* constructor's parameter). Be sure to use ``aiohttp.AddrInfoType`` rather than ``aiohappyeyeballs.AddrInfoType`` to avoid import breakage, as - it is likely to be removed from ``aiohappyeyeballs`` in the + it is likely to be removed from :mod:`aiohappyeyeballs` in the future. -.. autodata:: SocketFactoryType - -.. note:: +.. py:class:: SocketFactoryType Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` for more info. @@ -1146,7 +1142,7 @@ is controlled by *force_close* constructor's parameter). Be sure to use ``aiohttp.SocketFactoryType`` rather than ``aiohappyeyeballs.SocketFactoryType`` to avoid import breakage, - as it is likely to be removed from ``aiohappyeyeballs`` in the + as it is likely to be removed from :mod:`aiohappyeyeballs` in the future. @@ -1278,9 +1274,9 @@ is controlled by *force_close* constructor's parameter). .. versionadded:: 3.10 - :param :py:data:``SocketFactoryType`` socket_factory: This function takes an - :py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when - creating TCP connections. + :param SocketFactoryType socket_factory: This function takes an + :py:data:`AddrInfoType` and is used in lieu of + :py:func:`socket.socket` when creating TCP connections. .. versionadded:: 3.12 diff --git a/docs/conf.py b/docs/conf.py index eba93188b44..15de3598c7e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,7 +83,7 @@ "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), - "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None), + "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/latest/", None), } # Add any paths that contain templates here, relative to this directory. @@ -419,6 +419,7 @@ ("py:class", "aiohttp.web.MatchedSubAppResource"), # undocumented ("py:attr", "body"), # undocumented ("py:class", "socket.socket"), # undocumented + ("py:func", "socket.socket"), # undocumented ("py:class", "socket.AddressFamily"), # undocumented ("py:obj", "logging.DEBUG"), # undocumented ("py:class", "aiohttp.abc.AbstractAsyncAccessLogger"), # undocumented diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 71a5bb93d13..7032e1417b5 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -8,8 +8,9 @@ import pytest from multidict import CIMultiDict -from aiohttp import ClientConnectionResetError, http +from aiohttp import ClientConnectionResetError, hdrs, http from aiohttp.base_protocol import BaseProtocol +from aiohttp.http_writer import _serialize_headers from aiohttp.test_utils import make_mocked_coro @@ -603,3 +604,29 @@ async def test_set_eof_after_write_headers( msg.set_eof() await msg.write_eof() assert not transport.write.called + + +@pytest.mark.parametrize( + "char", + [ + "\n", + "\r", + ], +) +def test_serialize_headers_raises_on_new_line_or_carriage_return(char: str) -> None: + """Verify serialize_headers raises on cr or nl in the headers.""" + status_line = "HTTP/1.1 200 OK" + headers = CIMultiDict( + { + hdrs.CONTENT_TYPE: f"text/plain{char}", + } + ) + + with pytest.raises( + ValueError, + match=( + "Newline or carriage return detected in headers. " + "Potential header injection attack." + ), + ): + _serialize_headers(status_line, headers)