Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/10474.feature.rst
1 change: 1 addition & 0 deletions CHANGES/10601.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved performance of WebSocket buffer handling -- by :user:`bdraco`.
1 change: 1 addition & 0 deletions CHANGES/10625.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved performance of serializing headers -- by :user:`bdraco`.
2 changes: 2 additions & 0 deletions CHANGES/10634.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Replaced deprecated ``asyncio.iscoroutinefunction`` with its counterpart from ``inspect``
-- by :user:`layday`.
38 changes: 18 additions & 20 deletions aiohttp/_http_writer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,34 @@ cdef inline int _write_str(Writer* writer, str s):
return -1


# --------------- _serialize_headers ----------------------

cdef str to_str(object s):
cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s):
cdef Py_UCS4 ch
cdef str out_str
if type(s) is str:
return <str>s
out_str = <str>s
elif type(s) is _istr:
return PyObject_Str(s)
out_str = PyObject_Str(s)
elif not isinstance(s, str):
raise TypeError("Cannot serialize non-str key {!r}".format(s))
else:
return str(s)
out_str = str(s)

for ch in out_str:
if ch == 0x0D or ch == 0x0A:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
if _write_utf8(writer, ch) < 0:
return -1


# --------------- _serialize_headers ----------------------

def _serialize_headers(str status_line, headers):
cdef Writer writer
cdef object key
cdef object val
cdef bytes ret
cdef str key_str
cdef str val_str

_init_writer(&writer)

Expand All @@ -130,22 +137,13 @@ def _serialize_headers(str status_line, headers):
raise

for key, val in headers.items():
key_str = to_str(key)
val_str = to_str(val)

if "\r" in key_str or "\n" in key_str or "\r" in val_str or "\n" in val_str:
raise ValueError(
"Newline or carriage return character detected in HTTP status message or "
"header. This is a potential security issue."
)

if _write_str(&writer, key_str) < 0:
if _write_str_raise_on_nlcr(&writer, key) < 0:
raise
if _write_byte(&writer, b':') < 0:
raise
if _write_byte(&writer, b' ') < 0:
raise
if _write_str(&writer, val_str) < 0:
if _write_str_raise_on_nlcr(&writer, val) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
Expand Down
1 change: 1 addition & 0 deletions aiohttp/_websocket/reader_c.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ cdef class WebSocketReader:
chunk_size="unsigned int",
chunk_len="unsigned int",
buf_length="unsigned int",
buf_cstr="const unsigned char *",
first_byte="unsigned char",
second_byte="unsigned char",
end_pos="unsigned int",
Expand Down
37 changes: 18 additions & 19 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,8 @@ def _feed_data(self, data: bytes) -> None:
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
f"Message size {len(self._partial)} "
f"exceeds limit {self._max_msg_size}",
)
continue

Expand All @@ -220,7 +219,7 @@ def _feed_data(self, data: bytes) -> None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
"to be zero, got {!r}".format(opcode),
f"to be zero, got {opcode!r}",
)

assembled_payload: Union[bytes, bytearray]
Expand All @@ -233,9 +232,8 @@ def _feed_data(self, data: bytes) -> None:
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(assembled_payload), self._max_msg_size
),
f"Message size {len(assembled_payload)} "
f"exceeds limit {self._max_msg_size}",
)

