From 996ad00037d9e729aa58b757099dd2fea5e7dcc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20Bo=C4=8Dek?= Date: Fri, 6 Jun 2025 13:36:45 +0200 Subject: [PATCH] fix: leak of aiodns.DNSResolver when ClientSession is closed (#11150) Co-authored-by: J. Nick Koston --- CHANGES/11150.bugfix.rst | 3 +++ CONTRIBUTORS.txt | 1 + aiohttp/connector.py | 18 +++++++++++++++--- aiohttp/resolver.py | 5 +++-- tests/test_connector.py | 19 +++++++++++++++++++ 5 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 CHANGES/11150.bugfix.rst diff --git a/CHANGES/11150.bugfix.rst b/CHANGES/11150.bugfix.rst new file mode 100644 index 00000000000..8a51b2e4f0c --- /dev/null +++ b/CHANGES/11150.bugfix.rst @@ -0,0 +1,3 @@ +Fixed leak of ``aiodns.DNSResolver`` when :py:class:`~aiohttp.TCPConnector` is closed and no resolver was passed when creating the connector -- by :user:`Tasssadar`. + +This was a regression introduced in version 3.12.0 (:pr:`10897`). diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 42328be3848..aa1f744b7a3 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -378,6 +378,7 @@ Vladimir Shulyak Vladimir Zakharov Vladyslav Bohaichuk Vladyslav Bondar +Vojtěch Boček W. Trevor King Wei Lin Weiwei Wang diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 27926d3de88..4ee0d570127 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -883,9 +883,14 @@ def __init__( "got {!r} instead.".format(ssl) ) self._ssl = ssl + + self._resolver: AbstractResolver if resolver is None: - resolver = DefaultResolver() - self._resolver: AbstractResolver = resolver + self._resolver = DefaultResolver() + self._resolver_owner = True + else: + self._resolver = resolver + self._resolver_owner = False self._use_dns_cache = use_dns_cache self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) @@ -912,6 +917,12 @@ def _close_immediately(self) -> List[Awaitable[object]]: return waiters + async def close(self) -> None: + """Close all opened transports.""" + if self._resolver_owner: + await self._resolver.close() + await super().close() + @property def family(self) -> int: """Socket family like AF_INET.""" @@ -1567,7 +1578,8 @@ def __init__( limit_per_host=limit_per_host, ) if not isinstance( - self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] + self._loop, + asyncio.ProactorEventLoop, # type: ignore[attr-defined] ): raise RuntimeError( "Named Pipes only available in proactor loop under windows" diff --git a/aiohttp/resolver.py b/aiohttp/resolver.py index 9cdcdb0864b..b07bd9716f2 100644 --- a/aiohttp/resolver.py +++ b/aiohttp/resolver.py @@ -220,9 +220,10 @@ def release_resolver( loop: The event loop the resolver was using. """ # Remove client from its loop's tracking - if loop not in self._loop_data: + current_loop_data = self._loop_data.get(loop) + if current_loop_data is None: return - resolver, client_set = self._loop_data[loop] + resolver, client_set = current_loop_data client_set.discard(client) # If no more clients for this loop, cancel and remove its resolver if not client_set: diff --git a/tests/test_connector.py b/tests/test_connector.py index 4ae73ce9b9c..10ba4227a33 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1301,6 +1301,7 @@ async def test_tcp_connector_dns_cache_not_expired( with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() + m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) @@ -1314,6 +1315,7 @@ async def test_tcp_connector_dns_cache_forever( with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() + m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) @@ -1327,6 +1329,7 @@ async def test_tcp_connector_use_dns_cache_disabled( with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=False) m_resolver().resolve.side_effect = [dns_response(), dns_response()] + m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080) await conn._resolve_host("localhost", 8080) m_resolver().resolve.assert_has_calls( @@ -1345,6 +1348,7 @@ async def test_tcp_connector_dns_throttle_requests( with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() + m_resolver().close = mock.AsyncMock() t = loop.create_task(conn._resolve_host("localhost", 8080)) t2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) @@ -1365,6 +1369,7 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread( conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) e = Exception() m_resolver().resolve.side_effect = e + m_resolver().close = mock.AsyncMock() r1 = loop.create_task(conn._resolve_host("localhost", 8080)) r2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) @@ -1383,6 +1388,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() + m_resolver().close = mock.AsyncMock() t = loop.create_task(conn._resolve_host("localhost", 8080)) f = loop.create_task(conn._resolve_host("localhost", 8080)) @@ -1429,6 +1435,7 @@ def exception_handler(loop: asyncio.AbstractEventLoop, context: object) -> None: use_dns_cache=False, ) m_resolver().resolve.return_value = dns_response_error() + m_resolver().close = mock.AsyncMock() f = loop.create_task(conn._create_direct_connection(req, [], ClientTimeout(0))) await asyncio.sleep(0) @@ -1466,6 +1473,7 @@ async def test_tcp_connector_dns_tracing( conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() + m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080, traces=traces) on_dns_resolvehost_start.assert_called_once_with( @@ -1509,6 +1517,7 @@ async def test_tcp_connector_dns_tracing_cache_disabled( conn = aiohttp.TCPConnector(use_dns_cache=False) m_resolver().resolve.side_effect = [dns_response(), dns_response()] + m_resolver().close = mock.AsyncMock() await conn._resolve_host("localhost", 8080, traces=traces) @@ -1565,6 +1574,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests( with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() + m_resolver().close = mock.AsyncMock() t = loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) t1 = loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) await asyncio.sleep(0) @@ -1583,6 +1593,14 @@ async def test_tcp_connector_dns_tracing_throttle_requests( await conn.close() +async def test_tcp_connector_close_resolver() -> None: + m_resolver = mock.AsyncMock() + with mock.patch("aiohttp.connector.DefaultResolver", return_value=m_resolver): + conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) + await conn.close() + m_resolver.close.assert_awaited_once() + + async def test_dns_error(loop: asyncio.AbstractEventLoop) -> None: connector = aiohttp.TCPConnector() with mock.patch.object( @@ -3834,6 +3852,7 @@ async def resolve_response() -> List[ResolveResult]: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: m_resolver().resolve.return_value = resolve_response() + m_resolver().close = mock.AsyncMock() connector = TCPConnector() traces = [DummyTracer()]