diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df41c1a607d..aac59671aa3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -106,7 +106,7 @@ repos: - id: pyupgrade args: ['--py37-plus'] - repo: https://github.com/PyCQA/flake8 - rev: '7.2.0' + rev: '7.3.0' hooks: - id: flake8 additional_dependencies: diff --git a/CHANGES/11269.feature.rst b/CHANGES/11269.feature.rst new file mode 100644 index 00000000000..92cf173be14 --- /dev/null +++ b/CHANGES/11269.feature.rst @@ -0,0 +1 @@ +Added initial trailer parsing logic to Python HTTP parser -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/11273.bugfix.rst b/CHANGES/11273.bugfix.rst new file mode 100644 index 00000000000..b4d9948fbcd --- /dev/null +++ b/CHANGES/11273.bugfix.rst @@ -0,0 +1 @@ +Fixed :py:meth:`ClientSession.close() ` hanging indefinitely when using HTTPS requests through HTTP proxies -- by :user:`bdraco`. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 1c96ff42457..e1288424bff 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -202,6 +202,26 @@ def closed(self) -> bool: return self._protocol is None or not self._protocol.is_connected() +class _ConnectTunnelConnection(Connection): + """Special connection wrapper for CONNECT tunnels that must never be pooled. + + This connection wraps the proxy connection that will be upgraded with TLS. + It must never be released to the pool because: + 1. Its 'closed' future will never complete, causing session.close() to hang + 2. It represents an intermediate state, not a reusable connection + 3. The real connection (with TLS) will be created separately + """ + + def release(self) -> None: + """Do nothing - don't pool or close the connection. + + These connections are an intermediate state during the CONNECT tunnel + setup and will be cleaned up naturally after the TLS upgrade. If they + were to be pooled, they would never be properly closed, causing + session.close() to wait forever for their 'closed' future. + """ + + class _TransportPlaceholder: """placeholder for BaseConnector.connect function""" @@ -1496,7 +1516,7 @@ async def _create_proxy_connection( key = req.connection_key._replace( proxy=None, proxy_auth=None, proxy_headers_hash=None ) - conn = Connection(self, key, proto, self._loop) + conn = _ConnectTunnelConnection(self, key, proto, self._loop) proxy_resp = await proxy_req.send(conn) try: protocol = conn._protocol diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index fa0328830a2..84b59afc486 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -123,7 +123,6 @@ class ChunkState(IntEnum): PARSE_CHUNKED_SIZE = 0 PARSE_CHUNKED_CHUNK = 1 PARSE_CHUNKED_CHUNK_EOF = 2 - PARSE_MAYBE_TRAILERS = 3 PARSE_TRAILERS = 4 @@ -142,8 +141,8 @@ def parse_headers( # note: "raw" does not mean inclusion of OWS before/after the field value raw_headers = [] - lines_idx = 1 - line = lines[1] + lines_idx = 0 + line = lines[lines_idx] line_count = len(lines) while line: @@ -394,6 +393,7 @@ def get_content_length() -> Optional[int]: response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, + headers_parser=self._headers_parser, ) if not payload_parser.done: self._payload_parser = payload_parser @@ -412,6 +412,7 @@ def get_content_length() -> Optional[int]: compression=msg.compression, auto_decompress=self._auto_decompress, lax=self.lax, + headers_parser=self._headers_parser, ) elif not empty_body and length is None and self.read_until_eof: payload = StreamReader( @@ -430,6 +431,7 @@ def get_content_length() -> Optional[int]: response_with_body=self.response_with_body, auto_decompress=self._auto_decompress, lax=self.lax, + headers_parser=self._headers_parser, ) if not payload_parser.done: self._payload_parser = payload_parser @@ -467,6 +469,10 @@ def get_content_length() -> Optional[int]: eof = True data = b"" + if isinstance( + underlying_exc, (InvalidHeader, TransferEncodingError) + ): + raise if eof: start_pos = 0 @@ -629,7 +635,7 @@ def parse_message(self, lines: List[bytes]) -> RawRequestMessage: compression, upgrade, chunked, - ) = self.parse_headers(lines) + ) = self.parse_headers(lines[1:]) if close is None: # then the headers weren't set in the request if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close @@ -715,7 +721,7 @@ def parse_message(self, lines: List[bytes]) -> RawResponseMessage: compression, upgrade, chunked, - ) = self.parse_headers(lines) + ) = self.parse_headers(lines[1:]) if close is None: if version_o <= HttpVersion10: @@ -758,6 +764,8 @@ def __init__( response_with_body: bool = True, auto_decompress: bool = True, lax: bool = False, + *, + headers_parser: HeadersParser, ) -> None: self._length = 0 self._type = ParseState.PARSE_UNTIL_EOF @@ -766,6 +774,8 @@ def __init__( self._chunk_tail = b"" self._auto_decompress = auto_decompress self._lax = lax + self._headers_parser = headers_parser + self._trailer_lines: list[bytes] = [] self.done = False # payload decompression wrapper @@ -833,7 +843,7 @@ def feed_data( size_b = chunk[:i] # strip chunk-extensions # Verify no LF in the chunk-extension if b"\n" in (ext := chunk[i:pos]): - exc = BadHttpMessage( + exc = TransferEncodingError( f"Unexpected LF in chunk-extension: {ext!r}" ) set_exception(self.payload, exc) @@ -854,7 +864,7 @@ def feed_data( chunk = chunk[pos + len(SEP) :] if size == 0: # eof marker - self._chunk = ChunkState.PARSE_MAYBE_TRAILERS + self._chunk = ChunkState.PARSE_TRAILERS if self._lax and chunk.startswith(b"\r"): chunk = chunk[1:] else: @@ -888,38 +898,31 @@ def feed_data( self._chunk_tail = chunk return False, b"" - # if stream does not contain trailer, after 0\r\n - # we should get another \r\n otherwise - # trailers needs to be skipped until \r\n\r\n - if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS: - head = chunk[: len(SEP)] - if head == SEP: - # end of stream - self.payload.feed_eof() - return True, chunk[len(SEP) :] - # Both CR and LF, or only LF may not be received yet. It is - # expected that CRLF or LF will be shown at the very first - # byte next time, otherwise trailers should come. The last - # CRLF which marks the end of response might not be - # contained in the same TCP segment which delivered the - # size indicator. - if not head: - return False, b"" - if head == SEP[:1]: - self._chunk_tail = head - return False, b"" - self._chunk = ChunkState.PARSE_TRAILERS - - # read and discard trailer up to the CRLF terminator if self._chunk == ChunkState.PARSE_TRAILERS: pos = chunk.find(SEP) - if pos >= 0: - chunk = chunk[pos + len(SEP) :] - self._chunk = ChunkState.PARSE_MAYBE_TRAILERS - else: + if pos < 0: # No line found self._chunk_tail = chunk return False, b"" + line = chunk[:pos] + chunk = chunk[pos + len(SEP) :] + if SEP == b"\n": # For lax response parsing + line = line.rstrip(b"\r") + self._trailer_lines.append(line) + + # \r\n\r\n found, end of stream + if self._trailer_lines[-1] == b"": + # Headers and trailers are defined the same way, + # so we reuse the HeadersParser here. + try: + trailers, raw_trailers = self._headers_parser.parse_headers( + self._trailer_lines + ) + finally: + self._trailer_lines.clear() + self.payload.feed_eof() + return True, chunk + # Read all bytes until eof elif self._type == ParseState.PARSE_UNTIL_EOF: self.payload.feed_data(chunk) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 0c18ba05a54..befb1e8c373 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -781,7 +781,7 @@ async def _read_boundary(self) -> None: raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}") async def _read_headers(self) -> "CIMultiDictProxy[str]": - lines = [b""] + lines = [] while True: chunk = await self._content.readline() chunk = chunk.strip() diff --git a/tests/test_connector.py b/tests/test_connector.py index 1a739674ce3..53704fc5ea5 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -47,6 +47,7 @@ AddrInfoType, Connection, TCPConnector, + _ConnectTunnelConnection, _DNSCacheTable, ) from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer @@ -4454,3 +4455,31 @@ async def test_available_connections_no_limits( connection1.close() assert conn._available_connections(key) == 1 assert conn._available_connections(other_host_key2) == 1 + + +async def test_connect_tunnel_connection_release( + loop: asyncio.AbstractEventLoop, +) -> None: + """Test _ConnectTunnelConnection.release() does not pool the connection.""" + connector = mock.create_autospec( + aiohttp.BaseConnector, spec_set=True, instance=True + ) + key = mock.create_autospec(ConnectionKey, spec_set=True, instance=True) + protocol = mock.create_autospec(ResponseHandler, spec_set=True, instance=True) + + # Create a connect tunnel connection + conn = _ConnectTunnelConnection(connector, key, protocol, loop) + + # Verify protocol is set + assert conn._protocol is protocol + + # Release should do nothing (not pool the connection) + conn.release() + + # Protocol should still be there (not released to pool) + assert conn._protocol is protocol + # Connector._release should NOT have been called + connector._release.assert_not_called() + + # Clean up to avoid resource warning + conn.close() diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 4a1e196f8eb..38c6971b037 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -18,6 +18,7 @@ from aiohttp.helpers import NO_EXTENSIONS from aiohttp.http_parser import ( DeflateBuffer, + HeadersParser, HttpParser, HttpPayloadParser, HttpRequestParser, @@ -266,43 +267,13 @@ def test_content_length_transfer_encoding(parser: HttpRequestParser) -> None: parser.feed_data(text) -def test_bad_chunked_py( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol -) -> None: +def test_bad_chunked(parser: HttpRequestParser) -> None: """Test that invalid chunked encoding doesn't allow content-length to be used.""" - parser = HttpRequestParserPy( - protocol, - loop, - 2**16, - max_line_size=8190, - max_field_size=8190, - ) - text = ( - b"GET / HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked\r\n\r\n0_2e\r\n\r\n" - + b"GET / HTTP/1.1\r\nHost: a\r\nContent-Length: 5\r\n\r\n0\r\n\r\n" - ) - messages, upgrade, tail = parser.feed_data(text) - assert isinstance(messages[0][1].exception(), http_exceptions.TransferEncodingError) - - -@pytest.mark.skipif( - "HttpRequestParserC" not in dir(aiohttp.http_parser), - reason="C based HTTP parser not available", -) -def test_bad_chunked_c(loop: asyncio.AbstractEventLoop, protocol: BaseProtocol) -> None: - """C parser behaves differently. Maybe we should align them later.""" - parser = HttpRequestParserC( - protocol, - loop, - 2**16, - max_line_size=8190, - max_field_size=8190, - ) text = ( b"GET / HTTP/1.1\r\nHost: a\r\nTransfer-Encoding: chunked\r\n\r\n0_2e\r\n\r\n" + b"GET / HTTP/1.1\r\nHost: a\r\nContent-Length: 5\r\n\r\n0\r\n\r\n" ) - with pytest.raises(http_exceptions.BadHttpMessage): + with pytest.raises(http_exceptions.BadHttpMessage, match="0_2e"): parser.feed_data(text) @@ -1207,8 +1178,8 @@ async def test_http_response_parser_bad_chunked_strict_py( text = ( b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5 \r\nabcde\r\n0\r\n\r\n" ) - messages, upgrade, tail = response.feed_data(text) - assert isinstance(messages[0][1].exception(), http_exceptions.TransferEncodingError) + with pytest.raises(http_exceptions.TransferEncodingError, match="5"): + response.feed_data(text) @pytest.mark.dev_mode @@ -1354,6 +1325,22 @@ def test_parse_chunked_payload_chunk_extension(parser: HttpRequestParser) -> Non assert payload.is_eof() +async def test_request_chunked_with_trailer(parser: HttpRequestParser) -> None: + text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\ntest: trailer\r\nsecond: test trailer\r\n\r\n" + messages, upgraded, tail = parser.feed_data(text) + assert not tail + msg, payload = messages[0] + assert await payload.read() == b"test" + + # TODO: Add assertion of trailers when API added. + + +async def test_request_chunked_reject_bad_trailer(parser: HttpRequestParser) -> None: + text = b"GET /test HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nbad\ntrailer\r\n\r\n" + with pytest.raises(http_exceptions.BadHttpMessage, match=r"b'bad\\ntrailer'"): + parser.feed_data(text) + + def test_parse_no_length_or_te_on_post( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, @@ -1547,19 +1534,10 @@ async def test_parse_chunked_payload_split_chunks(response: HttpResponseParser) assert await reader.read() == b"firstsecond" -@pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") -async def test_parse_chunked_payload_with_lf_in_extensions_c_parser( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol +async def test_parse_chunked_payload_with_lf_in_extensions( + parser: HttpRequestParser, ) -> None: - """Test the C-parser with a chunked payload that has a LF in the chunk extensions.""" - # The C parser will raise a BadHttpMessage from feed_data - parser = HttpRequestParserC( - protocol, - loop, - 2**16, - max_line_size=8190, - max_field_size=8190, - ) + """Test chunked payload that has a LF in the chunk extensions.""" payload = ( b"GET / HTTP/1.1\r\nHost: localhost:5001\r\n" b"Transfer-Encoding: chunked\r\n\r\n2;\nxx\r\n4c\r\n0\r\n\r\n" @@ -1570,31 +1548,6 @@ async def test_parse_chunked_payload_with_lf_in_extensions_c_parser( parser.feed_data(payload) -async def test_parse_chunked_payload_with_lf_in_extensions_py_parser( - loop: asyncio.AbstractEventLoop, protocol: BaseProtocol -) -> None: - """Test the py-parser with a chunked payload that has a LF in the chunk extensions.""" - # The py parser will not raise the BadHttpMessage directly, but instead - # it will set the exception on the StreamReader. - parser = HttpRequestParserPy( - protocol, - loop, - 2**16, - max_line_size=8190, - max_field_size=8190, - ) - payload = ( - b"GET / HTTP/1.1\r\nHost: localhost:5001\r\n" - b"Transfer-Encoding: chunked\r\n\r\n2;\nxx\r\n4c\r\n0\r\n\r\n" - b"GET /admin HTTP/1.1\r\nHost: localhost:5001\r\n" - b"Transfer-Encoding: chunked\r\n\r\n0\r\n\r\n" - ) - messages, _, _ = parser.feed_data(payload) - reader = messages[0][1] - assert isinstance(reader.exception(), http_exceptions.BadHttpMessage) - assert "\\nxx" in str(reader.exception()) - - def test_partial_url(parser: HttpRequestParser) -> None: messages, upgrade, tail = parser.feed_data(b"GET /te") assert len(messages) == 0 @@ -1684,7 +1637,7 @@ def test_parse_bad_method_for_c_parser_raises( class TestParsePayload: async def test_parse_eof_payload(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out) + p = HttpPayloadParser(out, headers_parser=HeadersParser()) p.feed_data(b"data") p.feed_eof() @@ -1694,7 +1647,7 @@ async def test_parse_eof_payload(self, protocol: BaseProtocol) -> None: async def test_parse_length_payload_eof(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=4) + p = HttpPayloadParser(out, length=4, headers_parser=HeadersParser()) p.feed_data(b"da") with pytest.raises(http_exceptions.ContentLengthError): @@ -1704,7 +1657,7 @@ async def test_parse_chunked_payload_size_error( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, chunked=True) + p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) with pytest.raises(http_exceptions.TransferEncodingError): p.feed_data(b"blah\r\n") assert isinstance(out.exception(), http_exceptions.TransferEncodingError) @@ -1713,7 +1666,7 @@ async def test_parse_chunked_payload_split_end( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, chunked=True) + p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"\r\n") @@ -1724,7 +1677,7 @@ async def test_parse_chunked_payload_split_end2( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, chunked=True) + p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n\r") p.feed_data(b"\n") @@ -1735,7 +1688,7 @@ async def test_parse_chunked_payload_split_end_trailers( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, chunked=True) + p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n") p.feed_data(b"\r\n") @@ -1747,7 +1700,7 @@ async def test_parse_chunked_payload_split_end_trailers2( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, chunked=True) + p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\n") p.feed_data(b"Content-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r") p.feed_data(b"\n") @@ -1759,7 +1712,7 @@ async def test_parse_chunked_payload_split_end_trailers3( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, chunked=True) + p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\nContent-MD5: ") p.feed_data(b"912ec803b2ce49e4a541068d495ab570\r\n\r\n") @@ -1770,7 +1723,7 @@ async def test_parse_chunked_payload_split_end_trailers4( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, chunked=True) + p = HttpPayloadParser(out, chunked=True, headers_parser=HeadersParser()) p.feed_data(b"4\r\nasdf\r\n0\r\nC") p.feed_data(b"ontent-MD5: 912ec803b2ce49e4a541068d495ab570\r\n\r\n") @@ -1779,7 +1732,7 @@ async def test_parse_chunked_payload_split_end_trailers4( async def test_http_payload_parser_length(self, protocol: BaseProtocol) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=2) + p = HttpPayloadParser(out, length=2, headers_parser=HeadersParser()) eof, tail = p.feed_data(b"1245") assert eof @@ -1792,7 +1745,9 @@ async def test_http_payload_parser_deflate(self, protocol: BaseProtocol) -> None length = len(COMPRESSED) out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=length, compression="deflate") + p = HttpPayloadParser( + out, length=length, compression="deflate", headers_parser=HeadersParser() + ) p.feed_data(COMPRESSED) assert b"data" == out._buffer[0] assert out.is_eof() @@ -1806,7 +1761,9 @@ async def test_http_payload_parser_deflate_no_hdrs( length = len(COMPRESSED) out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=length, compression="deflate") + p = HttpPayloadParser( + out, length=length, compression="deflate", headers_parser=HeadersParser() + ) p.feed_data(COMPRESSED) assert b"data" == out._buffer[0] assert out.is_eof() @@ -1819,7 +1776,9 @@ async def test_http_payload_parser_deflate_light( length = len(COMPRESSED) out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=length, compression="deflate") + p = HttpPayloadParser( + out, length=length, compression="deflate", headers_parser=HeadersParser() + ) p.feed_data(COMPRESSED) assert b"data" == out._buffer[0] @@ -1829,7 +1788,9 @@ async def test_http_payload_parser_deflate_split( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, compression="deflate") + p = HttpPayloadParser( + out, compression="deflate", headers_parser=HeadersParser() + ) # Feeding one correct byte should be enough to choose exact # deflate decompressor p.feed_data(b"x") @@ -1841,7 +1802,9 @@ async def test_http_payload_parser_deflate_split_err( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, compression="deflate") + p = HttpPayloadParser( + out, compression="deflate", headers_parser=HeadersParser() + ) # Feeding one wrong byte should be enough to choose exact # deflate decompressor p.feed_data(b"K") @@ -1853,7 +1816,7 @@ async def test_http_payload_parser_length_zero( self, protocol: BaseProtocol ) -> None: out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=0) + p = HttpPayloadParser(out, length=0, headers_parser=HeadersParser()) assert p.done assert out.is_eof() @@ -1861,7 +1824,12 @@ async def test_http_payload_parser_length_zero( async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: compressed = brotli.compress(b"brotli data") out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=len(compressed), compression="br") + p = HttpPayloadParser( + out, + length=len(compressed), + compression="br", + headers_parser=HeadersParser(), + ) p.feed_data(compressed) assert b"brotli data" == out._buffer[0] assert out.is_eof() @@ -1870,7 +1838,12 @@ async def test_http_payload_brotli(self, protocol: BaseProtocol) -> None: async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None: compressed = zstandard.compress(b"zstd data") out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop()) - p = HttpPayloadParser(out, length=len(compressed), compression="zstd") + p = HttpPayloadParser( + out, + length=len(compressed), + compression="zstd", + headers_parser=HeadersParser(), + ) p.feed_data(compressed) assert b"zstd data" == out._buffer[0] assert out.is_eof() diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 0a375cf61c7..89aba898c27 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -4,6 +4,7 @@ import platform import ssl import sys +from contextlib import suppress from re import match as match_regex from typing import ( TYPE_CHECKING, @@ -973,3 +974,46 @@ async def test_proxy_auth() -> None: proxy_auth=("user", "pass"), # type: ignore[arg-type] ): pass + + +async def test_https_proxy_connect_tunnel_session_close_no_hang( + aiohttp_server: AiohttpServer, +) -> None: + """Test that CONNECT tunnel connections are not pooled.""" + # Regression test for issue #11273. + + # Create a minimal proxy server + # The CONNECT method is handled at the protocol level, not by the handler + proxy_app = web.Application() + proxy_server = await aiohttp_server(proxy_app) + proxy_url = f"http://{proxy_server.host}:{proxy_server.port}" + + # Create session and make HTTPS request through proxy + session = aiohttp.ClientSession() + + try: + # This will fail during TLS upgrade because proxy doesn't establish tunnel + with suppress(aiohttp.ClientError): + async with session.get("https://example.com/test", proxy=proxy_url) as resp: + await resp.read() + + # The critical test: Check if any connections were pooled with proxy=None + # This is the root cause of the hang - CONNECT tunnel connections + # should NOT be pooled + connector = session.connector + assert connector is not None + + # Count connections with proxy=None in the pool + proxy_none_keys = [key for key in connector._conns if key.proxy is None] + proxy_none_count = len(proxy_none_keys) + + # Before the fix, there would be a connection with proxy=None + # After the fix, CONNECT tunnel connections are not pooled + assert proxy_none_count == 0, ( + f"Found {proxy_none_count} connections with proxy=None in pool. " + f"CONNECT tunnel connections should not be pooled - this is bug #11273" + ) + + finally: + # Clean close + await session.close()