# Decompress process must to be done after all packets
Expand All @@ -252,9 +250,8 @@ def _feed_data(self, data: bytes) -> None:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}".format(
self._max_msg_size + left, self._max_msg_size
),
f"Decompressed message size {self._max_msg_size + left}"
f" exceeds limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
Expand Down Expand Up @@ -333,14 +330,15 @@ def parse_frame(

start_pos: int = 0
buf_length = len(buf)
buf_cstr = buf

while True:
# read header
if self._state == READ_HEADER:
if buf_length - start_pos < 2:
break
first_byte = buf[start_pos]
second_byte = buf[start_pos + 1]
first_byte = buf_cstr[start_pos]
second_byte = buf_cstr[start_pos + 1]
start_pos += 2

fin = (first_byte >> 7) & 1
Expand Down Expand Up @@ -405,14 +403,14 @@ def parse_frame(
if length_flag == 126:
if buf_length - start_pos < 2:
break
first_byte = buf[start_pos]
second_byte = buf[start_pos + 1]
first_byte = buf_cstr[start_pos]
second_byte = buf_cstr[start_pos + 1]
start_pos += 2
self._payload_length = first_byte << 8 | second_byte
elif length_flag > 126:
if buf_length - start_pos < 8:
break
data = buf[start_pos : start_pos + 8]
data = buf_cstr[start_pos : start_pos + 8]
start_pos += 8
self._payload_length = UNPACK_LEN3(data)[0]
else:
Expand All @@ -424,7 +422,7 @@ def parse_frame(
if self._state == READ_PAYLOAD_MASK:
if buf_length - start_pos < 4:
break
self._frame_mask = buf[start_pos : start_pos + 4]
self._frame_mask = buf_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD

Expand All @@ -440,10 +438,10 @@ def parse_frame(
if self._frame_payload_len:
if type(self._frame_payload) is not bytearray:
self._frame_payload = bytearray(self._frame_payload)
self._frame_payload += buf[start_pos:end_pos]
self._frame_payload += buf_cstr[start_pos:end_pos]
else:
# Fast path for the first frame
self._frame_payload = buf[start_pos:end_pos]
self._frame_payload = buf_cstr[start_pos:end_pos]

self._frame_payload_len += end_pos - start_pos
start_pos = end_pos
Expand All @@ -469,6 +467,7 @@ def parse_frame(
self._frame_payload_len = 0
self._state = READ_HEADER

self._tail = buf[start_pos:] if start_pos < buf_length else b""
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = buf_cstr[start_pos:buf_length] if start_pos < buf_length else b""

return frames
6 changes: 3 additions & 3 deletions aiohttp/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def]
if inspect.isasyncgenfunction(func):
# async generator fixture
is_async_gen = True
elif asyncio.iscoroutinefunction(func):
elif inspect.iscoroutinefunction(func):
# regular async fixture
is_async_gen = False
else:
Expand Down Expand Up @@ -216,14 +216,14 @@ def _passthrough_loop_context(

def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def]
"""Fix pytest collecting for coroutines."""
if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
if collector.funcnamefilter(name) and inspect.iscoroutinefunction(obj):
return list(collector._genfunctions(name, obj))


def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def]
"""Run coroutines in an event loop instead of a normal function call."""
fast = pyfuncitem.config.getoption("--aiohttp-fast")
if asyncio.iscoroutinefunction(pyfuncitem.function):
if inspect.iscoroutinefunction(pyfuncitem.function):
existing_loop = pyfuncitem.funcargs.get(
"proactor_loop"
) or pyfuncitem.funcargs.get("loop", None)
Expand Down
9 changes: 6 additions & 3 deletions aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import hashlib
import html
import inspect
import keyword
import os
import re
Expand Down Expand Up @@ -174,15 +175,17 @@ def __init__(
if expect_handler is None:
expect_handler = _default_expect_handler

assert asyncio.iscoroutinefunction(
expect_handler
assert inspect.iscoroutinefunction(expect_handler) or (
sys.version_info < (3, 14) and asyncio.iscoroutinefunction(expect_handler)
), f"Coroutine is expected, got {expect_handler!r}"

method = method.upper()
if not HTTP_METHOD_RE.match(method):
raise ValueError(f"{method} is not allowed HTTP method")

if asyncio.iscoroutinefunction(handler):
if inspect.iscoroutinefunction(handler) or (
sys.version_info < (3, 14) and asyncio.iscoroutinefunction(handler)
):
pass
elif isinstance(handler, type) and issubclass(handler, AbstractView):
pass
Expand Down
5 changes: 4 additions & 1 deletion aiohttp/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Async gunicorn worker for aiohttp.web"""

import asyncio
import inspect
import os
import re
import signal
Expand Down Expand Up @@ -68,7 +69,9 @@ async def _run(self) -> None:
runner = None
if isinstance(self.wsgi, Application):
app = self.wsgi
elif asyncio.iscoroutinefunction(self.wsgi):
elif inspect.iscoroutinefunction(self.wsgi) or (
sys.version_info < (3, 14) and asyncio.iscoroutinefunction(self.wsgi)
):
wsgi = await self.wsgi()
if isinstance(wsgi, web.AppRunner):
runner = wsgi
Expand Down
4 changes: 2 additions & 2 deletions docs/client_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ Custom socket creation
^^^^^^^^^^^^^^^^^^^^^^

If the default socket is insufficient for your use case, pass an optional
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
`SocketFactoryType`. This will be used to create all sockets for the
``socket_factory`` to the :class:`~aiohttp.TCPConnector`, which implements
:class:`SocketFactoryType`. This will be used to create all sockets for the
lifetime of the class object. For example, we may want to change the
conditions under which we consider a connection dead. The following would
make all sockets respect 9*7200 = 18 hours::
Expand Down
18 changes: 7 additions & 11 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1122,31 +1122,27 @@ is controlled by *force_close* constructor's parameter).
overridden in subclasses.


.. autodata:: AddrInfoType

.. note::
.. py:class:: AddrInfoType

Refer to :py:data:`aiohappyeyeballs.AddrInfoType` for more info.

.. warning::

Be sure to use ``aiohttp.AddrInfoType`` rather than
``aiohappyeyeballs.AddrInfoType`` to avoid import breakage, as
it is likely to be removed from ``aiohappyeyeballs`` in the
it is likely to be removed from :mod:`aiohappyeyeballs` in the
future.


.. autodata:: SocketFactoryType

.. note::
.. py:class:: SocketFactoryType

Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` for more info.

.. warning::

Be sure to use ``aiohttp.SocketFactoryType`` rather than
``aiohappyeyeballs.SocketFactoryType`` to avoid import breakage,
as it is likely to be removed from ``aiohappyeyeballs`` in the
as it is likely to be removed from :mod:`aiohappyeyeballs` in the
future.


Expand Down Expand Up @@ -1278,9 +1274,9 @@ is controlled by *force_close* constructor's parameter).

.. versionadded:: 3.10

:param :py:data:``SocketFactoryType`` socket_factory: This function takes an
:py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when
creating TCP connections.
:param SocketFactoryType socket_factory: This function takes an
:py:data:`AddrInfoType` and is used in lieu of
:py:func:`socket.socket` when creating TCP connections.

.. versionadded:: 3.12

Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None),
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None),
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/latest/", None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down Expand Up @@ -419,6 +419,7 @@
("py:class", "aiohttp.web.MatchedSubAppResource"), # undocumented
("py:attr", "body"), # undocumented
("py:class", "socket.socket"), # undocumented
("py:func", "socket.socket"), # undocumented
("py:class", "socket.AddressFamily"), # undocumented
("py:obj", "logging.DEBUG"), # undocumented
("py:class", "aiohttp.abc.AbstractAsyncAccessLogger"), # undocumented
Expand Down
29 changes: 28 additions & 1 deletion tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import pytest
from multidict import CIMultiDict

from aiohttp import ClientConnectionResetError, http
from aiohttp import ClientConnectionResetError, hdrs, http
from aiohttp.base_protocol import BaseProtocol
from aiohttp.http_writer import _serialize_headers
from aiohttp.test_utils import make_mocked_coro


Expand Down Expand Up @@ -603,3 +604,29 @@ async def test_set_eof_after_write_headers(
msg.set_eof()
await msg.write_eof()
assert not transport.write.called


@pytest.mark.parametrize(
"char",
[
"\n",
"\r",
],
)
def test_serialize_headers_raises_on_new_line_or_carriage_return(char: str) -> None:
"""Verify serialize_headers raises on cr or nl in the headers."""
status_line = "HTTP/1.1 200 OK"
headers = CIMultiDict(
{
hdrs.CONTENT_TYPE: f"text/plain{char}",
}
)

with pytest.raises(
ValueError,
match=(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
),
):
_serialize_headers(status_line, headers)
Loading