diff --git a/CHANGES/11268.feature.rst b/CHANGES/11268.feature.rst new file mode 100644 index 00000000000..e38de7c49e0 --- /dev/null +++ b/CHANGES/11268.feature.rst @@ -0,0 +1,2 @@ +Updated ``_TracingSignal`` to utilize a secondary generic variable for type hinting custom context variables +-- by :user:`Vizonex`. diff --git a/CHANGES/11271.contrib.rst b/CHANGES/11271.contrib.rst new file mode 100644 index 00000000000..6db394f1496 --- /dev/null +++ b/CHANGES/11271.contrib.rst @@ -0,0 +1 @@ +Updated a regex in `test_aiohttp_request_coroutine` for Python 3.14. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 32c0747adf6..eae7c357602 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -59,6 +59,7 @@ Arthur Darcet Austin Scola Bai Haoran Ben Bader +Ben Beasley Ben Greiner Ben Kallus Ben Timby diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index 8325ed1812c..b2cf433b6e8 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -11,9 +11,6 @@ if TYPE_CHECKING: from .client import ClientSession - _ParamT_contra = TypeVar("_ParamT_contra", contravariant=True) - _TracingSignal = Signal[ClientSession, SimpleNamespace, _ParamT_contra] - __all__ = ( "TraceConfig", @@ -36,6 +33,8 @@ ) _T = TypeVar("_T", covariant=True) +_ParamT_contra = TypeVar("_ParamT_contra", contravariant=True) +_TracingSignal = Signal["ClientSession", _T, _ParamT_contra] class _Factory(Protocol[_T]): @@ -52,46 +51,52 @@ def __init__(self, trace_config_ctx_factory: _Factory[_T]) -> None: ... def __init__( self, trace_config_ctx_factory: _Factory[Any] = SimpleNamespace ) -> None: - self._on_request_start: _TracingSignal[TraceRequestStartParams] = Signal(self) - self._on_request_chunk_sent: _TracingSignal[TraceRequestChunkSentParams] = ( + self._on_request_start: _TracingSignal[_T, TraceRequestStartParams] = Signal( + self + ) + self._on_request_chunk_sent: _TracingSignal[_T, TraceRequestChunkSentParams] = ( Signal(self) ) self._on_response_chunk_received: _TracingSignal[ - TraceResponseChunkReceivedParams + _T, TraceResponseChunkReceivedParams ] = Signal(self) - self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(self) - self._on_request_exception: _TracingSignal[TraceRequestExceptionParams] = ( + self._on_request_end: _TracingSignal[_T, TraceRequestEndParams] = Signal(self) + self._on_request_exception: _TracingSignal[_T, TraceRequestExceptionParams] = ( Signal(self) ) - self._on_request_redirect: _TracingSignal[TraceRequestRedirectParams] = Signal( - self + self._on_request_redirect: _TracingSignal[_T, TraceRequestRedirectParams] = ( + Signal(self) ) self._on_connection_queued_start: _TracingSignal[ - TraceConnectionQueuedStartParams + _T, TraceConnectionQueuedStartParams ] = Signal(self) self._on_connection_queued_end: _TracingSignal[ - TraceConnectionQueuedEndParams + _T, TraceConnectionQueuedEndParams ] = Signal(self) self._on_connection_create_start: _TracingSignal[ - TraceConnectionCreateStartParams + _T, TraceConnectionCreateStartParams ] = Signal(self) self._on_connection_create_end: _TracingSignal[ - TraceConnectionCreateEndParams + _T, TraceConnectionCreateEndParams ] = Signal(self) self._on_connection_reuseconn: _TracingSignal[ - TraceConnectionReuseconnParams + _T, TraceConnectionReuseconnParams ] = Signal(self) self._on_dns_resolvehost_start: _TracingSignal[ - TraceDnsResolveHostStartParams + _T, TraceDnsResolveHostStartParams ] = Signal(self) - self._on_dns_resolvehost_end: _TracingSignal[TraceDnsResolveHostEndParams] = ( - Signal(self) + self._on_dns_resolvehost_end: _TracingSignal[ + _T, TraceDnsResolveHostEndParams + ] = Signal(self) + self._on_dns_cache_hit: _TracingSignal[_T, TraceDnsCacheHitParams] = Signal( + self ) - self._on_dns_cache_hit: _TracingSignal[TraceDnsCacheHitParams] = Signal(self) - self._on_dns_cache_miss: _TracingSignal[TraceDnsCacheMissParams] = Signal(self) - self._on_request_headers_sent: _TracingSignal[TraceRequestHeadersSentParams] = ( - Signal(self) + self._on_dns_cache_miss: _TracingSignal[_T, TraceDnsCacheMissParams] = Signal( + self ) + self._on_request_headers_sent: _TracingSignal[ + _T, TraceRequestHeadersSentParams + ] = Signal(self) self._trace_config_ctx_factory: _Factory[_T] = trace_config_ctx_factory @@ -118,89 +123,91 @@ def freeze(self) -> None: self._on_request_headers_sent.freeze() @property - def on_request_start(self) -> "_TracingSignal[TraceRequestStartParams]": + def on_request_start(self) -> "_TracingSignal[_T, TraceRequestStartParams]": return self._on_request_start @property - def on_request_chunk_sent(self) -> "_TracingSignal[TraceRequestChunkSentParams]": + def on_request_chunk_sent( + self, + ) -> "_TracingSignal[_T, TraceRequestChunkSentParams]": return self._on_request_chunk_sent @property def on_response_chunk_received( self, - ) -> "_TracingSignal[TraceResponseChunkReceivedParams]": + ) -> "_TracingSignal[_T, TraceResponseChunkReceivedParams]": return self._on_response_chunk_received @property - def on_request_end(self) -> "_TracingSignal[TraceRequestEndParams]": + def on_request_end(self) -> "_TracingSignal[_T, TraceRequestEndParams]": return self._on_request_end @property def on_request_exception( self, - ) -> "_TracingSignal[TraceRequestExceptionParams]": + ) -> "_TracingSignal[_T, TraceRequestExceptionParams]": return self._on_request_exception @property def on_request_redirect( self, - ) -> "_TracingSignal[TraceRequestRedirectParams]": + ) -> "_TracingSignal[_T, TraceRequestRedirectParams]": return self._on_request_redirect @property def on_connection_queued_start( self, - ) -> "_TracingSignal[TraceConnectionQueuedStartParams]": + ) -> "_TracingSignal[_T, TraceConnectionQueuedStartParams]": return self._on_connection_queued_start @property def on_connection_queued_end( self, - ) -> "_TracingSignal[TraceConnectionQueuedEndParams]": + ) -> "_TracingSignal[_T, TraceConnectionQueuedEndParams]": return self._on_connection_queued_end @property def on_connection_create_start( self, - ) -> "_TracingSignal[TraceConnectionCreateStartParams]": + ) -> "_TracingSignal[_T, TraceConnectionCreateStartParams]": return self._on_connection_create_start @property def on_connection_create_end( self, - ) -> "_TracingSignal[TraceConnectionCreateEndParams]": + ) -> "_TracingSignal[_T, TraceConnectionCreateEndParams]": return self._on_connection_create_end @property def on_connection_reuseconn( self, - ) -> "_TracingSignal[TraceConnectionReuseconnParams]": + ) -> "_TracingSignal[_T, TraceConnectionReuseconnParams]": return self._on_connection_reuseconn @property def on_dns_resolvehost_start( self, - ) -> "_TracingSignal[TraceDnsResolveHostStartParams]": + ) -> "_TracingSignal[_T, TraceDnsResolveHostStartParams]": return self._on_dns_resolvehost_start @property def on_dns_resolvehost_end( self, - ) -> "_TracingSignal[TraceDnsResolveHostEndParams]": + ) -> "_TracingSignal[_T, TraceDnsResolveHostEndParams]": return self._on_dns_resolvehost_end @property - def on_dns_cache_hit(self) -> "_TracingSignal[TraceDnsCacheHitParams]": + def on_dns_cache_hit(self) -> "_TracingSignal[_T, TraceDnsCacheHitParams]": return self._on_dns_cache_hit @property - def on_dns_cache_miss(self) -> "_TracingSignal[TraceDnsCacheMissParams]": + def on_dns_cache_miss(self) -> "_TracingSignal[_T, TraceDnsCacheMissParams]": return self._on_dns_cache_miss @property def on_request_headers_sent( self, - ) -> "_TracingSignal[TraceRequestHeadersSentParams]": + ) -> "_TracingSignal[_T, TraceRequestHeadersSentParams]": return self._on_request_headers_sent diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index d403aab107e..d22a749dda8 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -3670,8 +3670,12 @@ async def handler(request: web.Request) -> web.Response: not_an_awaitable = aiohttp.request("GET", server.make_url("/")) with pytest.raises( TypeError, - match="^object _SessionRequestContextManager " - "can't be used in 'await' expression$", + match=( + "^'_SessionRequestContextManager' object can't be awaited$" + if sys.version_info >= (3, 14) + else "^object _SessionRequestContextManager " + "can't be used in 'await' expression$" + ), ): await not_an_awaitable # type: ignore[misc] diff --git a/tests/test_tracing.py b/tests/test_tracing.py index d989dacf57f..e50d908b2ae 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -1,10 +1,13 @@ +import sys from types import SimpleNamespace from typing import Any, Tuple from unittest import mock from unittest.mock import Mock import pytest +from aiosignal import Signal +from aiohttp import ClientSession from aiohttp.tracing import ( Trace, TraceConfig, @@ -25,15 +28,28 @@ TraceResponseChunkReceivedParams, ) +if sys.version_info >= (3, 11): + from typing import assert_type + class TestTraceConfig: def test_trace_config_ctx_default(self) -> None: trace_config = TraceConfig() assert isinstance(trace_config.trace_config_ctx(), SimpleNamespace) + if sys.version_info >= (3, 11): + assert_type( + trace_config.on_request_chunk_sent, + Signal[ClientSession, SimpleNamespace, TraceRequestChunkSentParams], + ) def test_trace_config_ctx_factory(self) -> None: trace_config = TraceConfig(trace_config_ctx_factory=dict) assert isinstance(trace_config.trace_config_ctx(), dict) + if sys.version_info >= (3, 11): + assert_type( + trace_config.on_request_start, + Signal[ClientSession, dict[str, Any], TraceRequestStartParams], + ) def test_trace_config_ctx_request_ctx(self) -> None: trace_request_ctx = Mock()