diff --git a/CHANGES/10988.bugfix.rst b/CHANGES/10988.bugfix.rst new file mode 120000 index 00000000000..6e737bb336c --- /dev/null +++ b/CHANGES/10988.bugfix.rst @@ -0,0 +1 @@ +6009.bugfix.rst \ No newline at end of file diff --git a/CHANGES/2914.doc.rst b/CHANGES/2914.doc.rst new file mode 100644 index 00000000000..25592bf79bc --- /dev/null +++ b/CHANGES/2914.doc.rst @@ -0,0 +1,4 @@ +Improved documentation for middleware by adding warnings and examples about +request body stream consumption. The documentation now clearly explains that +request body streams can only be read once and provides best practices for +sharing parsed request data between middleware and handlers -- by :user:`bdraco`. diff --git a/CHANGES/6009.bugfix.rst b/CHANGES/6009.bugfix.rst new file mode 100644 index 00000000000..a530832c8a9 --- /dev/null +++ b/CHANGES/6009.bugfix.rst @@ -0,0 +1 @@ +Fixed :py:attr:`~aiohttp.web.WebSocketResponse.prepared` property to correctly reflect the prepared state, especially during timeout scenarios -- by :user:`bdraco` diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 9421dc2ac76..a1fd12a1e97 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -358,6 +358,10 @@ def can_prepare(self, request: BaseRequest) -> WebSocketReady: else: return WebSocketReady(True, protocol) + @property + def prepared(self) -> bool: + return self._writer is not None + @property def closed(self) -> bool: return self._closed diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index 76fc3ea57f1..f182acf11c1 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -568,9 +568,13 @@ A *middleware* is a coroutine that can modify either the request or response. For example, here's a simple *middleware* which appends ``' wink'`` to the response:: - from aiohttp.web import middleware + from aiohttp import web + from typing import Callable, Awaitable - async def middleware(request, handler): + async def middleware( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]] + ) -> web.StreamResponse: resp = await handler(request) resp.text = resp.text + ' wink' return resp @@ -619,18 +623,25 @@ post-processing like handling *CORS* and so on. The following code demonstrates middlewares execution order:: from aiohttp import web + from typing import Callable, Awaitable - async def test(request): + async def test(request: web.Request) -> web.Response: print('Handler function called') return web.Response(text="Hello") - async def middleware1(request, handler): + async def middleware1( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]] + ) -> web.StreamResponse: print('Middleware 1 called') response = await handler(request) print('Middleware 1 finished') return response - async def middleware2(request, handler): + async def middleware2( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]] + ) -> web.StreamResponse: print('Middleware 2 called') response = await handler(request) print('Middleware 2 finished') @@ -649,6 +660,82 @@ Produced output:: Middleware 2 finished Middleware 1 finished +Request Body Stream Consumption +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. warning:: + + When middleware reads the request body (using :meth:`~aiohttp.web.BaseRequest.read`, + :meth:`~aiohttp.web.BaseRequest.text`, :meth:`~aiohttp.web.BaseRequest.json`, or + :meth:`~aiohttp.web.BaseRequest.post`), the body stream is consumed. However, these + high-level methods cache their result, so subsequent calls from the handler or other + middleware will return the same cached value. + + The important distinction is: + + - High-level methods (:meth:`~aiohttp.web.BaseRequest.read`, :meth:`~aiohttp.web.BaseRequest.text`, + :meth:`~aiohttp.web.BaseRequest.json`, :meth:`~aiohttp.web.BaseRequest.post`) cache their + results internally, so they can be called multiple times and will return the same value. + - Direct stream access via :attr:`~aiohttp.web.BaseRequest.content` does NOT have this + caching behavior. Once you read from ``request.content`` directly (e.g., using + ``await request.content.read()``), subsequent reads will return empty bytes. + +Consider this middleware that logs request bodies:: + + from aiohttp import web + from typing import Callable, Awaitable + + async def logging_middleware( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]] + ) -> web.StreamResponse: + # This consumes the request body stream + body = await request.text() + print(f"Request body: {body}") + return await handler(request) + + async def handler(request: web.Request) -> web.Response: + # This will return the same value that was read in the middleware + # (i.e., the cached result, not an empty string) + body = await request.text() + return web.Response(text=f"Received: {body}") + +In contrast, when accessing the stream directly (not recommended in middleware):: + + async def stream_middleware( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]] + ) -> web.StreamResponse: + # Reading directly from the stream - this consumes it! + data = await request.content.read() + print(f"Stream data: {data}") + return await handler(request) + + async def handler(request: web.Request) -> web.Response: + # This will return empty bytes because the stream was already consumed + data = await request.content.read() + # data will be b'' (empty bytes) + + # However, high-level methods would still work if called for the first time: + # body = await request.text() # This would read from internal cache if available + return web.Response(text=f"Received: {data}") + +When working with raw stream data that needs to be shared between middleware and handlers:: + + async def stream_parsing_middleware( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.StreamResponse]] + ) -> web.StreamResponse: + # Read stream once and store the data + raw_data = await request.content.read() + request['raw_body'] = raw_data + return await handler(request) + + async def handler(request: web.Request) -> web.Response: + # Access the stored data instead of reading the stream again + raw_data = request.get('raw_body', b'') + return web.Response(body=raw_data) + Example ^^^^^^^ diff --git a/docs/web_reference.rst b/docs/web_reference.rst index d1c14409a28..74e7f9cd3d1 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -1045,6 +1045,11 @@ and :ref:`aiohttp-web-signals` handlers:: of closing. :const:`~aiohttp.WSMsgType.CLOSE` message has been received from peer. + .. attribute:: prepared + + Read-only :class:`bool` property, ``True`` if :meth:`prepare` has + been called, ``False`` otherwise. + .. attribute:: close_code Read-only property, close code from peer. It is set to ``None`` on diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 70c9efb98dc..0fdea98cdd0 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -3408,6 +3408,9 @@ async def handler(request: web.Request) -> NoReturn: pass assert cm._session.closed + # Allow event loop to process transport cleanup + # on Python < 3.11 + await asyncio.sleep(0) async def test_aiohttp_request_ctx_manager_not_found() -> None: diff --git a/tests/test_client_middleware.py b/tests/test_client_middleware.py index 9d49b750333..e698e8ee825 100644 --- a/tests/test_client_middleware.py +++ b/tests/test_client_middleware.py @@ -863,8 +863,13 @@ async def test_client_middleware_retry_reuses_connection( aiohttp_server: AiohttpServer, ) -> None: """Test that connections are reused when middleware performs retries.""" + request_count = 0 async def handler(request: web.Request) -> web.Response: + nonlocal request_count + request_count += 1 + if request_count == 1: + return web.Response(status=400) # First request returns 400 with no body return web.Response(text="OK") class TrackingConnector(TCPConnector): @@ -891,7 +896,7 @@ async def __call__( while True: self.attempt_count += 1 response = await handler(request) - if retry_count == 0: + if response.status == 400 and retry_count == 0: retry_count += 1 continue return response diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 45719ce6012..d082f7d5d47 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -670,4 +670,4 @@ async def test_get_extra_info( await ws.prepare(req) ws._writer = ws_transport - assert ws.get_extra_info(valid_key, default_value) == expected_result + assert expected_result == ws.get_extra_info(valid_key, default_value) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 2ca95f71d9e..faee34cf811 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -1332,3 +1332,116 @@ async def handler(request: web.Request) -> web.WebSocketResponse: ) await client.server.close() assert close_code == WSCloseCode.OK + + +async def test_websocket_prepare_timeout_close_issue( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + """Test that WebSocket can handle prepare with early returns. + + This is a regression test for issue #6009 where the prepared property + incorrectly checked _payload_writer instead of _writer. + """ + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + assert ws.can_prepare(request) + await ws.prepare(request) + await ws.send_str("test") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/ws", handler) + client = await aiohttp_client(app) + + # Connect via websocket + ws = await client.ws_connect("/ws") + msg = await ws.receive() + assert msg.type is WSMsgType.TEXT + assert msg.data == "test" + await ws.close() + + +async def test_websocket_prepare_timeout_from_issue_reproducer( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + """Test websocket behavior when prepare is interrupted. + + This test verifies the fix for issue #6009 where close() would + fail after prepare() was interrupted. + """ + prepare_complete = asyncio.Event() + close_complete = asyncio.Event() + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + + # Prepare the websocket + await ws.prepare(request) + prepare_complete.set() + + # Send a message to confirm connection works + await ws.send_str("connected") + + # Wait for client to close + msg = await ws.receive() + assert msg.type is WSMsgType.CLOSE + await ws.close() + close_complete.set() + + return ws + + app = web.Application() + app.router.add_route("GET", "/ws", handler) + client = await aiohttp_client(app) + + # Connect and verify the connection works + ws = await client.ws_connect("/ws") + await prepare_complete.wait() + + msg = await ws.receive() + assert msg.type is WSMsgType.TEXT + assert msg.data == "connected" + + # Close the connection + await ws.close() + await close_complete.wait() + + +async def test_websocket_prepared_property( + loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient +) -> None: + """Test that WebSocketResponse.prepared property correctly reflects state.""" + prepare_called = asyncio.Event() + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + + # Initially not prepared + initial_state = ws.prepared + assert not initial_state + + # After prepare() is called, should be prepared + await ws.prepare(request) + prepare_called.set() + + # Check prepared state + prepared_state = ws.prepared + assert prepared_state + + # Send a message to verify the connection works + await ws.send_str("test") + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + + ws = await client.ws_connect("/") + await prepare_called.wait() + msg = await ws.receive() + assert msg.type is WSMsgType.TEXT + assert msg.data == "test" + await ws.close()