diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 2b8a38e02c0..698a34302c7 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -138,7 +138,7 @@ jobs: needs: gen_llhttp strategy: matrix: - pyver: [3.9, '3.10', '3.11', '3.12', '3.13', '3.14'] + pyver: ['3.10', '3.11', '3.12', '3.13', '3.14'] no-extensions: ['', 'Y'] os: [ubuntu, macos, windows] experimental: [false] @@ -148,7 +148,7 @@ jobs: - os: windows no-extensions: 'Y' include: - - pyver: pypy-3.9 + - pyver: pypy-3.10 no-extensions: 'Y' os: ubuntu experimental: false diff --git a/.github/workflows/update-pre-commit.yml b/.github/workflows/update-pre-commit.yml index 4616f7abc7d..80921b56615 100644 --- a/.github/workflows/update-pre-commit.yml +++ b/.github/workflows/update-pre-commit.yml @@ -12,7 +12,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: 3.9 + python-version: 3.10 - name: Install dependencies run: >- pip install -r requirements/lint.in -c requirements/lint.txt diff --git a/CHANGES.rst b/CHANGES.rst index 673985ab3fc..945ca25d569 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4546,7 +4546,7 @@ Bugfixes `#5853 `_ - Added ``params`` keyword argument to ``ClientSession.ws_connect``. -- :user:`hoh`. `#5868 `_ -- Uses :py:class:`~asyncio.ThreadedChildWatcher` under POSIX to allow setting up test loop in non-main thread. +- Uses ``asyncio.ThreadedChildWatcher`` under POSIX to allow setting up test loop in non-main thread. `#5877 `_ - Fix the error in handling the return value of `getaddrinfo`. `getaddrinfo` will return an `(int, bytes)` tuple, if CPython could not handle the address family. diff --git a/CHANGES/11601.breaking.rst b/CHANGES/11601.breaking.rst new file mode 100644 index 00000000000..c2eccbd9e1c --- /dev/null +++ b/CHANGES/11601.breaking.rst @@ -0,0 +1 @@ +Dropped support for Python 3.9 -- by :user:`Dreamsorcerer`. diff --git a/Makefile b/Makefile index 099d1cf0af1..6531d435f22 100644 --- a/Makefile +++ b/Makefile @@ -128,16 +128,6 @@ define run_tests_in_docker docker run --rm -ti -v `pwd`:/src -w /src "aiohttp-test-$(1)-$(2)" $(TEST_SPEC) endef -.PHONY: test-3.9-no-extensions test -test-3.9-no-extensions: - $(call run_tests_in_docker,3.9,y) -test-3.9: - $(call run_tests_in_docker,3.9,n) -test-3.10-no-extensions: - $(call run_tests_in_docker,3.10,y) -test-3.10: - $(call run_tests_in_docker,3.10,n) - .PHONY: clean clean: @rm -rf `find . -name __pycache__` diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 3f8a1cc62dc..e01204a5adb 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -1,6 +1,6 @@ __version__ = "4.0.0a2.dev0" -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING from . import hdrs from .client import ( @@ -113,7 +113,7 @@ # At runtime these are lazy-loaded at the bottom of the file. from .worker import GunicornUVLoopWebWorker, GunicornWebWorker -__all__: Tuple[str, ...] = ( +__all__: tuple[str, ...] = ( "hdrs", # client "AddrInfoType", @@ -237,7 +237,7 @@ ) -def __dir__() -> Tuple[str, ...]: +def __dir__() -> tuple[str, ...]: return __all__ + ("__doc__",) diff --git a/aiohttp/_cookie_helpers.py b/aiohttp/_cookie_helpers.py index 9e80b6065d7..7fe8f43d12b 100644 --- a/aiohttp/_cookie_helpers.py +++ b/aiohttp/_cookie_helpers.py @@ -6,8 +6,9 @@ """ import re +from collections.abc import Sequence from http.cookies import Morsel -from typing import List, Optional, Sequence, Tuple, cast +from typing import cast from .log import internal_logger @@ -156,7 +157,7 @@ def _unquote(value: str) -> str: return _unquote_sub(_unquote_replace, value) -def parse_cookie_header(header: str) -> List[Tuple[str, Morsel[str]]]: +def parse_cookie_header(header: str) -> list[tuple[str, Morsel[str]]]: """ Parse a Cookie header according to RFC 6265 Section 5.4. @@ -176,7 +177,7 @@ def parse_cookie_header(header: str) -> List[Tuple[str, Morsel[str]]]: if not header: return [] - cookies: List[Tuple[str, Morsel[str]]] = [] + cookies: list[tuple[str, Morsel[str]]] = [] i = 0 n = len(header) @@ -211,7 +212,7 @@ def parse_cookie_header(header: str) -> List[Tuple[str, Morsel[str]]]: return cookies -def parse_set_cookie_headers(headers: Sequence[str]) -> List[Tuple[str, Morsel[str]]]: +def parse_set_cookie_headers(headers: Sequence[str]) -> list[tuple[str, Morsel[str]]]: """ Parse cookie headers using a vendored version of SimpleCookie parsing. @@ -230,7 +231,7 @@ def parse_set_cookie_headers(headers: Sequence[str]) -> List[Tuple[str, Morsel[s This implementation handles unmatched quotes more gracefully to prevent cookie loss. See https://github.com/aio-libs/aiohttp/issues/7993 """ - parsed_cookies: List[Tuple[str, Morsel[str]]] = [] + parsed_cookies: list[tuple[str, Morsel[str]]] = [] for header in headers: if not header: @@ -239,7 +240,7 @@ def parse_set_cookie_headers(headers: Sequence[str]) -> List[Tuple[str, Morsel[s # Parse cookie string using SimpleCookie's algorithm i = 0 n = len(header) - current_morsel: Optional[Morsel[str]] = None + current_morsel: Morsel[str] | None = None morsel_seen = False while 0 <= i < n: diff --git a/aiohttp/_websocket/helpers.py b/aiohttp/_websocket/helpers.py index 06ced46cdab..f9a44cdd39b 100644 --- a/aiohttp/_websocket/helpers.py +++ b/aiohttp/_websocket/helpers.py @@ -2,8 +2,9 @@ import functools import re +from re import Pattern from struct import Struct -from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple +from typing import TYPE_CHECKING, Final from ..helpers import NO_EXTENSIONS from .models import WSHandshakeError @@ -23,7 +24,7 @@ # Used by _websocket_mask_python @functools.lru_cache -def _xor_table() -> List[bytes]: +def _xor_table() -> list[bytes]: return [bytes(a ^ b for a in range(256)) for b in range(256)] @@ -74,7 +75,7 @@ def _websocket_mask_python(mask: bytes, data: bytearray) -> None: _WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") -def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: +def ws_ext_parse(extstr: str | None, isserver: bool = False) -> tuple[int, bool]: if not extstr: return 0, False diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index ee771d30df3..085fb460cb5 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -1,8 +1,9 @@ """Models for WebSocket protocol versions 13 and 8.""" import json +from collections.abc import Callable from enum import IntEnum -from typing import Any, Callable, Final, Literal, NamedTuple, Optional, Union, cast +from typing import Any, Final, Literal, NamedTuple, Union, cast WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF]) @@ -41,18 +42,18 @@ class WSMsgType(IntEnum): class WSMessageContinuation(NamedTuple): data: bytes size: int - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.CONTINUATION] = WSMsgType.CONTINUATION class WSMessageText(NamedTuple): data: str size: int - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT def json( - self, *, loads: Callable[[Union[str, bytes, bytearray]], Any] = json.loads + self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads ) -> Any: """Return parsed JSON data.""" return loads(self.data) @@ -61,11 +62,11 @@ def json( class WSMessageBinary(NamedTuple): data: bytes size: int - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.BINARY] = WSMsgType.BINARY def json( - self, *, loads: Callable[[Union[str, bytes, bytearray]], Any] = json.loads + self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads ) -> Any: """Return parsed JSON data.""" return loads(self.data) @@ -74,42 +75,42 @@ def json( class WSMessagePing(NamedTuple): data: bytes size: int - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.PING] = WSMsgType.PING class WSMessagePong(NamedTuple): data: bytes size: int - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.PONG] = WSMsgType.PONG class WSMessageClose(NamedTuple): data: int size: int - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.CLOSE] = WSMsgType.CLOSE class WSMessageClosing(NamedTuple): data: None = None size: int = 0 - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.CLOSING] = WSMsgType.CLOSING class WSMessageClosed(NamedTuple): data: None = None size: int = 0 - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.CLOSED] = WSMsgType.CLOSED class WSMessageError(NamedTuple): data: BaseException size: int = 0 - extra: Optional[str] = None + extra: str | None = None type: Literal[WSMsgType.ERROR] = WSMsgType.ERROR diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index f022aa4d220..5bcc2ecfb78 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -3,7 +3,7 @@ import asyncio import builtins from collections import deque -from typing import Deque, Final, Optional, Set, Tuple, Type, Union +from typing import Final from ..base_protocol import BaseProtocol from ..compression_utils import ZLibDecompressor @@ -23,7 +23,7 @@ WSMsgType, ) -ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} +ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode} # States for the reader, used to parse the WebSocket frame # integer values are used so they can be cythonized @@ -70,21 +70,21 @@ def __init__( self._limit = limit * 2 self._loop = loop self._eof = False - self._waiter: Optional[asyncio.Future[None]] = None - self._exception: Union[Type[BaseException], BaseException, None] = None - self._buffer: Deque[WSMessage] = deque() + self._waiter: asyncio.Future[None] | None = None + self._exception: type[BaseException] | BaseException | None = None + self._buffer: deque[WSMessage] = deque() self._get_buffer = self._buffer.popleft self._put_buffer = self._buffer.append def is_eof(self) -> bool: return self._eof - def exception(self) -> Optional[Union[Type[BaseException], BaseException]]: + def exception(self) -> type[BaseException] | BaseException | None: return self._exception def set_exception( self, - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: builtins.BaseException = _EXC_SENTINEL, ) -> None: self._eof = True @@ -144,7 +144,7 @@ def __init__( self.queue = queue self._max_msg_size = max_msg_size - self._exc: Optional[Exception] = None + self._exc: Exception | None = None self._partial = bytearray() self._state = READ_HEADER @@ -156,11 +156,11 @@ def __init__( self._tail: bytes = b"" self._has_mask = False - self._frame_mask: Optional[bytes] = None + self._frame_mask: bytes | None = None self._payload_bytes_to_read = 0 self._payload_len_flag = 0 self._compressed: int = COMPRESSED_NOT_SET - self._decompressobj: Optional[ZLibDecompressor] = None + self._decompressobj: ZLibDecompressor | None = None self._compress = compress def feed_eof(self) -> None: @@ -169,9 +169,7 @@ def feed_eof(self) -> None: # data can be bytearray on Windows because proactor event loop uses bytearray # and asyncio types this to Union[bytes, bytearray, memoryview] so we need # coerce data to bytes if it is not - def feed_data( - self, data: Union[bytes, bytearray, memoryview] - ) -> Tuple[bool, bytes]: + def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]: if type(data) is not bytes: data = bytes(data) @@ -190,9 +188,9 @@ def feed_data( def _handle_frame( self, fin: bool, - opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int - payload: Union[bytes, bytearray], - compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int + opcode: int | cython_int, # Union intended: Cython pxd uses C int + payload: bytes | bytearray, + compressed: int | cython_int, # Union intended: Cython pxd uses C int ) -> None: msg: WSMessage if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: @@ -228,7 +226,7 @@ def _handle_frame( f"to be zero, got {opcode!r}", ) - assembled_payload: Union[bytes, bytearray] + assembled_payload: bytes | bytearray if has_partial: assembled_payload = self._partial + payload self._partial.clear() @@ -452,7 +450,7 @@ def _feed_data(self, data: bytes) -> None: self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) break - payload: Union[bytes, bytearray] + payload: bytes | bytearray if had_fragments: # We have to join the payload fragments get the payload self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos]) diff --git a/aiohttp/_websocket/writer.py b/aiohttp/_websocket/writer.py index 8ba98280304..fdbcda45c3c 100644 --- a/aiohttp/_websocket/writer.py +++ b/aiohttp/_websocket/writer.py @@ -3,7 +3,7 @@ import asyncio import random from functools import partial -from typing import Any, Final, Optional, Union +from typing import Any, Final from ..base_protocol import BaseProtocol from ..client_exceptions import ClientConnectionResetError @@ -65,7 +65,7 @@ def __init__( self._compressobj: Any = None # actually compressobj async def send_frame( - self, message: bytes, opcode: int, compress: Optional[int] = None + self, message: bytes, opcode: int, compress: int | None = None ) -> None: """Send a frame over the websocket with message as its payload.""" if self._closing and not (opcode & WSMsgType.CLOSE): @@ -166,7 +166,7 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor: max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, ) - async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None: + async def close(self, code: int = 1000, message: bytes | str = b"") -> None: """Close the websocket, sending the specified code and message.""" if isinstance(message, str): message = message.encode("utf-8") diff --git a/aiohttp/abc.py b/aiohttp/abc.py index f29828285b0..a21e6b5d0bd 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -1,23 +1,9 @@ import logging import socket from abc import ABC, abstractmethod -from collections.abc import Sized +from collections.abc import Awaitable, Callable, Generator, Iterable, Sequence, Sized from http.cookies import BaseCookie, Morsel -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - Tuple, - TypedDict, - Union, -) +from typing import TYPE_CHECKING, Any, TypedDict from multidict import CIMultiDict from yarl import URL @@ -31,8 +17,8 @@ from .web_request import BaseRequest, Request from .web_response import StreamResponse else: - BaseRequest = Request = Application = StreamResponse = None - HTTPException = None + BaseRequest = Request = Application = StreamResponse = Any + HTTPException = Any class AbstractRouter(ABC): @@ -73,21 +59,21 @@ def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: @abstractmethod def expect_handler( self, - ) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]: + ) -> Callable[[Request], Awaitable[StreamResponse | None]]: """Expect handler for 100-continue processing""" @property # pragma: no branch @abstractmethod - def http_exception(self) -> Optional[HTTPException]: + def http_exception(self) -> HTTPException | None: """HTTPException instance raised on router's resolving, or None""" @abstractmethod # pragma: no branch - def get_info(self) -> Dict[str, Any]: + def get_info(self) -> dict[str, Any]: """Return a dict with additional info useful for introspection""" @property # pragma: no branch @abstractmethod - def apps(self) -> Tuple[Application, ...]: + def apps(self) -> tuple[Application, ...]: """Stack of nested applications. Top level application is left-most element. @@ -153,7 +139,7 @@ class AbstractResolver(ABC): @abstractmethod async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: """Return IP address for given hostname""" @abstractmethod @@ -179,7 +165,7 @@ def quote_cookie(self) -> bool: """Return True if cookies should be quoted.""" @abstractmethod - def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: + def clear(self, predicate: ClearCookiePredicate | None = None) -> None: """Clear all cookies if no predicate is passed.""" @abstractmethod @@ -198,7 +184,7 @@ def update_cookies_from_headers( self.update_cookies(cookies_to_update, response_url) @abstractmethod - def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": + def filter_cookies(self, request_url: URL) -> BaseCookie[str]: """Return the jar's cookies filtered by their attributes.""" @@ -207,11 +193,11 @@ class AbstractStreamWriter(ABC): buffer_size: int = 0 output_size: int = 0 - length: Optional[int] = 0 + length: int | None = 0 @abstractmethod async def write( - self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + self, chunk: "bytes | bytearray | memoryview[int] | memoryview[bytes]" ) -> None: """Write chunk into stream.""" @@ -225,7 +211,7 @@ async def drain(self) -> None: @abstractmethod def enable_compression( - self, encoding: str = "deflate", strategy: Optional[int] = None + self, encoding: str = "deflate", strategy: int | None = None ) -> None: """Enable HTTP body compression""" @@ -234,9 +220,7 @@ def enable_chunking(self) -> None: """Enable HTTP chunked mode""" @abstractmethod - async def write_headers( - self, status_line: str, headers: "CIMultiDict[str]" - ) -> None: + async def write_headers(self, status_line: str, headers: CIMultiDict[str]) -> None: """Write HTTP headers""" def send_headers(self) -> None: diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index b0a67ed6ff6..7f01830f4e9 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional, cast +from typing import cast from .client_exceptions import ClientConnectionResetError from .helpers import set_exception @@ -19,10 +19,10 @@ class BaseProtocol(asyncio.Protocol): def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False - self._drain_waiter: Optional[asyncio.Future[None]] = None + self._drain_waiter: asyncio.Future[None] | None = None self._reading_paused = False - self.transport: Optional[asyncio.Transport] = None + self.transport: asyncio.Transport | None = None @property def connected(self) -> bool: @@ -68,7 +68,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: tcp_nodelay(tr, True) self.transport = tr - def connection_lost(self, exc: Optional[BaseException]) -> None: + def connection_lost(self, exc: BaseException | None) -> None: # Wake up the writer if currently paused. self.transport = None if not self._paused: diff --git a/aiohttp/client.py b/aiohttp/client.py index 6a8c667491f..a7da3ff0c57 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -9,32 +9,19 @@ import sys import traceback import warnings -from contextlib import suppress -from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, +from collections.abc import ( Awaitable, Callable, Collection, Coroutine, - Final, - FrozenSet, Generator, - Generic, Iterable, - List, Mapping, - Optional, Sequence, - Set, - Tuple, - Type, - TypedDict, - TypeVar, - Union, - final, ) +from contextlib import suppress +from types import TracebackType +from typing import TYPE_CHECKING, Any, Final, Generic, TypedDict, TypeVar, final from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr from yarl import URL @@ -173,37 +160,37 @@ class _RequestOptions(TypedDict, total=False): params: Query data: Any json: Any - cookies: Union[LooseCookies, None] - headers: Union[LooseHeaders, None] - skip_auto_headers: Union[Iterable[str], None] - auth: Union[BasicAuth, None] + cookies: LooseCookies | None + headers: LooseHeaders | None + skip_auto_headers: Iterable[str] | None + auth: BasicAuth | None allow_redirects: bool max_redirects: int - compress: Union[str, bool] - chunked: Union[bool, None] + compress: str | bool + chunked: bool | None expect100: bool - raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]] + raise_for_status: None | bool | Callable[[ClientResponse], Awaitable[None]] read_until_eof: bool - proxy: Union[StrOrURL, None] - proxy_auth: Union[BasicAuth, None] - timeout: "Union[ClientTimeout, _SENTINEL, None]" - ssl: Union[SSLContext, bool, Fingerprint] - server_hostname: Union[str, None] - proxy_headers: Union[LooseHeaders, None] - trace_request_ctx: Union[Mapping[str, Any], None] - read_bufsize: Union[int, None] - auto_decompress: Union[bool, None] - max_line_size: Union[int, None] - max_field_size: Union[int, None] - middlewares: Optional[Sequence[ClientMiddlewareType]] + proxy: StrOrURL | None + proxy_auth: BasicAuth | None + timeout: "ClientTimeout | _SENTINEL | None" + ssl: SSLContext | bool | Fingerprint + server_hostname: str | None + proxy_headers: LooseHeaders | None + trace_request_ctx: Mapping[str, Any] | None + read_bufsize: int | None + auto_decompress: bool | None + max_line_size: int | None + max_field_size: int | None + middlewares: Sequence[ClientMiddlewareType] | None @frozen_dataclass_decorator class ClientTimeout: - total: Optional[float] = None - connect: Optional[float] = None - sock_read: Optional[float] = None - sock_connect: Optional[float] = None + total: float | None = None + connect: float | None = None + sock_read: float | None = None + sock_connect: float | None = None ceil_threshold: float = 5 # pool_queue_timeout: Optional[float] = None @@ -268,42 +255,40 @@ class ClientSession: def __init__( self, - base_url: Optional[StrOrURL] = None, + base_url: StrOrURL | None = None, *, - connector: Optional[BaseConnector] = None, - cookies: Optional[LooseCookies] = None, - headers: Optional[LooseHeaders] = None, - proxy: Optional[StrOrURL] = None, - proxy_auth: Optional[BasicAuth] = None, - skip_auto_headers: Optional[Iterable[str]] = None, - auth: Optional[BasicAuth] = None, + connector: BaseConnector | None = None, + cookies: LooseCookies | None = None, + headers: LooseHeaders | None = None, + proxy: StrOrURL | None = None, + proxy_auth: BasicAuth | None = None, + skip_auto_headers: Iterable[str] | None = None, + auth: BasicAuth | None = None, json_serialize: JSONEncoder = json.dumps, - request_class: Type[ClientRequest] = ClientRequest, - response_class: Type[ClientResponse] = ClientResponse, - ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse, + request_class: type[ClientRequest] = ClientRequest, + response_class: type[ClientResponse] = ClientResponse, + ws_response_class: type[ClientWebSocketResponse] = ClientWebSocketResponse, version: HttpVersion = http.HttpVersion11, - cookie_jar: Optional[AbstractCookieJar] = None, + cookie_jar: AbstractCookieJar | None = None, connector_owner: bool = True, - raise_for_status: Union[ - bool, Callable[[ClientResponse], Awaitable[None]] - ] = False, - timeout: Union[_SENTINEL, ClientTimeout, None] = sentinel, + raise_for_status: bool | Callable[[ClientResponse], Awaitable[None]] = False, + timeout: _SENTINEL | ClientTimeout | None = sentinel, auto_decompress: bool = True, trust_env: bool = False, requote_redirect_url: bool = True, - trace_configs: Optional[List[TraceConfig[object]]] = None, + trace_configs: list[TraceConfig[object]] | None = None, read_bufsize: int = 2**16, max_line_size: int = 8190, max_field_size: int = 8190, fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", middlewares: Sequence[ClientMiddlewareType] = (), - ssl_shutdown_timeout: Union[_SENTINEL, None, float] = sentinel, + ssl_shutdown_timeout: _SENTINEL | None | float = sentinel, ) -> None: # We initialise _connector to None immediately, as it's referenced in __del__() # and could cause issues if an exception occurs during initialisation. - self._connector: Optional[BaseConnector] = None + self._connector: BaseConnector | None = None if base_url is None or isinstance(base_url, URL): - self._base_url: Optional[URL] = base_url + self._base_url: URL | None = base_url self._base_url_origin = None if base_url is None else base_url.origin() else: self._base_url = URL(base_url) @@ -337,7 +322,7 @@ def __init__( self._connector = connector self._loop = loop if loop.get_debug(): - self._source_traceback: Optional[traceback.StackSummary] = ( + self._source_traceback: traceback.StackSummary | None = ( traceback.extract_stack(sys._getframe(1)) ) else: @@ -391,10 +376,9 @@ def __init__( self._retry_connection: bool = True self._middlewares = middlewares - def __init_subclass__(cls: Type["ClientSession"]) -> None: + def __init_subclass__(cls: type["ClientSession"]) -> None: raise TypeError( - "Inheritance class {} from ClientSession " - "is forbidden".format(cls.__name__) + f"Inheritance class {cls.__name__} from ClientSession " "is forbidden" ) def __del__(self, _warnings: Any = warnings) -> None: @@ -440,31 +424,31 @@ async def _request( params: Query = None, data: Any = None, json: Any = None, - cookies: Optional[LooseCookies] = None, - headers: Optional[LooseHeaders] = None, - skip_auto_headers: Optional[Iterable[str]] = None, - auth: Optional[BasicAuth] = None, + cookies: LooseCookies | None = None, + headers: LooseHeaders | None = None, + skip_auto_headers: Iterable[str] | None = None, + auth: BasicAuth | None = None, allow_redirects: bool = True, max_redirects: int = 10, - compress: Union[str, bool] = False, - chunked: Optional[bool] = None, + compress: str | bool = False, + chunked: bool | None = None, expect100: bool = False, - raise_for_status: Union[ - None, bool, Callable[[ClientResponse], Awaitable[None]] - ] = None, + raise_for_status: ( + None | bool | Callable[[ClientResponse], Awaitable[None]] + ) = None, read_until_eof: bool = True, - proxy: Optional[StrOrURL] = None, - proxy_auth: Optional[BasicAuth] = None, - timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel, - ssl: Union[SSLContext, bool, Fingerprint] = True, - server_hostname: Optional[str] = None, - proxy_headers: Optional[LooseHeaders] = None, - trace_request_ctx: Optional[Mapping[str, Any]] = None, - read_bufsize: Optional[int] = None, - auto_decompress: Optional[bool] = None, - max_line_size: Optional[int] = None, - max_field_size: Optional[int] = None, - middlewares: Optional[Sequence[ClientMiddlewareType]] = None, + proxy: StrOrURL | None = None, + proxy_auth: BasicAuth | None = None, + timeout: ClientTimeout | _SENTINEL | None = sentinel, + ssl: SSLContext | bool | Fingerprint = True, + server_hostname: str | None = None, + proxy_headers: LooseHeaders | None = None, + trace_request_ctx: Mapping[str, Any] | None = None, + read_bufsize: int | None = None, + auto_decompress: bool | None = None, + max_line_size: int | None = None, + max_field_size: int | None = None, + middlewares: Sequence[ClientMiddlewareType] | None = None, ) -> ClientResponse: # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants @@ -476,7 +460,7 @@ async def _request( if not isinstance(ssl, SSL_ALLOWED_TYPES): raise TypeError( "ssl should be SSLContext, Fingerprint, or bool, " - "got {!r} instead.".format(ssl) + f"got {ssl!r} instead." ) if data is not None and json is not None: @@ -487,7 +471,7 @@ async def _request( data = payload.JsonPayload(json, dumps=self._json_serialize) redirects = 0 - history: List[ClientResponse] = [] + history: list[ClientResponse] = [] version = self._version params = params or {} @@ -503,7 +487,7 @@ async def _request( if url.scheme not in self._connector.allowed_protocol_schema_set: raise NonHttpUrlClientError(url) - skip_headers: Optional[Iterable[istr]] + skip_headers: Iterable[istr] | None if skip_auto_headers is not None: skip_headers = { istr(i) for i in skip_auto_headers @@ -622,7 +606,7 @@ async def _request( if req_cookies: all_cookies.load(req_cookies) - proxy_: Optional[URL] = None + proxy_: URL | None = None if proxy is not None: proxy_ = URL(proxy) elif self._trust_env: @@ -879,20 +863,20 @@ def ws_connect( *, method: str = hdrs.METH_GET, protocols: Collection[str] = (), - timeout: Union[ClientWSTimeout, _SENTINEL] = sentinel, - receive_timeout: Optional[float] = None, + timeout: ClientWSTimeout | _SENTINEL = sentinel, + receive_timeout: float | None = None, autoclose: bool = True, autoping: bool = True, - heartbeat: Optional[float] = None, - auth: Optional[BasicAuth] = None, - origin: Optional[str] = None, + heartbeat: float | None = None, + auth: BasicAuth | None = None, + origin: str | None = None, params: Query = None, - headers: Optional[LooseHeaders] = None, - proxy: Optional[StrOrURL] = None, - proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, Fingerprint] = True, - server_hostname: Optional[str] = None, - proxy_headers: Optional[LooseHeaders] = None, + headers: LooseHeaders | None = None, + proxy: StrOrURL | None = None, + proxy_auth: BasicAuth | None = None, + ssl: SSLContext | bool | Fingerprint = True, + server_hostname: str | None = None, + proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, ) -> "_WSRequestContextManager": @@ -927,20 +911,20 @@ async def _ws_connect( *, method: str = hdrs.METH_GET, protocols: Collection[str] = (), - timeout: Union[ClientWSTimeout, _SENTINEL] = sentinel, - receive_timeout: Optional[float] = None, + timeout: ClientWSTimeout | _SENTINEL = sentinel, + receive_timeout: float | None = None, autoclose: bool = True, autoping: bool = True, - heartbeat: Optional[float] = None, - auth: Optional[BasicAuth] = None, - origin: Optional[str] = None, + heartbeat: float | None = None, + auth: BasicAuth | None = None, + origin: str | None = None, params: Query = None, - headers: Optional[LooseHeaders] = None, - proxy: Optional[StrOrURL] = None, - proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, Fingerprint] = True, - server_hostname: Optional[str] = None, - proxy_headers: Optional[LooseHeaders] = None, + headers: LooseHeaders | None = None, + proxy: StrOrURL | None = None, + proxy_auth: BasicAuth | None = None, + ssl: SSLContext | bool | Fingerprint = True, + server_hostname: str | None = None, + proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, ) -> ClientWebSocketResponse: @@ -996,7 +980,7 @@ async def _ws_connect( if not isinstance(ssl, SSL_ALLOWED_TYPES): raise TypeError( "ssl should be SSLContext, Fingerprint, or bool, " - "got {!r} instead.".format(ssl) + f"got {ssl!r} instead." ) # send request @@ -1131,14 +1115,14 @@ async def _ws_connect( client_notakeover=notakeover, ) - def _prepare_headers(self, headers: Optional[LooseHeaders]) -> "CIMultiDict[str]": + def _prepare_headers(self, headers: LooseHeaders | None) -> "CIMultiDict[str]": """Add default headers and transform it to CIMultiDict""" # Convert headers to MultiDict result = CIMultiDict(self._default_headers) if headers: if not isinstance(headers, (MultiDictProxy, MultiDict)): headers = CIMultiDict(headers) - added_names: Set[str] = set() + added_names: set[str] = set() for key, value in headers.items(): if key in added_names: result.add(key, value) @@ -1272,7 +1256,7 @@ def closed(self) -> bool: return self._connector is None or self._connector.closed @property - def connector(self) -> Optional[BaseConnector]: + def connector(self) -> BaseConnector | None: """Connector instance used for the session.""" return self._connector @@ -1282,7 +1266,7 @@ def cookie_jar(self) -> AbstractCookieJar: return self._cookie_jar @property - def version(self) -> Tuple[int, int]: + def version(self) -> tuple[int, int]: """The session HTTP protocol version.""" return self._version @@ -1302,12 +1286,12 @@ def headers(self) -> "CIMultiDict[str]": return self._default_headers @property - def skip_auto_headers(self) -> FrozenSet[istr]: + def skip_auto_headers(self) -> frozenset[istr]: """Headers for which autogeneration should be skipped""" return self._skip_auto_headers @property - def auth(self) -> Optional[BasicAuth]: + def auth(self) -> BasicAuth | None: """An object that represents HTTP Basic Authorization""" return self._default_auth @@ -1324,7 +1308,7 @@ def connector_owner(self) -> bool: @property def raise_for_status( self, - ) -> Union[bool, Callable[[ClientResponse], Awaitable[None]]]: + ) -> bool | Callable[[ClientResponse], Awaitable[None]]: """Should `ClientResponse.raise_for_status()` be called for each response.""" return self._raise_for_status @@ -1344,7 +1328,7 @@ def trust_env(self) -> bool: return self._trust_env @property - def trace_configs(self) -> List[TraceConfig[Any]]: + def trace_configs(self) -> list[TraceConfig[Any]]: """A list of TraceConfig instances used for client tracing""" return self._trace_configs @@ -1360,9 +1344,9 @@ async def __aenter__(self) -> "ClientSession": async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: await self.close() @@ -1371,7 +1355,7 @@ class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType __slots__ = ("_coro", "_resp") def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: - self._coro: Coroutine["asyncio.Future[Any]", None, _RetType] = coro + self._coro: Coroutine[asyncio.Future[Any], None, _RetType] = coro def send(self, arg: None) -> "asyncio.Future[Any]": return self._coro.send(arg) @@ -1395,9 +1379,9 @@ async def __aenter__(self) -> _RetType: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: await self._resp.__aexit__(exc_type, exc, tb) @@ -1415,7 +1399,7 @@ def __init__( session: ClientSession, ) -> None: self._coro = coro - self._resp: Optional[ClientResponse] = None + self._resp: ClientResponse | None = None self._session = session async def __aenter__(self) -> ClientResponse: @@ -1429,9 +1413,9 @@ async def __aenter__(self) -> ClientResponse: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: assert self._resp is not None self._resp.close() @@ -1445,7 +1429,7 @@ def request( url: StrOrURL, *, version: HttpVersion = http.HttpVersion11, - connector: Optional[BaseConnector] = None, + connector: BaseConnector | None = None, **kwargs: Unpack[_RequestOptions], ) -> _SessionRequestContextManager: ... @@ -1456,7 +1440,7 @@ def request( url: StrOrURL, *, version: HttpVersion = http.HttpVersion11, - connector: Optional[BaseConnector] = None, + connector: BaseConnector | None = None, **kwargs: Any, ) -> _SessionRequestContextManager: """Constructs and sends a request. diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index da159d0ae7d..7bff03171b2 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -1,7 +1,7 @@ """HTTP related errors.""" import asyncio -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Union from multidict import MultiMapping @@ -74,11 +74,11 @@ class ClientResponseError(ClientError): def __init__( self, request_info: RequestInfo, - history: Tuple[ClientResponse, ...], + history: tuple[ClientResponse, ...], *, - status: Optional[int] = None, + status: int | None = None, message: str = "", - headers: Optional[MultiMapping[str]] = None, + headers: MultiMapping[str] | None = None, ) -> None: self.request_info = request_info if status is not None: @@ -91,11 +91,7 @@ def __init__( self.args = (request_info, history) def __str__(self) -> str: - return "{}, message={!r}, url={!r}".format( - self.status, - self.message, - str(self.request_info.real_url), - ) + return f"{self.status}, message={self.message!r}, url={str(self.request_info.real_url)!r}" def __repr__(self) -> str: args = f"{self.request_info!r}, {self.history!r}" @@ -163,7 +159,7 @@ def host(self) -> str: return self._conn_key.host @property - def port(self) -> Optional[int]: + def port(self) -> int | None: return self._conn_key.port @property @@ -225,7 +221,7 @@ class ServerConnectionError(ClientConnectionError): class ServerDisconnectedError(ServerConnectionError): """Server disconnected.""" - def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None: + def __init__(self, message: RawResponseMessage | str | None = None) -> None: if message is None: message = "Server disconnected" @@ -256,9 +252,7 @@ def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None: self.args = (expected, got, host, port) def __repr__(self) -> str: - return "<{} expected={!r} got={!r} host={!r} port={!r}>".format( - self.__class__.__name__, self.expected, self.got, self.host, self.port - ) + return f"<{self.__class__.__name__} expected={self.expected!r} got={self.got!r} host={self.host!r} port={self.port!r}>" class ClientPayloadError(ClientError): @@ -274,7 +268,7 @@ class InvalidURL(ClientError, ValueError): # Derive from ValueError for backward compatibility - def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None: + def __init__(self, url: StrOrURL, description: str | None = None) -> None: # The type of url is not yarl.URL because the exception can be raised # on URL(url) call self._url = url @@ -369,7 +363,7 @@ def host(self) -> str: return self._conn_key.host @property - def port(self) -> Optional[int]: + def port(self) -> int | None: return self._conn_key.port @property @@ -378,9 +372,9 @@ def ssl(self) -> bool: def __str__(self) -> str: return ( - "Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} " - "[{0.certificate_error.__class__.__name__}: " - "{0.certificate_error.args}]".format(self) + f"Cannot connect to host {self.host}:{self.port} ssl:{self.ssl} " + f"[{self.certificate_error.__class__.__name__}: " + f"{self.certificate_error.args}]" ) diff --git a/aiohttp/client_middleware_digest_auth.py b/aiohttp/client_middleware_digest_auth.py index c1ed7ca0fdd..18d47c96219 100644 --- a/aiohttp/client_middleware_digest_auth.py +++ b/aiohttp/client_middleware_digest_auth.py @@ -11,17 +11,8 @@ import os import re import time -from typing import ( - Callable, - Dict, - Final, - FrozenSet, - List, - Literal, - Tuple, - TypedDict, - Union, -) +from collections.abc import Callable +from typing import Final, Literal, TypedDict from yarl import URL @@ -42,7 +33,7 @@ class DigestAuthChallenge(TypedDict, total=False): stale: str -DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = { +DigestFunctions: dict[str, Callable[[bytes], "hashlib._Hash"]] = { "MD5": hashlib.md5, "MD5-SESS": hashlib.md5, "SHA": hashlib.sha1, @@ -83,7 +74,7 @@ class DigestAuthChallenge(TypedDict, total=False): # RFC 7616: Challenge parameters to extract CHALLENGE_FIELDS: Final[ - Tuple[ + tuple[ Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ... ] ] = ( @@ -98,14 +89,14 @@ class DigestAuthChallenge(TypedDict, total=False): # Supported digest authentication algorithms # Use a tuple of sorted keys for predictable documentation and error messages -SUPPORTED_ALGORITHMS: Final[Tuple[str, ...]] = tuple(sorted(DigestFunctions.keys())) +SUPPORTED_ALGORITHMS: Final[tuple[str, ...]] = tuple(sorted(DigestFunctions.keys())) # RFC 7616: Fields that require quoting in the Digest auth header # These fields must be enclosed in double quotes in the Authorization header. # Algorithm, qop, and nc are never quoted per RFC specifications. # This frozen set is used by the template-based header construction to # automatically determine which fields need quotes. -QUOTED_AUTH_FIELDS: Final[FrozenSet[str]] = frozenset( +QUOTED_AUTH_FIELDS: Final[frozenset[str]] = frozenset( {"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"} ) @@ -120,7 +111,7 @@ def unescape_quotes(value: str) -> str: return value.replace('\\"', '"') -def parse_header_pairs(header: str) -> Dict[str, str]: +def parse_header_pairs(header: str) -> dict[str, str]: """ Parse key-value pairs from WWW-Authenticate or similar HTTP headers. @@ -202,11 +193,9 @@ def __init__( self._challenge: DigestAuthChallenge = {} self._preemptive: bool = preemptive # Set of URLs defining the protection space - self._protection_space: List[str] = [] + self._protection_space: list[str] = [] - async def _encode( - self, method: str, url: URL, body: Union[Payload, Literal[b""]] - ) -> str: + async def _encode(self, method: str, url: URL, body: Payload | Literal[b""]) -> str: """ Build digest authorization header for the current challenge. @@ -358,7 +347,7 @@ def KD(s: bytes, d: bytes) -> bytes: header_fields["cnonce"] = cnonce # Build header using templates for each field type - pairs: List[str] = [] + pairs: list[str] = [] for field, value in header_fields.items(): if field in QUOTED_AUTH_FIELDS: pairs.append(f'{field}="{value}"') diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 93221faaa30..144cb42d52b 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -1,6 +1,6 @@ import asyncio from contextlib import suppress -from typing import Any, Optional, Tuple, Type, Union +from typing import Any from .base_protocol import BaseProtocol from .client_exceptions import ( @@ -22,7 +22,7 @@ from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader -class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]): +class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamReader]]): """Helper class to adapt between Protocol and StreamReader.""" def __init__(self, loop: asyncio.AbstractEventLoop) -> None: @@ -31,26 +31,26 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._should_close = False - self._payload: Optional[StreamReader] = None + self._payload: StreamReader | None = None self._skip_payload = False - self._payload_parser: Optional[WebSocketReader] = None + self._payload_parser: WebSocketReader | None = None self._timer = None self._tail = b"" self._upgraded = False - self._parser: Optional[HttpResponseParser] = None + self._parser: HttpResponseParser | None = None - self._read_timeout: Optional[float] = None - self._read_timeout_handle: Optional[asyncio.TimerHandle] = None + self._read_timeout: float | None = None + self._read_timeout_handle: asyncio.TimerHandle | None = None - self._timeout_ceil_threshold: Optional[float] = 5 + self._timeout_ceil_threshold: float | None = 5 - self._closed: Union[None, asyncio.Future[None]] = None + self._closed: None | asyncio.Future[None] = None self._connection_lost_called = False @property - def closed(self) -> Union[None, asyncio.Future[None]]: + def closed(self) -> None | asyncio.Future[None]: """Future that is set when the connection is closed. This property returns a Future that will be completed when the connection @@ -107,7 +107,7 @@ def abort(self) -> None: def is_connected(self) -> bool: return self.transport is not None and not self.transport.is_closing() - def connection_lost(self, exc: Optional[BaseException]) -> None: + def connection_lost(self, exc: BaseException | None) -> None: self._connection_lost_called = True self._drop_timeout() @@ -196,7 +196,7 @@ def resume_reading(self) -> None: def set_exception( self, - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: self._should_close = True @@ -221,11 +221,11 @@ def set_parser(self, parser: Any, payload: Any) -> None: def set_response_params( self, *, - timer: Optional[BaseTimerContext] = None, + timer: BaseTimerContext | None = None, skip_payload: bool = False, read_until_eof: bool = False, auto_decompress: bool = True, - read_timeout: Optional[float] = None, + read_timeout: float | None = None, read_bufsize: int = 2**16, timeout_ceil_threshold: float = 5, max_line_size: int = 8190, @@ -275,11 +275,11 @@ def start_timeout(self) -> None: self._reschedule_timeout() @property - def read_timeout(self) -> Optional[float]: + def read_timeout(self) -> float | None: return self._read_timeout @read_timeout.setter - def read_timeout(self, read_timeout: Optional[float]) -> None: + def read_timeout(self, read_timeout: float | None) -> None: self._read_timeout = read_timeout def _on_read_timeout(self) -> None: @@ -333,7 +333,7 @@ def data_received(self, data: bytes) -> None: self._upgraded = upgraded - payload: Optional[StreamReader] = None + payload: StreamReader | None = None for message, payload in messages: if message.should_close: self._should_close = True diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 56131263a68..880e1085bab 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -7,24 +7,11 @@ import sys import traceback import warnings -from collections.abc import Mapping +from collections.abc import Callable, Iterable, Mapping from hashlib import md5, sha1, sha256 from http.cookies import Morsel, SimpleCookie from types import MappingProxyType, TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Literal, - NamedTuple, - Optional, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Union from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL @@ -119,9 +106,9 @@ def _gen_default_accept_encoding() -> str: @frozen_dataclass_decorator class ContentDisposition: - type: Optional[str] + type: str | None parameters: "MappingProxyType[str, str]" - filename: Optional[str] + filename: str | None class _RequestInfo(NamedTuple): @@ -138,7 +125,7 @@ def __new__( url: URL, method: str, headers: "CIMultiDictProxy[str]", - real_url: Union[URL, _SENTINEL] = sentinel, + real_url: URL | _SENTINEL = sentinel, ) -> "RequestInfo": """Create a new RequestInfo instance. @@ -197,12 +184,12 @@ class ConnectionKey(NamedTuple): # the key should contain an information about used proxy / TLS # to prevent reusing wrong connections from a pool host: str - port: Optional[int] + port: int | None is_ssl: bool - ssl: Union[SSLContext, bool, Fingerprint] - proxy: Optional[URL] - proxy_auth: Optional[BasicAuth] - proxy_headers_hash: Optional[int] # hash(CIMultiDict) + ssl: SSLContext | bool | Fingerprint + proxy: URL | None + proxy_auth: BasicAuth | None + proxy_headers_hash: int | None # hash(CIMultiDict) def _warn_if_unclosed_payload(payload: payload.Payload, stacklevel: int = 2) -> None: @@ -229,21 +216,21 @@ class ClientResponse(HeadersMixin): # but will be set by the start() method. # As the end user will likely never see the None values, we cheat the types below. # from the Status-Line of the response - version: Optional[HttpVersion] = None # HTTP-Version + version: HttpVersion | None = None # HTTP-Version status: int = None # type: ignore[assignment] # Status-Code - reason: Optional[str] = None # Reason-Phrase + reason: str | None = None # Reason-Phrase content: StreamReader = None # type: ignore[assignment] # Payload stream - _body: Optional[bytes] = None + _body: bytes | None = None _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] - _history: Tuple["ClientResponse", ...] = () + _history: tuple["ClientResponse", ...] = () _raw_headers: RawHeaders = None # type: ignore[assignment] _connection: Optional["Connection"] = None # current connection - _cookies: Optional[SimpleCookie] = None - _raw_cookie_headers: Optional[Tuple[str, ...]] = None + _cookies: SimpleCookie | None = None + _raw_cookie_headers: tuple[str, ...] | None = None _continue: Optional["asyncio.Future[bool]"] = None - _source_traceback: Optional[traceback.StackSummary] = None + _source_traceback: traceback.StackSummary | None = None _session: Optional["ClientSession"] = None # set up by ClientRequest after ClientResponse object creation # post-init stage allows to not change ctor signature @@ -260,11 +247,11 @@ def __init__( method: str, url: URL, *, - writer: "Optional[asyncio.Task[None]]", + writer: "asyncio.Task[None] | None", continue100: Optional["asyncio.Future[bool]"], - timer: Optional[BaseTimerContext], + timer: BaseTimerContext | None, request_info: RequestInfo, - traces: List["Trace"], + traces: list["Trace"], loop: asyncio.AbstractEventLoop, session: "ClientSession", ) -> None: @@ -281,7 +268,7 @@ def __init__( self._continue = continue100 self._request_info = request_info self._timer = timer if timer is not None else TimerNoop() - self._cache: Dict[str, Any] = {} + self._cache: dict[str, Any] = {} self._traces = traces self._loop = loop # Save reference to _resolve_charset, so that get_encoding() will still @@ -371,7 +358,7 @@ def request_info(self) -> RequestInfo: return self._request_info @reify - def content_disposition(self) -> Optional[ContentDisposition]: + def content_disposition(self) -> ContentDisposition | None: raw = self._headers.get(hdrs.CONTENT_DISPOSITION) if raw is None: return None @@ -407,9 +394,7 @@ def __repr__(self) -> str: else: ascii_encodable_reason = "None" print( - "".format( - ascii_encodable_url, self.status, ascii_encodable_reason - ), + f"", file=out, ) print(self.headers, file=out) @@ -420,18 +405,18 @@ def connection(self) -> Optional["Connection"]: return self._connection @reify - def history(self) -> Tuple["ClientResponse", ...]: + def history(self) -> tuple["ClientResponse", ...]: """A sequence of responses, if redirects occurred.""" return self._history @reify - def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]": + def links(self) -> "MultiDictProxy[MultiDictProxy[str | URL]]": links_str = ", ".join(self.headers.getall("link", [])) if not links_str: return MultiDictProxy(MultiDict()) - links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict() + links: MultiDict[MultiDictProxy[str | URL]] = MultiDict() for val in re.split(r",(?=\s*<)", links_str): match = re.match(r"\s*<(.*)>(.*)", val) @@ -440,7 +425,7 @@ def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]": url, params_str = match.groups() params = params_str.split(";")[1:] - link: MultiDict[Union[str, URL]] = MultiDict() + link: MultiDict[str | URL] = MultiDict() for param in params: match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M) @@ -662,7 +647,7 @@ def get_encoding(self) -> str: return self._resolve_charset(self, self._body) - async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str: + async def text(self, encoding: str | None = None, errors: str = "strict") -> str: """Read response payload and decode.""" await self.read() @@ -674,9 +659,9 @@ async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> async def json( self, *, - encoding: Optional[str] = None, + encoding: str | None = None, loads: JSONDecoder = DEFAULT_JSON_DECODER, - content_type: Optional[str] = "application/json", + content_type: str | None = "application/json", ) -> Any: """Read and decodes JSON response.""" await self.read() @@ -705,9 +690,9 @@ async def __aenter__(self) -> "ClientResponse": async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: self._in_context = False # similar to _RequestContextManager, we do not need to check @@ -733,7 +718,7 @@ class ClientRequest: } # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. - _body: Union[None, payload.Payload] = None + _body: None | payload.Payload = None auth = None response = None @@ -758,26 +743,26 @@ def __init__( url: URL, *, params: Query = None, - headers: Optional[LooseHeaders] = None, - skip_auto_headers: Optional[Iterable[str]] = None, + headers: LooseHeaders | None = None, + skip_auto_headers: Iterable[str] | None = None, data: Any = None, - cookies: Optional[LooseCookies] = None, - auth: Optional[BasicAuth] = None, + cookies: LooseCookies | None = None, + auth: BasicAuth | None = None, version: http.HttpVersion = http.HttpVersion11, - compress: Union[str, bool] = False, - chunked: Optional[bool] = None, + compress: str | bool = False, + chunked: bool | None = None, expect100: bool = False, loop: asyncio.AbstractEventLoop, - response_class: Optional[Type["ClientResponse"]] = None, - proxy: Optional[URL] = None, - proxy_auth: Optional[BasicAuth] = None, - timer: Optional[BaseTimerContext] = None, + response_class: type["ClientResponse"] | None = None, + proxy: URL | None = None, + proxy_auth: BasicAuth | None = None, + timer: BaseTimerContext | None = None, session: Optional["ClientSession"] = None, - ssl: Union[SSLContext, bool, Fingerprint] = True, - proxy_headers: Optional[LooseHeaders] = None, - traces: Optional[List["Trace"]] = None, + ssl: SSLContext | bool | Fingerprint = True, + proxy_headers: LooseHeaders | None = None, + traces: list["Trace"] | None = None, trust_env: bool = False, - server_hostname: Optional[str] = None, + server_hostname: str | None = None, ): if match := _CONTAINS_CONTROL_CHAR_RE.search(method): raise ValueError( @@ -805,7 +790,7 @@ def __init__( real_response_class = ClientResponse else: real_response_class = response_class - self.response_class: Type[ClientResponse] = real_response_class + self.response_class: type[ClientResponse] = real_response_class self._timer = timer if timer is not None else TimerNoop() self._ssl = ssl self.server_hostname = server_hostname @@ -831,7 +816,7 @@ def __init__( def __reset_writer(self, _: object = None) -> None: self.__writer = None - def _get_content_length(self) -> Optional[int]: + def _get_content_length(self) -> int | None: """Extract and validate Content-Length header value. Returns parsed Content-Length value or None if not set. @@ -873,7 +858,7 @@ def ssl(self) -> Union["SSLContext", bool, Fingerprint]: @property def connection_key(self) -> ConnectionKey: if proxy_headers := self.proxy_headers: - h: Optional[int] = hash(tuple(proxy_headers.items())) + h: int | None = hash(tuple(proxy_headers.items())) else: h = None url = self.url @@ -897,11 +882,11 @@ def host(self) -> str: return ret @property - def port(self) -> Optional[int]: + def port(self) -> int | None: return self.url.port @property - def body(self) -> Union[payload.Payload, Literal[b""]]: + def body(self) -> payload.Payload | Literal[b""]: """Request body.""" # empty body is represented as bytes for backwards compatibility return self._body or b"" @@ -961,7 +946,7 @@ def update_host(self, url: URL) -> None: if url.raw_user or url.raw_password: self.auth = helpers.BasicAuth(url.user or "", url.password or "") - def update_version(self, version: Union[http.HttpVersion, str]) -> None: + def update_version(self, version: http.HttpVersion | str) -> None: """Convert request version to two elements tuple. parser HTTP version '1.1' => (1, 1) @@ -976,7 +961,7 @@ def update_version(self, version: Union[http.HttpVersion, str]) -> None: ) from None self.version = version - def update_headers(self, headers: Optional[LooseHeaders]) -> None: + def update_headers(self, headers: LooseHeaders | None) -> None: """Update request headers.""" self.headers: CIMultiDict[str] = CIMultiDict() @@ -1001,7 +986,7 @@ def update_headers(self, headers: Optional[LooseHeaders]) -> None: else: self.headers.add(key, value) - def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> None: + def update_auto_headers(self, skip_auto_headers: Iterable[str] | None) -> None: if skip_auto_headers is not None: self._skip_auto_headers = CIMultiDict( (hdr, None) for hdr in sorted(skip_auto_headers) @@ -1020,7 +1005,7 @@ def update_auto_headers(self, skip_auto_headers: Optional[Iterable[str]]) -> Non if hdrs.USER_AGENT not in used_headers: self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE - def update_cookies(self, cookies: Optional[LooseCookies]) -> None: + def update_cookies(self, cookies: LooseCookies | None) -> None: """Update request cookies header.""" if not cookies: return @@ -1044,7 +1029,7 @@ def update_cookies(self, cookies: Optional[LooseCookies]) -> None: self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() - def update_content_encoding(self, data: Any, compress: Union[bool, str]) -> None: + def update_content_encoding(self, data: Any, compress: bool | str) -> None: """Set request content encoding.""" self.compress = None if not data: @@ -1079,7 +1064,7 @@ def update_transfer_encoding(self) -> None: self.headers[hdrs.TRANSFER_ENCODING] = "chunked" - def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: + def update_auth(self, auth: BasicAuth | None, trust_env: bool = False) -> None: """Set basic auth.""" if auth is None: auth = self.auth @@ -1117,7 +1102,7 @@ def update_body_from_data(self, body: Any, _stacklevel: int = 3) -> None: try: body_payload = payload.PAYLOAD_REGISTRY.get(maybe_payload, disposition=None) except payload.LookupError: - boundary: Optional[str] = None + boundary: str | None = None if CONTENT_TYPE in self.headers: boundary = parse_mimetype(self.headers[CONTENT_TYPE]).parameters.get( "boundary" @@ -1235,9 +1220,9 @@ def update_expect_continue(self, expect: bool = False) -> None: def update_proxy( self, - proxy: Optional[URL], - proxy_auth: Optional[BasicAuth], - proxy_headers: Optional[LooseHeaders], + proxy: URL | None, + proxy_auth: BasicAuth | None, + proxy_headers: LooseHeaders | None, ) -> None: self.proxy = proxy if proxy is None: @@ -1259,7 +1244,7 @@ async def write_bytes( self, writer: AbstractStreamWriter, conn: "Connection", - content_length: Optional[int], + content_length: int | None, ) -> None: """ Write the request body to the connection stream. @@ -1399,7 +1384,7 @@ async def send(self, conn: "Connection") -> "ClientResponse": # Buffer headers for potential coalescing with body await writer.write_headers(status_line, self.headers) - task: Optional["asyncio.Task[None]"] + task: asyncio.Task[None] | None if self._body or self._continue is not None or protocol.writing_paused: coro = self.write_bytes(writer, conn, self._get_content_length()) if sys.version_info >= (3, 12): diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index bfd53ea64bf..36959aae0c7 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -3,7 +3,7 @@ import asyncio import sys from types import TracebackType -from typing import Any, Final, Optional, Type +from typing import Any, Final from ._websocket.reader import WebSocketDataQueue from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError @@ -34,8 +34,8 @@ @frozen_dataclass_decorator class ClientWSTimeout: - ws_receive: Optional[float] = None - ws_close: Optional[float] = None + ws_receive: float | None = None + ws_close: float | None = None DEFAULT_WS_CLIENT_TIMEOUT: Final[ClientWSTimeout] = ClientWSTimeout( @@ -48,14 +48,14 @@ def __init__( self, reader: WebSocketDataQueue, writer: WebSocketWriter, - protocol: Optional[str], + protocol: str | None, response: ClientResponse, timeout: ClientWSTimeout, autoclose: bool, autoping: bool, loop: asyncio.AbstractEventLoop, *, - heartbeat: Optional[float] = None, + heartbeat: float | None = None, compress: int = 0, client_notakeover: bool = False, ) -> None: @@ -67,23 +67,23 @@ def __init__( self._protocol = protocol self._closed = False self._closing = False - self._close_code: Optional[int] = None + self._close_code: int | None = None self._timeout = timeout self._autoclose = autoclose self._autoping = autoping self._heartbeat = heartbeat - self._heartbeat_cb: Optional[asyncio.TimerHandle] = None + self._heartbeat_cb: asyncio.TimerHandle | None = None self._heartbeat_when: float = 0.0 if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 - self._pong_response_cb: Optional[asyncio.TimerHandle] = None + self._pong_response_cb: asyncio.TimerHandle | None = None self._loop = loop self._waiting: bool = False - self._close_wait: Optional[asyncio.Future[None]] = None - self._exception: Optional[BaseException] = None + self._close_wait: asyncio.Future[None] | None = None + self._exception: BaseException | None = None self._compress = compress self._client_notakeover = client_notakeover - self._ping_task: Optional[asyncio.Task[None]] = None + self._ping_task: asyncio.Task[None] | None = None self._reset_heartbeat() @@ -199,11 +199,11 @@ def closed(self) -> bool: return self._closed @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: return self._close_code @property - def protocol(self) -> Optional[str]: + def protocol(self) -> str | None: return self._protocol @property @@ -224,7 +224,7 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: return default return transport.get_extra_info(name, default) - def exception(self) -> Optional[BaseException]: + def exception(self) -> BaseException | None: return self._exception async def ping(self, message: bytes = b"") -> None: @@ -234,19 +234,19 @@ async def pong(self, message: bytes = b"") -> None: await self._writer.send_frame(message, WSMsgType.PONG) async def send_frame( - self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None + self, message: bytes, opcode: WSMsgType, compress: int | None = None ) -> None: """Send a frame over the websocket.""" await self._writer.send_frame(message, opcode, compress) - async def send_str(self, data: str, compress: Optional[int] = None) -> None: + async def send_str(self, data: str, compress: int | None = None) -> None: if not isinstance(data, str): raise TypeError("data argument must be str (%r)" % type(data)) await self._writer.send_frame( data.encode("utf-8"), WSMsgType.TEXT, compress=compress ) - async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: + async def send_bytes(self, data: bytes, compress: int | None = None) -> None: if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("data argument must be byte-ish (%r)" % type(data)) await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress) @@ -254,7 +254,7 @@ async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: async def send_json( self, data: Any, - compress: Optional[int] = None, + compress: int | None = None, *, dumps: JSONEncoder = DEFAULT_JSON_ENCODER, ) -> None: @@ -309,7 +309,7 @@ async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bo self._response.close() return True - async def receive(self, timeout: Optional[float] = None) -> WSMessage: + async def receive(self, timeout: float | None = None) -> WSMessage: receive_timeout = timeout or self._timeout.ws_receive while True: @@ -383,7 +383,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return msg - async def receive_str(self, *, timeout: Optional[float] = None) -> str: + async def receive_str(self, *, timeout: float | None = None) -> str: msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( @@ -391,7 +391,7 @@ async def receive_str(self, *, timeout: Optional[float] = None) -> str: ) return msg.data - async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: + async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) if msg.type is not WSMsgType.BINARY: raise WSMessageTypeError( @@ -403,7 +403,7 @@ async def receive_json( self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER, - timeout: Optional[float] = None, + timeout: float | None = None, ) -> Any: data = await self.receive_str(timeout=timeout) return loads(data) @@ -422,8 +422,8 @@ async def __aenter__(self) -> "ClientWebSocketResponse": async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: await self.close() diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py index 9f5562ea1cb..67aab49f91b 100644 --- a/aiohttp/compression_utils.py +++ b/aiohttp/compression_utils.py @@ -2,7 +2,7 @@ import sys import zlib from concurrent.futures import Executor -from typing import Any, Final, Optional, Protocol, TypedDict, cast +from typing import Any, Final, Protocol, TypedDict, cast if sys.version_info >= (3, 12): from collections.abc import Buffer @@ -63,7 +63,7 @@ def compressobj( wbits: int = ..., memLevel: int = ..., strategy: int = ..., - zdict: Optional[Buffer] = ..., + zdict: Buffer | None = ..., ) -> ZLibCompressObjProtocol: ... def decompressobj( self, wbits: int = ..., zdict: Buffer = ... @@ -136,7 +136,7 @@ def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None: def encoding_to_mode( - encoding: Optional[str] = None, + encoding: str | None = None, suppress_deflate_header: bool = False, ) -> int: if encoding == "gzip": @@ -149,8 +149,8 @@ class ZlibBaseHandler: def __init__( self, mode: int, - executor: Optional[Executor] = None, - max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + executor: Executor | None = None, + max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): self._mode = mode self._executor = executor @@ -160,13 +160,13 @@ def __init__( class ZLibCompressor(ZlibBaseHandler): def __init__( self, - encoding: Optional[str] = None, + encoding: str | None = None, suppress_deflate_header: bool = False, - level: Optional[int] = None, - wbits: Optional[int] = None, - strategy: Optional[int] = None, - executor: Optional[Executor] = None, - max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + level: int | None = None, + wbits: int | None = None, + strategy: int | None = None, + executor: Executor | None = None, + max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): super().__init__( mode=( @@ -214,7 +214,7 @@ async def compress(self, data: Buffer) -> bytes: ) return self.compress_sync(data) - def flush(self, mode: Optional[int] = None) -> bytes: + def flush(self, mode: int | None = None) -> bytes: return self._compressor.flush( mode if mode is not None else self._zlib_backend.Z_FINISH ) @@ -223,10 +223,10 @@ def flush(self, mode: Optional[int] = None) -> bytes: class ZLibDecompressor(ZlibBaseHandler): def __init__( self, - encoding: Optional[str] = None, + encoding: str | None = None, suppress_deflate_header: bool = False, - executor: Optional[Executor] = None, - max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + executor: Executor | None = None, + max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, ): super().__init__( mode=encoding_to_mode(encoding, suppress_deflate_header), diff --git a/aiohttp/connector.py b/aiohttp/connector.py index e1288424bff..a6eec4d3d15 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -6,30 +6,13 @@ import traceback import warnings from collections import OrderedDict, defaultdict, deque +from collections.abc import Awaitable, Callable, Iterator, Sequence from contextlib import suppress from http import HTTPStatus from itertools import chain, cycle, islice from time import monotonic from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - DefaultDict, - Deque, - Dict, - Iterator, - List, - Literal, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast import aiohappyeyeballs from aiohappyeyeballs import AddrInfoType, SocketFactoryType @@ -134,8 +117,8 @@ def __init__( self._key = key self._connector = connector self._loop = loop - self._protocol: Optional[ResponseHandler] = protocol - self._callbacks: List[Callable[[], None]] = [] + self._protocol: ResponseHandler | None = protocol + self._callbacks: list[Callable[[], None]] = [] self._source_traceback = ( traceback.extract_stack(sys._getframe(1)) if loop.get_debug() else None ) @@ -163,13 +146,13 @@ def __bool__(self) -> Literal[True]: return True @property - def transport(self) -> Optional[asyncio.Transport]: + def transport(self) -> asyncio.Transport | None: if self._protocol is None: return None return self._protocol.transport @property - def protocol(self) -> Optional[ResponseHandler]: + def protocol(self) -> ResponseHandler | None: return self._protocol def add_callback(self, callback: Callable[[], None]) -> None: @@ -227,7 +210,7 @@ class _TransportPlaceholder: __slots__ = ("closed", "transport") - def __init__(self, closed_future: asyncio.Future[Optional[Exception]]) -> None: + def __init__(self, closed_future: asyncio.Future[Exception | None]) -> None: """Initialize a placeholder for a transport.""" self.closed = closed_future self.transport = None @@ -265,7 +248,7 @@ class BaseConnector: def __init__( self, *, - keepalive_timeout: Union[_SENTINEL, None, float] = sentinel, + keepalive_timeout: _SENTINEL | None | float = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, @@ -292,13 +275,13 @@ def __init__( # Connection pool of reusable connections. # We use a deque to store connections because it has O(1) popleft() # and O(1) append() operations to implement a FIFO queue. - self._conns: DefaultDict[ - ConnectionKey, Deque[Tuple[ResponseHandler, float]] + self._conns: defaultdict[ + ConnectionKey, deque[tuple[ResponseHandler, float]] ] = defaultdict(deque) self._limit = limit self._limit_per_host = limit_per_host - self._acquired: Set[ResponseHandler] = set() - self._acquired_per_host: DefaultDict[ConnectionKey, Set[ResponseHandler]] = ( + self._acquired: set[ResponseHandler] = set() + self._acquired_per_host: defaultdict[ConnectionKey, set[ResponseHandler]] = ( defaultdict(set) ) self._keepalive_timeout = cast(float, keepalive_timeout) @@ -307,7 +290,7 @@ def __init__( # {host_key: FIFO list of waiters} # The FIFO is implemented with an OrderedDict with None keys because # python does not have an ordered set. - self._waiters: DefaultDict[ + self._waiters: defaultdict[ ConnectionKey, OrderedDict[asyncio.Future[None], None] ] = defaultdict(OrderedDict) @@ -315,10 +298,10 @@ def __init__( self._factory = functools.partial(ResponseHandler, loop=loop) # start keep-alive connection cleanup task - self._cleanup_handle: Optional[asyncio.TimerHandle] = None + self._cleanup_handle: asyncio.TimerHandle | None = None # start cleanup closed transports task - self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None + self._cleanup_closed_handle: asyncio.TimerHandle | None = None if enable_cleanup_closed and not NEEDS_CLEANUP_CLOSED: warnings.warn( @@ -331,9 +314,9 @@ def __init__( enable_cleanup_closed = False self._cleanup_closed_disabled = not enable_cleanup_closed - self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = [] + self._cleanup_closed_transports: list[asyncio.Transport | None] = [] - self._placeholder_future: asyncio.Future[Optional[Exception]] = ( + self._placeholder_future: asyncio.Future[Exception | None] = ( loop.create_future() ) self._placeholder_future.set_result(None) @@ -364,9 +347,9 @@ async def __aenter__(self) -> "BaseConnector": async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + exc_traceback: TracebackType | None = None, ) -> None: await self.close() @@ -408,7 +391,7 @@ def _cleanup(self) -> None: connections = defaultdict(deque) deadline = now - timeout for key, conns in self._conns.items(): - alive: Deque[Tuple[ResponseHandler, float]] = deque() + alive: deque[tuple[ResponseHandler, float]] = deque() for proto, use_time in conns: if proto.is_connected() and use_time - deadline >= 0: alive.append((proto, use_time)) @@ -470,8 +453,8 @@ async def close(self, *, abort_ssl: bool = False) -> None: err_msg = "Error while closing connector: " + repr(res) client_logger.debug(err_msg) - def _close_immediately(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: - waiters: List[Awaitable[object]] = [] + def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: + waiters: list[Awaitable[object]] = [] if self._closed: return waiters @@ -567,7 +550,7 @@ def _available_connections(self, key: "ConnectionKey") -> int: return total_remain async def connect( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> Connection: """Get from pool or create new connection.""" key = req.connection_key @@ -622,7 +605,7 @@ async def connect( return Connection(self, key, proto, self._loop) async def _wait_for_available_connection( - self, key: "ConnectionKey", traces: List["Trace"] + self, key: "ConnectionKey", traces: list["Trace"] ) -> None: """Wait for an available connection slot.""" # We loop here because there is a race between @@ -664,8 +647,8 @@ async def _wait_for_available_connection( attempts += 1 async def _get( - self, key: "ConnectionKey", traces: List["Trace"] - ) -> Optional[Connection]: + self, key: "ConnectionKey", traces: list["Trace"] + ) -> Connection | None: """Get next reusable connection for the key or None. The connection will be marked as acquired. @@ -776,27 +759,27 @@ def _release( ) async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: raise NotImplementedError() class _DNSCacheTable: - def __init__(self, ttl: Optional[float] = None) -> None: - self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {} - self._timestamps: Dict[Tuple[str, int], float] = {} + def __init__(self, ttl: float | None = None) -> None: + self._addrs_rr: dict[tuple[str, int], tuple[Iterator[ResolveResult], int]] = {} + self._timestamps: dict[tuple[str, int], float] = {} self._ttl = ttl def __contains__(self, host: object) -> bool: return host in self._addrs_rr - def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None: + def add(self, key: tuple[str, int], addrs: list[ResolveResult]) -> None: self._addrs_rr[key] = (cycle(addrs), len(addrs)) if self._ttl is not None: self._timestamps[key] = monotonic() - def remove(self, key: Tuple[str, int]) -> None: + def remove(self, key: tuple[str, int]) -> None: self._addrs_rr.pop(key, None) if self._ttl is not None: @@ -806,14 +789,14 @@ def clear(self) -> None: self._addrs_rr.clear() self._timestamps.clear() - def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]: + def next_addrs(self, key: tuple[str, int]) -> list[ResolveResult]: loop, length = self._addrs_rr[key] addrs = list(islice(loop, length)) # Consume one more element to shift internal state of `cycle` next(loop) return addrs - def expired(self, key: Tuple[str, int]) -> bool: + def expired(self, key: tuple[str, int]) -> bool: if self._ttl is None: return False @@ -896,21 +879,21 @@ def __init__( self, *, use_dns_cache: bool = True, - ttl_dns_cache: Optional[int] = 10, + ttl_dns_cache: int | None = 10, family: socket.AddressFamily = socket.AddressFamily.AF_UNSPEC, - ssl: Union[bool, Fingerprint, SSLContext] = True, - local_addr: Optional[Tuple[str, int]] = None, - resolver: Optional[AbstractResolver] = None, - keepalive_timeout: Union[None, float, _SENTINEL] = sentinel, + ssl: bool | Fingerprint | SSLContext = True, + local_addr: tuple[str, int] | None = None, + resolver: AbstractResolver | None = None, + keepalive_timeout: None | float | _SENTINEL = sentinel, force_close: bool = False, limit: int = 100, limit_per_host: int = 0, enable_cleanup_closed: bool = False, timeout_ceil_threshold: float = 5, - happy_eyeballs_delay: Optional[float] = 0.25, - interleave: Optional[int] = None, - socket_factory: Optional[SocketFactoryType] = None, - ssl_shutdown_timeout: Union[_SENTINEL, None, float] = sentinel, + happy_eyeballs_delay: float | None = 0.25, + interleave: int | None = None, + socket_factory: SocketFactoryType | None = None, + ssl_shutdown_timeout: _SENTINEL | None | float = sentinel, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -924,7 +907,7 @@ def __init__( if not isinstance(ssl, SSL_ALLOWED_TYPES): raise TypeError( "ssl should be SSLContext, Fingerprint, or bool, " - "got {!r} instead.".format(ssl) + f"got {ssl!r} instead." ) self._ssl = ssl @@ -938,16 +921,16 @@ def __init__( self._use_dns_cache = use_dns_cache self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) - self._throttle_dns_futures: Dict[Tuple[str, int], Set[asyncio.Future[None]]] = ( + self._throttle_dns_futures: dict[tuple[str, int], set[asyncio.Future[None]]] = ( {} ) self._family = family self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr) self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave - self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() + self._resolve_host_tasks: set[asyncio.Task[list[ResolveResult]]] = set() self._socket_factory = socket_factory - self._ssl_shutdown_timeout: Optional[float] + self._ssl_shutdown_timeout: float | None # Handle ssl_shutdown_timeout with warning for Python < 3.11 if ssl_shutdown_timeout is sentinel: @@ -986,7 +969,7 @@ async def close(self, *, abort_ssl: bool = False) -> None: # Use abort_ssl param if explicitly set, otherwise use ssl_shutdown_timeout default await super().close(abort_ssl=abort_ssl or self._ssl_shutdown_timeout == 0) - def _close_immediately(self, *, abort_ssl: bool = False) -> List[Awaitable[object]]: + def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]: for fut in chain.from_iterable(self._throttle_dns_futures.values()): fut.cancel() @@ -1008,9 +991,7 @@ def use_dns_cache(self) -> bool: """True if local DNS caching is enabled.""" return self._use_dns_cache - def clear_dns_cache( - self, host: Optional[str] = None, port: Optional[int] = None - ) -> None: + def clear_dns_cache(self, host: str | None = None, port: int | None = None) -> None: """Remove specified host/port or clear all dns local cache.""" if host is not None and port is not None: self._cached_hosts.remove((host, port)) @@ -1020,8 +1001,8 @@ def clear_dns_cache( self._cached_hosts.clear() async def _resolve_host( - self, host: str, port: int, traces: Optional[Sequence["Trace"]] = None - ) -> List[ResolveResult]: + self, host: str, port: int, traces: Sequence["Trace"] | None = None + ) -> list[ResolveResult]: """Resolve host and return list of addresses.""" if is_ip_address(host): return [ @@ -1058,7 +1039,7 @@ async def _resolve_host( await trace.send_dns_cache_hit(host) return result - futures: Set[asyncio.Future[None]] + futures: set[asyncio.Future[None]] # # If multiple connectors are resolving the same host, we wait # for the first one to resolve and then use the result for all of them. @@ -1102,7 +1083,7 @@ async def _resolve_host( return await asyncio.shield(resolved_host_task) except asyncio.CancelledError: - def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: + def drop_exception(fut: "asyncio.Future[list[ResolveResult]]") -> None: with suppress(Exception, asyncio.CancelledError): fut.result() @@ -1111,12 +1092,12 @@ def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None: async def _resolve_host_with_throttle( self, - key: Tuple[str, int], + key: tuple[str, int], host: str, port: int, - futures: Set[asyncio.Future[None]], - traces: Optional[Sequence["Trace"]], - ) -> List[ResolveResult]: + futures: set[asyncio.Future[None]], + traces: Sequence["Trace"] | None, + ) -> list[ResolveResult]: """Resolve host and set result for all waiters. This method must be run in a task and shielded from cancellation @@ -1151,7 +1132,7 @@ async def _resolve_host_with_throttle( return self._cached_hosts.next_addrs(key) async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: """Create connection. @@ -1164,7 +1145,7 @@ async def _create_connection( return proto - def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + def _get_ssl_context(self, req: ClientRequest) -> SSLContext | None: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1209,12 +1190,12 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, - addr_infos: List[AddrInfoType], + addr_infos: list[AddrInfoType], req: ClientRequest, timeout: "ClientTimeout", - client_error: Type[Exception] = ClientConnectorError, + client_error: type[Exception] = ClientConnectorError, **kwargs: Any, - ) -> Tuple[asyncio.Transport, ResponseHandler]: + ) -> tuple[asyncio.Transport, ResponseHandler]: try: async with ceil_timeout( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold @@ -1292,8 +1273,8 @@ async def _start_tls_connection( underlying_transport: asyncio.Transport, req: ClientRequest, timeout: "ClientTimeout", - client_error: Type[Exception] = ClientConnectorError, - ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: + client_error: type[Exception] = ClientConnectorError, + ) -> tuple[asyncio.BaseTransport, ResponseHandler]: """Wrap the raw TCP transport with TLS.""" tls_proto = self._factory() # Create a brand new proto for TLS sslcontext = self._get_ssl_context(req) @@ -1374,14 +1355,14 @@ async def _start_tls_connection( return tls_transport, tls_proto def _convert_hosts_to_addr_infos( - self, hosts: List[ResolveResult] - ) -> List[AddrInfoType]: + self, hosts: list[ResolveResult] + ) -> list[AddrInfoType]: """Converts the list of hosts to a list of addr_infos. The list of hosts is the result of a DNS lookup. The list of addr_infos is the result of a call to `socket.getaddrinfo()`. """ - addr_infos: List[AddrInfoType] = [] + addr_infos: list[AddrInfoType] = [] for hinfo in hosts: host = hinfo["host"] is_ipv6 = ":" in host @@ -1397,11 +1378,11 @@ def _convert_hosts_to_addr_infos( async def _create_direct_connection( self, req: ClientRequest, - traces: List["Trace"], + traces: list["Trace"], timeout: "ClientTimeout", *, - client_error: Type[Exception] = ClientConnectorError, - ) -> Tuple[asyncio.Transport, ResponseHandler]: + client_error: type[Exception] = ClientConnectorError, + ) -> tuple[asyncio.Transport, ResponseHandler]: sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) @@ -1426,7 +1407,7 @@ async def _create_direct_connection( # it is problem of resolving proxy ip itself raise ClientConnectorDNSError(req.connection_key, exc) from exc - last_exc: Optional[Exception] = None + last_exc: Exception | None = None addr_infos = self._convert_hosts_to_addr_infos(hosts) while addr_infos: # Strip trailing dots, certificates contain FQDN without dots. @@ -1469,9 +1450,9 @@ async def _create_direct_connection( raise last_exc async def _create_proxy_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" - ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: - headers: Dict[str, str] = {} + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" + ) -> tuple[asyncio.BaseTransport, ResponseHandler]: + headers: dict[str, str] = {} if req.proxy_headers is not None: headers = req.proxy_headers # type: ignore[assignment] headers[hdrs.HOST] = req.headers[hdrs.HOST] @@ -1586,7 +1567,7 @@ def __init__( self, path: str, force_close: bool = False, - keepalive_timeout: Union[_SENTINEL, float, None] = sentinel, + keepalive_timeout: _SENTINEL | float | None = sentinel, limit: int = 100, limit_per_host: int = 0, ) -> None: @@ -1604,7 +1585,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( @@ -1642,7 +1623,7 @@ def __init__( self, path: str, force_close: bool = False, - keepalive_timeout: Union[_SENTINEL, float, None] = sentinel, + keepalive_timeout: _SENTINEL | float | None = sentinel, limit: int = 100, limit_per_host: int = 0, ) -> None: @@ -1667,7 +1648,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( diff --git a/aiohttp/cookiejar.py b/aiohttp/cookiejar.py index 7e8b206076e..792583938f2 100644 --- a/aiohttp/cookiejar.py +++ b/aiohttp/cookiejar.py @@ -10,20 +10,9 @@ import time import warnings from collections import defaultdict -from collections.abc import Mapping +from collections.abc import Iterable, Iterator, Mapping from http.cookies import BaseCookie, Morsel, SimpleCookie -from typing import ( - DefaultDict, - Dict, - FrozenSet, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import Union from yarl import URL @@ -74,10 +63,9 @@ class CookieJar(AbstractCookieJar): ) try: calendar.timegm(time.gmtime(MAX_TIME)) - except (OSError, ValueError): + except OSError: # Hit the maximum representable time on Windows # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64 - # Throws ValueError on PyPy 3.9, OSError elsewhere MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1)) except OverflowError: # #4515: datetime.max may not be representable on 32-bit platforms @@ -90,19 +78,19 @@ def __init__( *, unsafe: bool = False, quote_cookie: bool = True, - treat_as_secure_origin: Union[StrOrURL, Iterable[StrOrURL], None] = None, + treat_as_secure_origin: StrOrURL | Iterable[StrOrURL] | None = None, ) -> None: - self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict( + self._cookies: defaultdict[tuple[str, str], SimpleCookie] = defaultdict( SimpleCookie ) - self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = ( + self._morsel_cache: defaultdict[tuple[str, str], dict[str, Morsel[str]]] = ( defaultdict(dict) ) - self._host_only_cookies: Set[Tuple[str, str]] = set() + self._host_only_cookies: set[tuple[str, str]] = set() self._unsafe = unsafe self._quote_cookie = quote_cookie if treat_as_secure_origin is None: - self._treat_as_secure_origin: FrozenSet[URL] = frozenset() + self._treat_as_secure_origin: frozenset[URL] = frozenset() elif isinstance(treat_as_secure_origin, URL): self._treat_as_secure_origin = frozenset({treat_as_secure_origin.origin()}) elif isinstance(treat_as_secure_origin, str): @@ -116,8 +104,8 @@ def __init__( for url in treat_as_secure_origin } ) - self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = [] - self._expirations: Dict[Tuple[str, str, str], float] = {} + self._expire_heap: list[tuple[float, tuple[str, str, str]]] = [] + self._expirations: dict[tuple[str, str, str], float] = {} @property def quote_cookie(self) -> bool: @@ -133,7 +121,7 @@ def load(self, file_path: PathLike) -> None: with file_path.open(mode="rb") as f: self._cookies = pickle.load(f) - def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: + def clear(self, predicate: ClearCookiePredicate | None = None) -> None: if predicate is None: self._expire_heap.clear() self._cookies.clear() @@ -198,7 +186,7 @@ def _do_expiration(self) -> None: heapq.heapify(self._expire_heap) now = time.time() - to_del: List[Tuple[str, str, str]] = [] + to_del: list[tuple[str, str, str]] = [] # Find any expired cookies and add them to the to-delete list while self._expire_heap: when, cookie_key = self._expire_heap[0] @@ -215,7 +203,7 @@ def _do_expiration(self) -> None: if to_del: self._delete_cookies(to_del) - def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None: + def _delete_cookies(self, to_del: list[tuple[str, str, str]]) -> None: for domain, path, name in to_del: self._host_only_cookies.discard((domain, name)) self._cookies[(domain, path)].pop(name, None) @@ -308,9 +296,7 @@ def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": """Returns this jar's cookies filtered by their attributes.""" if not isinstance(request_url, URL): warnings.warn( # type: ignore[unreachable] - "The method accepts yarl.URL instances only, got {}".format( - type(request_url) - ), + f"The method accepts yarl.URL instances only, got {type(request_url)}", DeprecationWarning, ) request_url = URL(request_url) @@ -425,7 +411,7 @@ def _is_domain_match(domain: str, hostname: str) -> bool: return not is_ip_address(hostname) @classmethod - def _parse_date(cls, date_str: str) -> Optional[int]: + def _parse_date(cls, date_str: str) -> int | None: """Implements date string parsing adhering to RFC 6265.""" if not date_str: return None @@ -506,7 +492,7 @@ def __len__(self) -> int: def quote_cookie(self) -> bool: return True - def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: + def clear(self, predicate: ClearCookiePredicate | None = None) -> None: pass def clear_domain(self, domain: str) -> None: diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py index 27d395a2a04..d49d693a4ec 100644 --- a/aiohttp/formdata.py +++ b/aiohttp/formdata.py @@ -1,5 +1,6 @@ import io -from typing import Any, Iterable, List, Optional +from collections.abc import Iterable +from typing import Any from urllib.parse import urlencode from multidict import MultiDict, MultiDictProxy @@ -21,14 +22,14 @@ def __init__( self, fields: Iterable[Any] = (), quote_fields: bool = True, - charset: Optional[str] = None, - boundary: Optional[str] = None, + charset: str | None = None, + boundary: str | None = None, *, default_to_multipart: bool = False, ) -> None: self._boundary = boundary self._writer = multipart.MultipartWriter("form-data", boundary=self._boundary) - self._fields: List[Any] = [] + self._fields: list[Any] = [] self._is_multipart = default_to_multipart self._quote_fields = quote_fields self._charset = charset @@ -48,8 +49,8 @@ def add_field( name: str, value: Any, *, - content_type: Optional[str] = None, - filename: Optional[str] = None, + content_type: str | None = None, + filename: str | None = None, ) -> None: if isinstance(value, (io.IOBase, bytes, bytearray, memoryview)): self._is_multipart = True @@ -95,7 +96,7 @@ def add_fields(self, *fields: Any) -> None: raise TypeError( "Only io.IOBase, multidict and (name, file) " "pairs allowed, use .add_field() for passing " - "more complex parameters, got {!r}".format(rec) + f"more complex parameters, got {rec!r}" ) def _gen_form_urlencoded(self) -> payload.BytesPayload: diff --git a/aiohttp/hdrs.py b/aiohttp/hdrs.py index c8d6b35f33a..b64b62ee7f2 100644 --- a/aiohttp/hdrs.py +++ b/aiohttp/hdrs.py @@ -3,7 +3,7 @@ # After changing the file content call ./tools/gen.py # to regenerate the headers parser import itertools -from typing import Final, Set +from typing import Final from multidict import istr @@ -18,7 +18,7 @@ METH_PUT: Final[str] = "PUT" METH_TRACE: Final[str] = "TRACE" -METH_ALL: Final[Set[str]] = { +METH_ALL: Final[set[str]] = { METH_CONNECT, METH_HEAD, METH_GET, diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 97a04b6d6bd..48389264186 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -18,6 +18,7 @@ import warnings import weakref from collections import namedtuple +from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress from email.parser import HeaderParser from email.utils import parsedate @@ -28,18 +29,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, ContextManager, - Dict, Generic, - Iterable, - Iterator, - List, - Mapping, Optional, Protocol, - Tuple, - Type, TypeVar, Union, final, @@ -64,8 +57,6 @@ if TYPE_CHECKING: from dataclasses import dataclass as frozen_dataclass_decorator -elif sys.version_info < (3, 10): - frozen_dataclass_decorator = functools.partial(dataclasses.dataclass, frozen=True) else: frozen_dataclass_decorator = functools.partial( dataclasses.dataclass, frozen=True, slots=True @@ -73,8 +64,6 @@ __all__ = ("BasicAuth", "ChainMapProxy", "ETag", "frozen_dataclass_decorator", "reify") -PY_310 = sys.version_info >= (3, 10) - COOKIE_MAX_LENGTH = 4096 _T = TypeVar("_T") @@ -190,7 +179,7 @@ def encode(self) -> str: return "Basic %s" % base64.b64encode(creds).decode(self.encoding) -def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: +def strip_auth_from_url(url: URL) -> tuple[URL, BasicAuth | None]: """Remove user and password from URL if present and return BasicAuth object.""" # Check raw_user and raw_password first as yarl is likely # to already have these values parsed from the netloc in the cache. @@ -199,7 +188,7 @@ def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: return url.with_user(None), BasicAuth(url.user or "", url.password or "") -def netrc_from_env() -> Optional[netrc.netrc]: +def netrc_from_env() -> netrc.netrc | None: """Load netrc from file. Attempt to load it from the path specified by the env-var @@ -247,10 +236,10 @@ def netrc_from_env() -> Optional[netrc.netrc]: @frozen_dataclass_decorator class ProxyInfo: proxy: URL - proxy_auth: Optional[BasicAuth] + proxy_auth: BasicAuth | None -def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth: +def basicauth_from_netrc(netrc_obj: netrc.netrc | None, host: str) -> BasicAuth: """ Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. @@ -279,7 +268,7 @@ def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAu return BasicAuth(username, password) -def proxies_from_env() -> Dict[str, ProxyInfo]: +def proxies_from_env() -> dict[str, ProxyInfo]: proxy_urls = { k: URL(v) for k, v in getproxies().items() @@ -305,7 +294,7 @@ def proxies_from_env() -> Dict[str, ProxyInfo]: return ret -def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: +def get_env_proxy_for_url(url: URL) -> tuple[URL, BasicAuth | None]: """Get a permitted proxy for the given URL from the env.""" if url.host is not None and proxy_bypass(url.host): raise LookupError(f"Proxying is disallowed for `{url.host!r}`") @@ -368,7 +357,7 @@ def parse_mimetype(mimetype: str) -> MimeType: @functools.lru_cache(maxsize=56) -def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]: +def parse_content_type(raw: str) -> tuple[str, MappingProxyType[str, str]]: """Parse Content-Type header. Returns a tuple of the parsed content type and a @@ -381,7 +370,7 @@ def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]: return content_type, MappingProxyType(content_dict) -def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: +def guess_filename(obj: Any, default: str | None = None) -> str | None: name = getattr(obj, "name", None) if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": return Path(name).name @@ -409,7 +398,7 @@ def content_disposition_header( disptype: str, quote_fields: bool = True, _charset: str = "utf-8", - params: Optional[Dict[str, str]] = None, + params: dict[str, str] | None = None, ) -> str: """Sets ``Content-Disposition`` header for MIME. @@ -471,7 +460,7 @@ def is_expected_content_type( return expected_content_type in response_content_type -def is_ip_address(host: Optional[str]) -> bool: +def is_ip_address(host: str | None) -> bool: """Check if host looks like an IP Address. This check is only meant as a heuristic to ensure that @@ -484,7 +473,7 @@ def is_ip_address(host: Optional[str]) -> bool: return ":" in host or host.replace(".", "").isdigit() -_cached_current_datetime: Optional[int] = None +_cached_current_datetime: int | None = None _cached_formatted_datetime = "" @@ -528,7 +517,7 @@ def rfc822_formatted_time() -> str: return _cached_formatted_datetime -def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: +def _weakref_handle(info: "tuple[weakref.ref[object], str]") -> None: ref, name = info ob = ref() if ob is not None: @@ -539,10 +528,10 @@ def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: def weakref_handle( ob: object, name: str, - timeout: Optional[float], + timeout: float | None, loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, -) -> Optional[asyncio.TimerHandle]: +) -> asyncio.TimerHandle | None: if timeout is not None and timeout > 0: when = loop.time() + timeout if timeout >= timeout_ceil_threshold: @@ -554,10 +543,10 @@ def weakref_handle( def call_later( cb: Callable[[], Any], - timeout: Optional[float], + timeout: float | None, loop: asyncio.AbstractEventLoop, timeout_ceil_threshold: float = 5, -) -> Optional[asyncio.TimerHandle]: +) -> asyncio.TimerHandle | None: if timeout is None or timeout <= 0: return None now = loop.time() @@ -585,14 +574,14 @@ class TimeoutHandle: def __init__( self, loop: asyncio.AbstractEventLoop, - timeout: Optional[float], + timeout: float | None, ceil_threshold: float = 5, ) -> None: self._timeout = timeout self._loop = loop self._ceil_threshold = ceil_threshold - self._callbacks: List[ - Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] + self._callbacks: list[ + tuple[Callable[..., None], tuple[Any, ...], dict[str, Any]] ] = [] def register( @@ -603,7 +592,7 @@ def register( def close(self) -> None: self._callbacks.clear() - def start(self) -> Optional[asyncio.TimerHandle]: + def start(self) -> asyncio.TimerHandle | None: timeout = self._timeout if timeout is not None and timeout > 0: when = self._loop.time() + timeout @@ -646,9 +635,9 @@ def __enter__(self) -> BaseTimerContext: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: return @@ -660,7 +649,7 @@ class TimerContext(BaseTimerContext): def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop - self._tasks: List[asyncio.Task[Any]] = [] + self._tasks: list[asyncio.Task[Any]] = [] self._cancelled = False self._cancelling = 0 @@ -688,11 +677,11 @@ def __enter__(self) -> BaseTimerContext: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - enter_task: Optional[asyncio.Task[Any]] = None + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + enter_task: asyncio.Task[Any] | None = None if self._tasks: enter_task = self._tasks.pop() @@ -719,7 +708,7 @@ def timeout(self) -> None: def ceil_timeout( - delay: Optional[float], ceil_threshold: float = 5 + delay: float | None, ceil_threshold: float = 5 ) -> async_timeout.Timeout: if delay is None or delay <= 0: return async_timeout.timeout(None) @@ -736,11 +725,11 @@ class HeadersMixin: """Mixin for handling headers.""" _headers: MultiMapping[str] - _content_type: Optional[str] = None - _content_dict: Optional[Dict[str, str]] = None - _stored_content_type: Union[str, None, _SENTINEL] = sentinel + _content_type: str | None = None + _content_dict: dict[str, str] | None = None + _stored_content_type: str | None | _SENTINEL = sentinel - def _parse_content_type(self, raw: Optional[str]) -> None: + def _parse_content_type(self, raw: str | None) -> None: self._stored_content_type = raw if raw is None: # default value according to RFC 2616 @@ -762,7 +751,7 @@ def content_type(self) -> str: return self._content_type @property - def charset(self) -> Optional[str]: + def charset(self) -> str | None: """The value of charset part for Content-Type HTTP header.""" raw = self._headers.get(hdrs.CONTENT_TYPE) if self._stored_content_type != raw: @@ -771,7 +760,7 @@ def charset(self) -> Optional[str]: return self._content_dict.get("charset") @property - def content_length(self) -> Optional[int]: + def content_length(self) -> int | None: """The value of Content-Length HTTP header.""" content_length = self._headers.get(hdrs.CONTENT_LENGTH) return None if content_length is None else int(content_length) @@ -788,14 +777,14 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: class ErrorableProtocol(Protocol): def set_exception( self, - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: BaseException = ..., ) -> None: ... def set_exception( fut: Union["asyncio.Future[_T]", ErrorableProtocol], - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: """Set future exception. @@ -825,10 +814,10 @@ class AppKey(Generic[_T]): # This may be set by Python when instantiating with a generic type. We need to # support this, in order to support types that are not concrete classes, # like Iterable, which can't be passed as the second parameter to __init__. - __orig_class__: Type[object] + __orig_class__: type[object] # TODO(PY314): Change Type to TypeForm (this should resolve unreachable below). - def __init__(self, name: str, t: Optional[Type[_T]] = None): + def __init__(self, name: str, t: type[_T] | None = None): # Prefix with module name to help deduplicate key names. frame = inspect.currentframe() while frame: @@ -868,16 +857,15 @@ def __repr__(self) -> str: @final -class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): +class ChainMapProxy(Mapping[str | AppKey[Any], Any]): __slots__ = ("_maps",) - def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None: + def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None: self._maps = tuple(maps) def __init_subclass__(cls) -> None: raise TypeError( - "Inheritance class {} from ChainMapProxy " - "is forbidden".format(cls.__name__) + f"Inheritance class {cls.__name__} from ChainMapProxy " "is forbidden" ) @overload # type: ignore[override] @@ -886,7 +874,7 @@ def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload def __getitem__(self, key: str) -> Any: ... - def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: + def __getitem__(self, key: str | AppKey[_T]) -> Any: for mapping in self._maps: try: return mapping[key] @@ -895,15 +883,15 @@ def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: raise KeyError(key) @overload # type: ignore[override] - def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ... + def get(self, key: AppKey[_T], default: _S) -> _T | _S: ... @overload - def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... + def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ... @overload def get(self, key: str, default: Any = ...) -> Any: ... - def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: + def get(self, key: str | AppKey[_T], default: Any = None) -> Any: try: return self[key] except KeyError: @@ -913,8 +901,8 @@ def __len__(self) -> int: # reuses stored hash values if possible return len(set().union(*self._maps)) - def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: - d: Dict[Union[str, AppKey[Any]], Any] = {} + def __iter__(self) -> Iterator[str | AppKey[Any]]: + d: dict[str | AppKey[Any], Any] = {} for mapping in reversed(self._maps): # reuses stored hash values if possible d.update(mapping) @@ -934,7 +922,7 @@ def __repr__(self) -> str: class CookieMixin: """Mixin for handling cookies.""" - _cookies: Optional[SimpleCookie] = None + _cookies: SimpleCookie | None = None @property def cookies(self) -> SimpleCookie: @@ -947,14 +935,14 @@ def set_cookie( name: str, value: str, *, - expires: Optional[str] = None, - domain: Optional[str] = None, - max_age: Optional[Union[int, str]] = None, + expires: str | None = None, + domain: str | None = None, + max_age: int | str | None = None, path: str = "/", - secure: Optional[bool] = None, - httponly: Optional[bool] = None, - samesite: Optional[str] = None, - partitioned: Optional[bool] = None, + secure: bool | None = None, + httponly: bool | None = None, + samesite: str | None = None, + partitioned: bool | None = None, ) -> None: """Set or update response cookie. @@ -1005,11 +993,11 @@ def del_cookie( self, name: str, *, - domain: Optional[str] = None, + domain: str | None = None, path: str = "/", - secure: Optional[bool] = None, - httponly: Optional[bool] = None, - samesite: Optional[str] = None, + secure: bool | None = None, + httponly: bool | None = None, + samesite: str | None = None, ) -> None: """Delete cookie. @@ -1060,7 +1048,7 @@ def validate_etag_value(value: str) -> None: ) -def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: +def parse_http_date(date_str: str | None) -> datetime.datetime | None: """Process a date string, return a datetime object""" if date_str is not None: timetuple = parsedate(date_str) diff --git a/aiohttp/http.py b/aiohttp/http.py index 244d71c4197..6dad94bb11c 100644 --- a/aiohttp/http.py +++ b/aiohttp/http.py @@ -55,6 +55,6 @@ ) -SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format( - sys.version_info, __version__ +SERVER_SOFTWARE: str = ( + f"Python/{sys.version_info[0]}.{sys.version_info[1]} aiohttp/{__version__}" ) diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index 51f19166d40..200cfeb3c68 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -1,7 +1,6 @@ """Low-level http related exceptions.""" from textwrap import indent -from typing import Optional, Union from .typedefs import _CIMultiDict @@ -25,9 +24,9 @@ class HttpProcessingError(Exception): def __init__( self, *, - code: Optional[int] = None, + code: int | None = None, message: str = "", - headers: Optional[_CIMultiDict] = None, + headers: _CIMultiDict | None = None, ) -> None: if code is not None: self.code = code @@ -46,7 +45,7 @@ class BadHttpMessage(HttpProcessingError): code = 400 message = "Bad Request" - def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> None: + def __init__(self, message: str, *, headers: _CIMultiDict | None = None) -> None: super().__init__(message=message, headers=headers) self.args = (message,) @@ -83,7 +82,7 @@ def __init__( class InvalidHeader(BadHttpMessage): - def __init__(self, hdr: Union[bytes, str]) -> None: + def __init__(self, hdr: bytes | str) -> None: hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr super().__init__(f"Invalid HTTP header: {hdr!r}") self.hdr = hdr_s @@ -91,7 +90,7 @@ def __init__(self, hdr: Union[bytes, str]) -> None: class BadStatusLine(BadHttpMessage): - def __init__(self, line: str = "", error: Optional[str] = None) -> None: + def __init__(self, line: str = "", error: str | None = None) -> None: super().__init__(error or f"Bad status line {line!r}") self.args = (line,) self.line = line @@ -100,7 +99,7 @@ def __init__(self, line: str = "", error: Optional[str] = None) -> None: class BadHttpMethod(BadStatusLine): """Invalid HTTP method in status line.""" - def __init__(self, line: str = "", error: Optional[str] = None) -> None: + def __init__(self, line: str = "", error: str | None = None) -> None: super().__init__(line, error or f"Bad HTTP method in status line {line!r}") diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index e50fc5fdcc1..ce143819141 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -4,22 +4,8 @@ import string from contextlib import suppress from enum import IntEnum -from typing import ( - Any, - ClassVar, - Final, - Generic, - List, - Literal, - NamedTuple, - Optional, - Pattern, - Set, - Tuple, - Type, - TypeVar, - Union, -) +from re import Pattern +from typing import Any, ClassVar, Final, Generic, Literal, NamedTuple, TypeVar from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL @@ -68,7 +54,7 @@ _SEP = Literal[b"\r\n", b"\n"] -ASCIISET: Final[Set[str]] = set(string.printable) +ASCIISET: Final[set[str]] = set(string.printable) # See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview # and https://www.rfc-editor.org/rfc/rfc9110.html#name-tokens @@ -91,7 +77,7 @@ class RawRequestMessage(NamedTuple): headers: CIMultiDictProxy[str] raw_headers: RawHeaders should_close: bool - compression: Optional[str] + compression: str | None upgrade: bool chunked: bool url: URL @@ -104,7 +90,7 @@ class RawResponseMessage(NamedTuple): headers: CIMultiDictProxy[str] raw_headers: RawHeaders should_close: bool - compression: Optional[str] + compression: str | None upgrade: bool chunked: bool @@ -135,8 +121,8 @@ def __init__( self._lax = lax def parse_headers( - self, lines: List[bytes] - ) -> Tuple["CIMultiDictProxy[str]", RawHeaders]: + self, lines: list[bytes] + ) -> tuple["CIMultiDictProxy[str]", RawHeaders]: headers: CIMultiDict[str] = CIMultiDict() # note: "raw" does not mean inclusion of OWS before/after the field value raw_headers = [] @@ -244,10 +230,10 @@ def __init__( limit: int, max_line_size: int = 8190, max_field_size: int = 8190, - timer: Optional[BaseTimerContext] = None, - code: Optional[int] = None, - method: Optional[str] = None, - payload_exception: Optional[Type[BaseException]] = None, + timer: BaseTimerContext | None = None, + code: int | None = None, + method: str | None = None, + payload_exception: type[BaseException] | None = None, response_with_body: bool = True, read_until_eof: bool = False, auto_decompress: bool = True, @@ -263,22 +249,22 @@ def __init__( self.response_with_body = response_with_body self.read_until_eof = read_until_eof - self._lines: List[bytes] = [] + self._lines: list[bytes] = [] self._tail = b"" self._upgraded = False self._payload = None - self._payload_parser: Optional[HttpPayloadParser] = None + self._payload_parser: HttpPayloadParser | None = None self._auto_decompress = auto_decompress self._limit = limit self._headers_parser = HeadersParser(max_line_size, max_field_size, self.lax) @abc.abstractmethod - def parse_message(self, lines: List[bytes]) -> _MsgT: ... + def parse_message(self, lines: list[bytes]) -> _MsgT: ... @abc.abstractmethod def _is_chunked_te(self, te: str) -> bool: ... - def feed_eof(self) -> Optional[_MsgT]: + def feed_eof(self) -> _MsgT | None: if self._payload_parser is not None: self._payload_parser.feed_eof() self._payload_parser = None @@ -302,7 +288,7 @@ def feed_data( CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, METH_CONNECT: str = hdrs.METH_CONNECT, SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, - ) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]: + ) -> tuple[list[tuple[_MsgT, StreamReader]], bool, bytes]: messages = [] if self._tail: @@ -341,7 +327,7 @@ def feed_data( finally: self._lines.clear() - def get_content_length() -> Optional[int]: + def get_content_length() -> int | None: # payload length length_hdr = msg.headers.get(CONTENT_LENGTH) if length_hdr is None: @@ -490,9 +476,9 @@ def get_content_length() -> Optional[int]: return messages, self._upgraded, data def parse_headers( - self, lines: List[bytes] - ) -> Tuple[ - "CIMultiDictProxy[str]", RawHeaders, Optional[bool], Optional[str], bool, bool + self, lines: list[bytes] + ) -> tuple[ + "CIMultiDictProxy[str]", RawHeaders, bool | None, str | None, bool, bool ]: """Parses RFC 5322 headers from a stream. @@ -571,7 +557,7 @@ class HttpRequestParser(HttpParser[RawRequestMessage]): Returns RawRequestMessage. """ - def parse_message(self, lines: List[bytes]) -> RawRequestMessage: + def parse_message(self, lines: list[bytes]) -> RawRequestMessage: # request line line = lines[0].decode("utf-8", "surrogateescape") try: @@ -676,15 +662,15 @@ class HttpResponseParser(HttpParser[RawResponseMessage]): def feed_data( self, data: bytes, - SEP: Optional[_SEP] = None, + SEP: _SEP | None = None, *args: Any, **kwargs: Any, - ) -> Tuple[List[Tuple[RawResponseMessage, StreamReader]], bool, bytes]: + ) -> tuple[list[tuple[RawResponseMessage, StreamReader]], bool, bytes]: if SEP is None: SEP = b"\r\n" if DEBUG else b"\n" return super().feed_data(data, SEP, *args, **kwargs) - def parse_message(self, lines: List[bytes]) -> RawResponseMessage: + def parse_message(self, lines: list[bytes]) -> RawResponseMessage: line = lines[0].decode("utf-8", "surrogateescape") try: version, status = line.split(maxsplit=1) @@ -756,11 +742,11 @@ class HttpPayloadParser: def __init__( self, payload: StreamReader, - length: Optional[int] = None, + length: int | None = None, chunked: bool = False, - compression: Optional[str] = None, - code: Optional[int] = None, - method: Optional[str] = None, + compression: str | None = None, + code: int | None = None, + method: str | None = None, response_with_body: bool = True, auto_decompress: bool = True, lax: bool = False, @@ -780,7 +766,7 @@ def __init__( # payload decompression wrapper if response_with_body and compression and self._auto_decompress: - real_payload: Union[StreamReader, DeflateBuffer] = DeflateBuffer( + real_payload: StreamReader | DeflateBuffer = DeflateBuffer( payload, compression ) else: @@ -817,7 +803,7 @@ def feed_eof(self) -> None: def feed_data( self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";" - ) -> Tuple[bool, bytes]: + ) -> tuple[bool, bytes]: # Read specified amount of bytes if self._type == ParseState.PARSE_LENGTH: required = self._length @@ -933,14 +919,14 @@ def feed_data( class DeflateBuffer: """DeflateStream decompress stream and feed data into specified stream.""" - def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: + def __init__(self, out: StreamReader, encoding: str | None) -> None: self.out = out self.size = 0 out.total_compressed_bytes = self.size self.encoding = encoding self._started_decoding = False - self.decompressor: Union[BrotliDecompressor, ZLibDecompressor, ZSTDDecompressor] + self.decompressor: BrotliDecompressor | ZLibDecompressor | ZSTDDecompressor if encoding == "br": if not HAS_BROTLI: raise ContentEncodingError( @@ -960,7 +946,7 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: def set_exception( self, - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: set_exception(self.out, exc, exc_cause) diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index c290d81db59..f1bafff86f4 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -56,10 +56,10 @@ class HttpVersion(NamedTuple): class StreamWriter(AbstractStreamWriter): - length: Optional[int] = None + length: int | None = None chunked: bool = False _eof: bool = False - _compress: Optional[ZLibCompressor] = None + _compress: ZLibCompressor | None = None def __init__( self, @@ -72,11 +72,11 @@ def __init__( self.loop = loop self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent self._on_headers_sent: _T_OnHeadersSent = on_headers_sent - self._headers_buf: Optional[bytes] = None + self._headers_buf: bytes | None = None self._headers_written: bool = False @property - def transport(self) -> Optional[asyncio.Transport]: + def transport(self) -> asyncio.Transport | None: return self._protocol.transport @property @@ -87,7 +87,7 @@ def enable_chunking(self) -> None: self.chunked = True def enable_compression( - self, encoding: str = "deflate", strategy: Optional[int] = None + self, encoding: str = "deflate", strategy: int | None = None ) -> None: self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) @@ -281,7 +281,7 @@ async def write_eof(self, chunk: bytes = b"") -> None: # Handle body/compression if self._compress: - chunks: List[bytes] = [] + chunks: list[bytes] = [] chunks_len = 0 if chunk and (compressed_chunk := await self._compress.compress(chunk)): chunks_len = len(compressed_chunk) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index a3a40080b12..15901e89e0a 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -6,21 +6,9 @@ import uuid import warnings from collections import deque -from collections.abc import Mapping, Sequence +from collections.abc import Iterator, Mapping, Sequence from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Deque, - Dict, - Iterator, - List, - Optional, - Tuple, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Union, cast from urllib.parse import parse_qsl, unquote, urlencode from multidict import CIMultiDict, CIMultiDictProxy @@ -78,8 +66,8 @@ class BadContentDispositionParam(RuntimeWarning): def parse_content_disposition( - header: Optional[str], -) -> Tuple[Optional[str], Dict[str, str]]: + header: str | None, +) -> tuple[str | None, dict[str, str]]: def is_token(string: str) -> bool: return bool(string) and TOKEN >= set(string) @@ -110,7 +98,7 @@ def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str: warnings.warn(BadContentDispositionHeader(header)) return None, {} - params: Dict[str, str] = {} + params: dict[str, str] = {} while parts: item = parts.pop(0) @@ -182,7 +170,7 @@ def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str: def content_disposition_filename( params: Mapping[str, str], name: str = "filename" -) -> Optional[str]: +) -> str | None: name_suf = "%s*" % name if not params: return None @@ -245,7 +233,7 @@ def at_eof(self) -> bool: async def next( self, - ) -> Optional[Union["MultipartReader", "BodyPartReader"]]: + ) -> Union["MultipartReader", "BodyPartReader"] | None: """Emits next multipart reader object.""" item = await self.stream.next() if self.stream.at_eof(): @@ -272,7 +260,7 @@ def __init__( content: StreamReader, *, subtype: str = "mixed", - default_charset: Optional[str] = None, + default_charset: str | None = None, ) -> None: self.headers = headers self._boundary = boundary @@ -285,10 +273,10 @@ def __init__( length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None) self._length = int(length) if length is not None else None self._read_bytes = 0 - self._unread: Deque[bytes] = deque() - self._prev_chunk: Optional[bytes] = None + self._unread: deque[bytes] = deque() + self._prev_chunk: bytes | None = None self._content_eof = 0 - self._cache: Dict[str, Any] = {} + self._cache: dict[str, Any] = {} def __aiter__(self) -> Self: return self @@ -299,7 +287,7 @@ async def __anext__(self) -> bytes: raise StopAsyncIteration return part - async def next(self) -> Optional[bytes]: + async def next(self) -> bytes | None: item = await self.read() if not item: return None @@ -459,7 +447,7 @@ async def release(self) -> None: while not self._at_eof: await self.read_chunk(self.chunk_size) - async def text(self, *, encoding: Optional[str] = None) -> str: + async def text(self, *, encoding: str | None = None) -> str: """Like read(), but assumes that body part contains text data.""" data = await self.read(decode=True) # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm @@ -467,15 +455,15 @@ async def text(self, *, encoding: Optional[str] = None) -> str: encoding = encoding or self.get_charset(default="utf-8") return data.decode(encoding) - async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]: + async def json(self, *, encoding: str | None = None) -> dict[str, Any] | None: """Like read(), but assumes that body parts contains JSON data.""" data = await self.read(decode=True) if not data: return None encoding = encoding or self.get_charset(default="utf-8") - return cast(Dict[str, Any], json.loads(data.decode(encoding))) + return cast(dict[str, Any], json.loads(data.decode(encoding))) - async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]: + async def form(self, *, encoding: str | None = None) -> list[tuple[str, str]]: """Like read(), but assumes that body parts contain form urlencoded data.""" data = await self.read(decode=True) if not data: @@ -543,7 +531,7 @@ def get_charset(self, default: str) -> str: return mimetype.parameters.get("charset", self._default_charset or default) @reify - def name(self) -> Optional[str]: + def name(self) -> str | None: """Returns name specified in Content-Disposition header. If the header is missing or malformed, returns None. @@ -552,7 +540,7 @@ def name(self) -> Optional[str]: return content_disposition_filename(params, "name") @reify - def filename(self) -> Optional[str]: + def filename(self) -> str | None: """Returns filename specified in Content-Disposition header. Returns None if the header is missing or malformed. @@ -569,7 +557,7 @@ class BodyPartReaderPayload(Payload): def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: super().__init__(value, *args, **kwargs) - params: Dict[str, str] = {} + params: dict[str, str] = {} if value.name is not None: params["name"] = value.name if value.filename is not None: @@ -606,7 +594,7 @@ class MultipartReader: response_wrapper_cls = MultipartResponseWrapper #: Multipart reader class, used to handle multipart/* body parts. #: None points to type(self) - multipart_reader_cls: Optional[Type["MultipartReader"]] = None + multipart_reader_cls: type["MultipartReader"] | None = None #: Body part reader class for non multipart/* content types. part_reader_cls = BodyPartReader @@ -621,18 +609,18 @@ def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: self.headers = headers self._boundary = ("--" + self._get_boundary()).encode() self._content = content - self._default_charset: Optional[str] = None - self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None + self._default_charset: str | None = None + self._last_part: MultipartReader | BodyPartReader | None = None self._at_eof = False self._at_bof = True - self._unread: List[bytes] = [] + self._unread: list[bytes] = [] def __aiter__(self) -> Self: return self async def __anext__( self, - ) -> Optional[Union["MultipartReader", BodyPartReader]]: + ) -> Union["MultipartReader", BodyPartReader] | None: part = await self.next() if part is None: raise StopAsyncIteration @@ -658,7 +646,7 @@ def at_eof(self) -> bool: async def next( self, - ) -> Optional[Union["MultipartReader", BodyPartReader]]: + ) -> Union["MultipartReader", BodyPartReader] | None: """Emits the next multipart body part.""" # So, if we're at BOF, we need to skip till the boundary. if self._at_eof: @@ -801,7 +789,7 @@ async def _maybe_release_last_part(self) -> None: self._last_part = None -_Part = Tuple[Payload, str, str] +_Part = tuple[Payload, str, str] class MultipartWriter(Payload): @@ -811,7 +799,7 @@ class MultipartWriter(Payload): # _consumed = False (inherited) - Can be encoded multiple times _autoclose = True # No file handles, just collects parts in memory - def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None: + def __init__(self, subtype: str = "mixed", boundary: str | None = None) -> None: boundary = boundary if boundary is not None else uuid.uuid4().hex # The underlying Payload API demands a str (utf-8), not bytes, # so we need to ensure we don't lose anything during conversion. @@ -830,7 +818,7 @@ def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> No super().__init__(None, content_type=ctype) - self._parts: List[_Part] = [] + self._parts: list[_Part] = [] self._is_form_data = subtype == "form-data" def __enter__(self) -> "MultipartWriter": @@ -838,9 +826,9 @@ def __enter__(self) -> "MultipartWriter": def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: pass @@ -892,7 +880,7 @@ def _boundary_value(self) -> str: def boundary(self) -> str: return self._boundary.decode("ascii") - def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload: + def append(self, obj: Any, headers: Mapping[str, str] | None = None) -> Payload: if headers is None: headers = CIMultiDict() @@ -909,8 +897,8 @@ def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Paylo def append_payload(self, payload: Payload) -> Payload: """Adds a new body part to multipart writer.""" - encoding: Optional[str] = None - te_encoding: Optional[str] = None + encoding: str | None = None + te_encoding: str | None = None if self._is_form_data: # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7 # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 @@ -946,7 +934,7 @@ def append_payload(self, payload: Payload) -> Payload: return payload def append_json( - self, obj: Any, headers: Optional[Mapping[str, str]] = None + self, obj: Any, headers: Mapping[str, str] | None = None ) -> Payload: """Helper to append JSON part.""" if headers is None: @@ -956,8 +944,8 @@ def append_json( def append_form( self, - obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]], - headers: Optional[Mapping[str, str]] = None, + obj: Sequence[tuple[str, str]] | Mapping[str, str], + headers: Mapping[str, str] | None = None, ) -> Payload: """Helper to append form urlencoded part.""" assert isinstance(obj, (Sequence, Mapping)) @@ -976,7 +964,7 @@ def append_form( ) @property - def size(self) -> Optional[int]: + def size(self) -> int | None: """Size of the payload.""" total = 0 for part, encoding, te_encoding in self._parts: @@ -1016,7 +1004,7 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt This method is async-safe and calls as_bytes on underlying payloads. """ - parts: List[bytes] = [] + parts: list[bytes] = [] # Process each part for part, _e, _te in self._parts: @@ -1094,9 +1082,9 @@ async def close(self) -> None: class MultipartPayloadWriter: def __init__(self, writer: Any) -> None: self._writer = writer - self._encoding: Optional[str] = None - self._compress: Optional[ZLibCompressor] = None - self._encoding_buffer: Optional[bytearray] = None + self._encoding: str | None = None + self._compress: ZLibCompressor | None = None + self._encoding_buffer: bytearray | None = None def enable_encoding(self, encoding: str) -> None: if encoding == "base64": @@ -1106,7 +1094,7 @@ def enable_encoding(self, encoding: str) -> None: self._encoding = "quoted-printable" def enable_compression( - self, encoding: str = "deflate", strategy: Optional[int] = None + self, encoding: str = "deflate", strategy: int | None = None ) -> None: self._compress = ZLibCompressor( encoding=encoding, diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 2ee8b8cb908..2bce326ccb4 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -9,20 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from itertools import chain -from typing import ( - IO, - TYPE_CHECKING, - Any, - Dict, - Final, - List, - Optional, - Set, - TextIO, - Tuple, - Type, - Union, -) +from typing import IO, TYPE_CHECKING, Any, Final, TextIO from multidict import CIMultiDict @@ -56,7 +43,7 @@ TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB READ_SIZE: Final[int] = 2**16 # 64 KB -_CLOSE_FUTURES: Set[asyncio.Future[None]] = set() +_CLOSE_FUTURES: set[asyncio.Future[None]] = set() class LookupError(Exception): @@ -74,7 +61,7 @@ def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": def register_payload( - factory: Type["Payload"], type: Any, *, order: Order = Order.normal + factory: type["Payload"], type: Any, *, order: Order = Order.normal ) -> None: PAYLOAD_REGISTRY.register(factory, type, order=order) @@ -84,13 +71,13 @@ def __init__(self, type: Any, *, order: Order = Order.normal) -> None: self.type = type self.order = order - def __call__(self, factory: Type["Payload"]) -> Type["Payload"]: + def __call__(self, factory: type["Payload"]) -> type["Payload"]: register_payload(factory, self.type, order=self.order) return factory -PayloadType = Type["Payload"] -_PayloadRegistryItem = Tuple[PayloadType, Any] +PayloadType = type["Payload"] +_PayloadRegistryItem = tuple[PayloadType, Any] class PayloadRegistry: @@ -102,16 +89,16 @@ class PayloadRegistry: __slots__ = ("_first", "_normal", "_last", "_normal_lookup") def __init__(self) -> None: - self._first: List[_PayloadRegistryItem] = [] - self._normal: List[_PayloadRegistryItem] = [] - self._last: List[_PayloadRegistryItem] = [] - self._normal_lookup: Dict[Any, PayloadType] = {} + self._first: list[_PayloadRegistryItem] = [] + self._normal: list[_PayloadRegistryItem] = [] + self._last: list[_PayloadRegistryItem] = [] + self._normal_lookup: dict[Any, PayloadType] = {} def get( self, data: Any, *args: Any, - _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain, + _CHAIN: "type[chain[_PayloadRegistryItem]]" = chain, **kwargs: Any, ) -> "Payload": if self._first: @@ -150,19 +137,19 @@ def register( class Payload(ABC): _default_content_type: str = "application/octet-stream" - _size: Optional[int] = None + _size: int | None = None _consumed: bool = False # Default: payload has not been consumed yet _autoclose: bool = False # Default: assume resource needs explicit closing def __init__( self, value: Any, - headers: Optional[ - Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]] - ] = None, - content_type: Union[None, str, _SENTINEL] = sentinel, - filename: Optional[str] = None, - encoding: Optional[str] = None, + headers: ( + _CIMultiDict | dict[str, str] | Iterable[tuple[str, str]] | None + ) = None, + content_type: None | str | _SENTINEL = sentinel, + filename: str | None = None, + encoding: str | None = None, **kwargs: Any, ) -> None: self._encoding = encoding @@ -187,7 +174,7 @@ def __init__( self._headers.update(headers) @property - def size(self) -> Optional[int]: + def size(self) -> int | None: """Size of the payload in bytes. Returns the number of bytes that will be transmitted when the payload @@ -197,7 +184,7 @@ def size(self) -> Optional[int]: return self._size @property - def filename(self) -> Optional[str]: + def filename(self) -> str | None: """Filename of the payload.""" return self._filename @@ -216,7 +203,7 @@ def _binary_headers(self) -> bytes: ) @property - def encoding(self) -> Optional[str]: + def encoding(self) -> str | None: """Payload encoding""" return self._encoding @@ -283,7 +270,7 @@ async def write(self, writer: AbstractStreamWriter) -> None: # write_with_length is new in aiohttp 3.12 # it should be overridden by subclasses async def write_with_length( - self, writer: AbstractStreamWriter, content_length: Optional[int] + self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write payload with a specific content length constraint. @@ -353,7 +340,7 @@ class BytesPayload(Payload): _autoclose = True # No file handle, just bytes in memory def __init__( - self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any + self, value: bytes | bytearray | memoryview, *args: Any, **kwargs: Any ) -> None: if "content_type" not in kwargs: kwargs["content_type"] = "application/octet-stream" @@ -406,7 +393,7 @@ async def write(self, writer: AbstractStreamWriter) -> None: await writer.write(self._value) async def write_with_length( - self, writer: AbstractStreamWriter, content_length: Optional[int] + self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write bytes payload with a specific content length constraint. @@ -431,8 +418,8 @@ def __init__( self, value: str, *args: Any, - encoding: Optional[str] = None, - content_type: Optional[str] = None, + encoding: str | None = None, + content_type: str | None = None, **kwargs: Any, ) -> None: if encoding is None: @@ -464,7 +451,7 @@ def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: class IOBasePayload(Payload): _value: io.IOBase # _consumed = False (inherited) - File can be re-read from the same position - _start_position: Optional[int] = None + _start_position: int | None = None # _autoclose = False (inherited) - Has file handle that needs explicit closing def __init__( @@ -494,8 +481,8 @@ def _set_or_restore_start_position(self) -> None: self._consumed = True def _read_and_available_len( - self, remaining_content_len: Optional[int] - ) -> Tuple[Optional[int], bytes]: + self, remaining_content_len: int | None + ) -> tuple[int | None, bytes]: """ Read the file-like object and return both its total size and the first chunk. @@ -519,7 +506,7 @@ def _read_and_available_len( min(READ_SIZE, size or READ_SIZE, remaining_content_len or READ_SIZE) ) - def _read(self, remaining_content_len: Optional[int]) -> bytes: + def _read(self, remaining_content_len: int | None) -> bytes: """ Read a chunk of data from the file-like object. @@ -538,7 +525,7 @@ def _read(self, remaining_content_len: Optional[int]) -> bytes: return self._value.read(remaining_content_len or READ_SIZE) # type: ignore[no-any-return] @property - def size(self) -> Optional[int]: + def size(self) -> int | None: """ Size of the payload in bytes. @@ -584,7 +571,7 @@ async def write(self, writer: AbstractStreamWriter) -> None: await self.write_with_length(writer, None) async def write_with_length( - self, writer: AbstractStreamWriter, content_length: Optional[int] + self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write file-like payload with a specific content length constraint. @@ -646,9 +633,9 @@ async def write_with_length( def _should_stop_writing( self, - available_len: Optional[int], + available_len: int | None, total_written_len: int, - remaining_content_len: Optional[int], + remaining_content_len: int | None, ) -> bool: """ Determine if we should stop writing data. @@ -734,8 +721,8 @@ def __init__( self, value: TextIO, *args: Any, - encoding: Optional[str] = None, - content_type: Optional[str] = None, + encoding: str | None = None, + content_type: str | None = None, **kwargs: Any, ) -> None: if encoding is None: @@ -758,8 +745,8 @@ def __init__( ) def _read_and_available_len( - self, remaining_content_len: Optional[int] - ) -> Tuple[Optional[int], bytes]: + self, remaining_content_len: int | None + ) -> tuple[int | None, bytes]: """ Read the text file-like object and return both its total size and the first chunk. @@ -788,7 +775,7 @@ def _read_and_available_len( ) return size, chunk.encode(self._encoding) if self._encoding else chunk.encode() - def _read(self, remaining_content_len: Optional[int]) -> bytes: + def _read(self, remaining_content_len: int | None) -> bytes: """ Read a chunk of data from the text file-like object. @@ -866,7 +853,7 @@ async def write(self, writer: AbstractStreamWriter) -> None: return await self.write_with_length(writer, None) async def write_with_length( - self, writer: AbstractStreamWriter, content_length: Optional[int] + self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write BytesIO payload with a specific content length constraint. @@ -953,7 +940,7 @@ def __init__( if TYPE_CHECKING: - from typing import AsyncIterable, AsyncIterator + from collections.abc import AsyncIterable, AsyncIterator _AsyncIterator = AsyncIterator[bytes] _AsyncIterable = AsyncIterable[bytes] @@ -965,9 +952,9 @@ def __init__( class AsyncIterablePayload(Payload): - _iter: Optional[_AsyncIterator] = None + _iter: _AsyncIterator | None = None _value: _AsyncIterable - _cached_chunks: Optional[List[bytes]] = None + _cached_chunks: list[bytes] | None = None # _consumed stays False to allow reuse with cached content _autoclose = True # Iterator doesn't need explicit closing @@ -976,7 +963,7 @@ def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: raise TypeError( "value argument must support " "collections.abc.AsyncIterable interface, " - "got {!r}".format(type(value)) + f"got {type(value)!r}" ) if "content_type" not in kwargs: @@ -1004,7 +991,7 @@ async def write(self, writer: AbstractStreamWriter) -> None: await self.write_with_length(writer, None) async def write_with_length( - self, writer: AbstractStreamWriter, content_length: Optional[int] + self, writer: AbstractStreamWriter, content_length: int | None ) -> None: """ Write async iterable payload with a specific content length constraint. @@ -1043,10 +1030,7 @@ async def write_with_length( try: while True: - if sys.version_info >= (3, 10): - chunk = await anext(self._iter) - else: - chunk = await self._iter.__anext__() + chunk = await anext(self._iter) if remaining_bytes is None: await writer.write(chunk) # If we have a content length limit @@ -1084,7 +1068,7 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt return b"" # Read all chunks and cache them - chunks: List[bytes] = [] + chunks: list[bytes] = [] async for chunk in self._iter: chunks.append(chunk) diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index eeb33fafaf9..65f8dc8aa7f 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -2,19 +2,8 @@ import contextlib import inspect import warnings -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Iterator, - Optional, - Protocol, - Type, - TypeVar, - Union, - overload, -) +from collections.abc import Awaitable, Callable, Iterator +from typing import Any, Protocol, TypeVar, overload import pytest @@ -46,7 +35,7 @@ async def __call__( self, __param: Application, *, - server_kwargs: Optional[Dict[str, Any]] = None, + server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[Request, Application]: ... @overload @@ -54,14 +43,14 @@ async def __call__( self, __param: BaseTestServer[_Request], *, - server_kwargs: Optional[Dict[str, Any]] = None, + server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[_Request, None]: ... class AiohttpServer(Protocol): def __call__( - self, app: Application, *, port: Optional[int] = None, **kwargs: Any + self, app: Application, *, port: int | None = None, **kwargs: Any ) -> Awaitable[TestServer]: ... @@ -70,7 +59,7 @@ def __call__( self, handler: _RequestHandler[BaseRequest], *, - port: Optional[int] = None, + port: int | None = None, **kwargs: Any, ) -> Awaitable[RawTestServer]: ... @@ -177,7 +166,7 @@ def _runtime_warning_context() -> Iterator[None]: with warnings.catch_warnings(record=True) as _warnings: yield rw = [ - "{w.filename}:{w.lineno}:{w.message}".format(w=w) + f"{w.filename}:{w.lineno}:{w.message}" for w in _warnings if w.category == RuntimeWarning ] @@ -197,7 +186,7 @@ def _runtime_warning_context() -> Iterator[None]: @contextlib.contextmanager def _passthrough_loop_context( - loop: Optional[asyncio.AbstractEventLoop], fast: bool = False + loop: asyncio.AbstractEventLoop | None, fast: bool = False ) -> Iterator[asyncio.AbstractEventLoop]: """Passthrough loop context. @@ -315,7 +304,7 @@ async def go( app: Application, *, host: str = "127.0.0.1", - port: Optional[int] = None, + port: int | None = None, **kwargs: Any, ) -> TestServer: server = TestServer(app, host=host, port=port) @@ -343,7 +332,7 @@ def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawSe async def go( handler: _RequestHandler[BaseRequest], *, - port: Optional[int] = None, + port: int | None = None, **kwargs: Any, ) -> RawTestServer: server = RawTestServer(handler, port=port) @@ -361,7 +350,7 @@ async def finalize() -> None: @pytest.fixture -def aiohttp_client_cls() -> Type[TestClient[Any, Any]]: +def aiohttp_client_cls() -> type[TestClient[Any, Any]]: """ Client class to use in ``aiohttp_client`` factory. @@ -389,7 +378,7 @@ def test_login(aiohttp_client): @pytest.fixture def aiohttp_client( - loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient[Any, Any]] + loop: asyncio.AbstractEventLoop, aiohttp_client_cls: type[TestClient[Any, Any]] ) -> Iterator[AiohttpClient]: """Factory to create a TestClient instance. @@ -403,20 +392,20 @@ def aiohttp_client( async def go( __param: Application, *, - server_kwargs: Optional[Dict[str, Any]] = None, + server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[Request, Application]: ... @overload async def go( __param: BaseTestServer[_Request], *, - server_kwargs: Optional[Dict[str, Any]] = None, + server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[_Request, None]: ... async def go( - __param: Union[Application, BaseTestServer[Any]], + __param: Application | BaseTestServer[Any], *, - server_kwargs: Optional[Dict[str, Any]] = None, + server_kwargs: dict[str, Any] | None = None, **kwargs: Any, ) -> TestClient[Any, Any]: # TODO(PY311): Use Unpack to specify ClientSession kwargs and server_kwargs. diff --git a/aiohttp/resolver.py b/aiohttp/resolver.py index b07bd9716f2..8840a1ca3e1 100644 --- a/aiohttp/resolver.py +++ b/aiohttp/resolver.py @@ -1,7 +1,7 @@ import asyncio import socket import weakref -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Optional from .abc import AbstractResolver, ResolveResult @@ -36,7 +36,7 @@ def __init__(self) -> None: async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: infos = await self._loop.getaddrinfo( host, port, @@ -45,7 +45,7 @@ async def resolve( flags=_AI_ADDRCONFIG, ) - hosts: List[ResolveResult] = [] + hosts: list[ResolveResult] = [] for family, _, proto, _, address in infos: if family == socket.AF_INET6: if len(address) < 3: @@ -90,7 +90,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: raise RuntimeError("Resolver requires aiodns library") self._loop = asyncio.get_running_loop() - self._manager: Optional[_DNSResolverManager] = None + self._manager: _DNSResolverManager | None = None # If custom args are provided, create a dedicated resolver instance # This means each AsyncResolver with custom args gets its own # aiodns.DNSResolver instance @@ -103,7 +103,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: async def resolve( self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: try: resp = await self._resolver.getaddrinfo( host, @@ -115,9 +115,9 @@ async def resolve( except aiodns.error.DNSError as exc: msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" raise OSError(None, msg) from exc - hosts: List[ResolveResult] = [] + hosts: list[ResolveResult] = [] for node in resp.nodes: - address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr + address: tuple[bytes, int] | tuple[bytes, int, int, int] = node.addr family = node.family if family == socket.AF_INET6: if len(address) > 3 and address[3]: @@ -184,7 +184,7 @@ def _init(self) -> None: # Use WeakKeyDictionary to allow event loops to be garbage collected self._loop_data: weakref.WeakKeyDictionary[ asyncio.AbstractEventLoop, - tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]], + tuple[aiodns.DNSResolver, weakref.WeakSet[AsyncResolver]], ] = weakref.WeakKeyDictionary() def get_resolver( @@ -200,7 +200,7 @@ def get_resolver( # Create a new resolver and client set for this loop if it doesn't exist if loop not in self._loop_data: resolver = aiodns.DNSResolver(loop=loop) - client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet() + client_set: weakref.WeakSet[AsyncResolver] = weakref.WeakSet() self._loop_data[loop] = (resolver, client_set) else: # Get the existing resolver and client set @@ -232,5 +232,5 @@ def release_resolver( del self._loop_data[loop] -_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]] +_DefaultType = type[AsyncResolver | ThreadedResolver] DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 1b675a1b73d..9d1858a3285 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -1,19 +1,8 @@ import asyncio import collections import warnings -from typing import ( - Awaitable, - Callable, - Deque, - Final, - Generic, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from collections.abc import Awaitable, Callable +from typing import Final, Generic, TypeVar from .base_protocol import BaseProtocol from .helpers import ( @@ -69,7 +58,7 @@ def __init__(self, stream: "StreamReader") -> None: def __aiter__(self) -> "ChunkTupleAsyncStreamIterator": return self - async def __anext__(self) -> Tuple[bytes, bool]: + async def __anext__(self) -> tuple[bytes, bool]: rv = await self._stream.readchunk() if rv == (b"", False): raise StopAsyncIteration @@ -140,7 +129,7 @@ def __init__( protocol: BaseProtocol, limit: int, *, - timer: Optional[BaseTimerContext] = None, + timer: BaseTimerContext | None = None, loop: asyncio.AbstractEventLoop, ) -> None: self._protocol = protocol @@ -149,18 +138,18 @@ def __init__( self._loop = loop self._size = 0 self._cursor = 0 - self._http_chunk_splits: Optional[List[int]] = None - self._buffer: Deque[bytes] = collections.deque() + self._http_chunk_splits: list[int] | None = None + self._buffer: collections.deque[bytes] = collections.deque() self._buffer_offset = 0 self._eof = False - self._waiter: Optional[asyncio.Future[None]] = None - self._eof_waiter: Optional[asyncio.Future[None]] = None - self._exception: Optional[Union[Type[BaseException], BaseException]] = None + self._waiter: asyncio.Future[None] | None = None + self._eof_waiter: asyncio.Future[None] | None = None + self._exception: type[BaseException] | BaseException | None = None self._timer = TimerNoop() if timer is None else timer - self._eof_callbacks: List[Callable[[], None]] = [] + self._eof_callbacks: list[Callable[[], None]] = [] self._eof_counter = 0 self.total_bytes = 0 - self.total_compressed_bytes: Optional[int] = None + self.total_compressed_bytes: int | None = None def __repr__(self) -> str: info = [self.__class__.__name__] @@ -176,15 +165,15 @@ def __repr__(self) -> str: info.append("e=%r" % self._exception) return "<%s>" % " ".join(info) - def get_read_buffer_limits(self) -> Tuple[int, int]: + def get_read_buffer_limits(self) -> tuple[int, int]: return (self._low_water, self._high_water) - def exception(self) -> Optional[Union[Type[BaseException], BaseException]]: + def exception(self) -> type[BaseException] | BaseException | None: return self._exception def set_exception( self, - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: self._exception = exc @@ -434,7 +423,7 @@ async def readany(self) -> bytes: return self._read_nowait(-1) - async def readchunk(self) -> Tuple[bytes, bool]: + async def readchunk(self) -> tuple[bytes, bool]: """Returns a tuple of (data, end_of_http_chunk). When chunked transfer @@ -472,7 +461,7 @@ async def readexactly(self, n: int) -> bytes: if self._exception is not None: raise self._exception - blocks: List[bytes] = [] + blocks: list[bytes] = [] while n > 0: block = await self.read(n) if not block: @@ -553,12 +542,12 @@ def __init__(self) -> None: def __repr__(self) -> str: return "<%s>" % self.__class__.__name__ - def exception(self) -> Optional[BaseException]: + def exception(self) -> BaseException | None: return None def set_exception( self, - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: pass @@ -595,7 +584,7 @@ async def read(self, n: int = -1) -> bytes: async def readany(self) -> bytes: return b"" - async def readchunk(self) -> Tuple[bytes, bool]: + async def readchunk(self) -> tuple[bytes, bool]: if not self._read_eof_chunk: self._read_eof_chunk = True return (b"", False) @@ -618,9 +607,9 @@ class DataQueue(Generic[_T]): def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop self._eof = False - self._waiter: Optional[asyncio.Future[None]] = None - self._exception: Union[Type[BaseException], BaseException, None] = None - self._buffer: Deque[_T] = collections.deque() + self._waiter: asyncio.Future[None] | None = None + self._exception: type[BaseException] | BaseException | None = None + self._buffer: collections.deque[_T] = collections.deque() def __len__(self) -> int: return len(self._buffer) @@ -631,12 +620,12 @@ def is_eof(self) -> bool: def at_eof(self) -> bool: return self._eof and not self._buffer - def exception(self) -> Optional[Union[Type[BaseException], BaseException]]: + def exception(self) -> type[BaseException] | BaseException | None: return self._exception def set_exception( self, - exc: Union[Type[BaseException], BaseException], + exc: type[BaseException] | BaseException, exc_cause: BaseException = _EXC_SENTINEL, ) -> None: self._eof = True diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index 5d1c885d8d5..192173b42c8 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -8,22 +8,9 @@ import socket import sys from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generic, - Iterator, - List, - Optional, - Type, - TypeVar, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal @@ -61,7 +48,7 @@ if TYPE_CHECKING: from ssl import SSLContext else: - SSLContext = None + SSLContext = Any if sys.version_info >= (3, 11) and TYPE_CHECKING: from typing import Unpack @@ -111,15 +98,15 @@ def __init__( *, scheme: str = "", host: str = "127.0.0.1", - port: Optional[int] = None, + port: int | None = None, skip_url_asserts: bool = False, socket_factory: Callable[ [str, int, socket.AddressFamily], socket.socket ] = get_port_socket, **kwargs: Any, ) -> None: - self.runner: Optional[BaseRunner[_Request]] = None - self._root: Optional[URL] = None + self.runner: BaseRunner[_Request] | None = None + self._root: URL | None = None self.host = host self.port = port or 0 self._closed = False @@ -210,9 +197,9 @@ async def __aenter__(self) -> Self: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: await self.close() @@ -224,7 +211,7 @@ def __init__( *, scheme: str = "", host: str = "127.0.0.1", - port: Optional[int] = None, + port: int | None = None, **kwargs: Any, ): self.app = app @@ -242,7 +229,7 @@ def __init__( *, scheme: str = "", host: str = "127.0.0.1", - port: Optional[int] = None, + port: int | None = None, **kwargs: Any, ) -> None: self._handler = handler @@ -269,7 +256,7 @@ def __init__( self: "TestClient[Request, Application]", server: TestServer, *, - cookie_jar: Optional[AbstractCookieJar] = None, + cookie_jar: AbstractCookieJar | None = None, **kwargs: Any, ) -> None: ... @overload @@ -277,14 +264,14 @@ def __init__( self: "TestClient[_Request, None]", server: BaseTestServer[_Request], *, - cookie_jar: Optional[AbstractCookieJar] = None, + cookie_jar: AbstractCookieJar | None = None, **kwargs: Any, ) -> None: ... def __init__( # type: ignore[misc] self, server: BaseTestServer[_Request], *, - cookie_jar: Optional[AbstractCookieJar] = None, + cookie_jar: AbstractCookieJar | None = None, **kwargs: Any, ) -> None: # TODO(PY311): Use Unpack to specify ClientSession kwargs. @@ -298,14 +285,14 @@ def __init__( # type: ignore[misc] self._session = ClientSession(cookie_jar=cookie_jar, **kwargs) self._session._retry_connection = False self._closed = False - self._responses: List[ClientResponse] = [] - self._websockets: List[ClientWebSocketResponse] = [] + self._responses: list[ClientResponse] = [] + self._websockets: list[ClientWebSocketResponse] = [] async def start_server(self) -> None: await self._server.start_server() @property - def scheme(self) -> Union[str, object]: + def scheme(self) -> str | object: return self._server.scheme @property @@ -484,9 +471,9 @@ async def __aenter__(self) -> Self: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> None: await self.close() @@ -591,10 +578,10 @@ def set_dict(app: Any, key: str, value: Any) -> None: return app -def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: +def _create_transport(sslcontext: SSLContext | None = None) -> mock.Mock: transport = mock.Mock() - def get_extra_info(key: str) -> Optional[SSLContext]: + def get_extra_info(key: str) -> SSLContext | None: if key == "sslcontext": return sslcontext else: @@ -607,17 +594,17 @@ def get_extra_info(key: str) -> Optional[SSLContext]: def make_mocked_request( method: str, path: str, - headers: Optional[LooseHeaders] = None, + headers: LooseHeaders | None = None, *, - match_info: Optional[Dict[str, str]] = None, + match_info: dict[str, str] | None = None, version: HttpVersion = HttpVersion(1, 1), closing: bool = False, - app: Optional[Application] = None, - writer: Optional[AbstractStreamWriter] = None, - protocol: Optional[RequestHandler[Request]] = None, - transport: Optional[asyncio.Transport] = None, + app: Application | None = None, + writer: AbstractStreamWriter | None = None, + protocol: RequestHandler[Request] | None = None, + transport: asyncio.Transport | None = None, payload: StreamReader = EMPTY_PAYLOAD, - sslcontext: Optional[SSLContext] = None, + sslcontext: SSLContext | None = None, client_max_size: int = 1024**2, loop: Any = ..., ) -> Request: diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py index cc8c0825b4e..dd7ad257460 100644 --- a/aiohttp/typedefs.py +++ b/aiohttp/typedefs.py @@ -1,16 +1,7 @@ import json import os -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Iterable, - Mapping, - Protocol, - Tuple, - Union, -) +from collections.abc import Awaitable, Callable, Iterable, Mapping +from typing import TYPE_CHECKING, Any, Protocol, Union from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr from yarl import URL, Query as _Query @@ -42,14 +33,14 @@ Mapping[istr, str], _CIMultiDict, _CIMultiDictProxy, - Iterable[Tuple[Union[str, istr], str]], + Iterable[tuple[str | istr, str]], ] -RawHeaders = Tuple[Tuple[bytes, bytes], ...] +RawHeaders = tuple[tuple[bytes, bytes], ...] StrOrURL = Union[str, URL] LooseCookiesMappings = Mapping[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]] LooseCookiesIterables = Iterable[ - Tuple[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]] + tuple[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]] ] LooseCookies = Union[ LooseCookiesMappings, diff --git a/aiohttp/web.py b/aiohttp/web.py index 13dd078e7f6..895d91e662e 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -5,22 +5,10 @@ import sys import warnings from argparse import ArgumentParser -from collections.abc import Iterable +from collections.abc import Awaitable, Callable, Iterable, Iterable as TypingIterable from contextlib import suppress from importlib import import_module -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Iterable as TypingIterable, - List, - Optional, - Set, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, cast from .abc import AbstractAccessLogger from .helpers import AppKey @@ -283,23 +271,23 @@ async def _run_app( - app: Union[Application, Awaitable[Application]], + app: Application | Awaitable[Application], *, - host: Optional[Union[str, HostSequence]] = None, - port: Optional[int] = None, - path: Union[PathLike, TypingIterable[PathLike], None] = None, - sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None, + host: str | HostSequence | None = None, + port: int | None = None, + path: PathLike | TypingIterable[PathLike] | None = None, + sock: socket.socket | TypingIterable[socket.socket] | None = None, shutdown_timeout: float = 60.0, keepalive_timeout: float = 75.0, - ssl_context: Optional[SSLContext] = None, - print: Optional[Callable[..., None]] = print, + ssl_context: SSLContext | None = None, + print: Callable[..., None] | None = print, backlog: int = 128, - access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_class: type[AbstractAccessLogger] = AccessLogger, access_log_format: str = AccessLogger.LOG_FORMAT, - access_log: Optional[logging.Logger] = access_logger, + access_log: logging.Logger | None = access_logger, handle_signals: bool = True, - reuse_address: Optional[bool] = None, - reuse_port: Optional[bool] = None, + reuse_address: bool | None = None, + reuse_port: bool | None = None, handler_cancellation: bool = False, ) -> None: # An internal function to actually do all dirty job for application running @@ -321,7 +309,7 @@ async def _run_app( await runner.setup() - sites: List[BaseSite] = [] + sites: list[BaseSite] = [] try: if host is not None: @@ -421,7 +409,7 @@ async def _run_app( def _cancel_tasks( - to_cancel: Set["asyncio.Task[Any]"], loop: asyncio.AbstractEventLoop + to_cancel: set["asyncio.Task[Any]"], loop: asyncio.AbstractEventLoop ) -> None: if not to_cancel: return @@ -445,26 +433,26 @@ def _cancel_tasks( def run_app( - app: Union[Application, Awaitable[Application]], + app: Application | Awaitable[Application], *, debug: bool = False, - host: Optional[Union[str, HostSequence]] = None, - port: Optional[int] = None, - path: Union[PathLike, TypingIterable[PathLike], None] = None, - sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None, + host: str | HostSequence | None = None, + port: int | None = None, + path: PathLike | TypingIterable[PathLike] | None = None, + sock: socket.socket | TypingIterable[socket.socket] | None = None, shutdown_timeout: float = 60.0, keepalive_timeout: float = 75.0, - ssl_context: Optional[SSLContext] = None, - print: Optional[Callable[..., None]] = print, + ssl_context: SSLContext | None = None, + print: Callable[..., None] | None = print, backlog: int = 128, - access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_class: type[AbstractAccessLogger] = AccessLogger, access_log_format: str = AccessLogger.LOG_FORMAT, - access_log: Optional[logging.Logger] = access_logger, + access_log: logging.Logger | None = access_logger, handle_signals: bool = True, - reuse_address: Optional[bool] = None, - reuse_port: Optional[bool] = None, + reuse_address: bool | None = None, + reuse_port: bool | None = None, handler_cancellation: bool = False, - loop: Optional[asyncio.AbstractEventLoop] = None, + loop: asyncio.AbstractEventLoop | None = None, ) -> None: """Run an app locally""" if loop is None: @@ -517,7 +505,7 @@ def run_app( asyncio.set_event_loop(None) -def main(argv: List[str]) -> None: +def main(argv: list[str]) -> None: arg_parser = ArgumentParser( description="aiohttp.web Application server", prog="aiohttp.web" ) diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index 69eafe5c7d4..bef158219e4 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -1,29 +1,18 @@ import asyncio import logging import warnings -from functools import lru_cache, partial, update_wrapper -from typing import ( - TYPE_CHECKING, - Any, +from collections.abc import ( AsyncIterator, Awaitable, Callable, - Dict, Iterable, Iterator, - List, Mapping, MutableMapping, - Optional, Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, - final, - overload, ) +from functools import lru_cache, partial, update_wrapper +from typing import TYPE_CHECKING, Any, TypeVar, cast, final, overload from aiosignal import Signal from frozenlist import FrozenList @@ -56,7 +45,7 @@ _RespPrepareSignal = Signal[Request, StreamResponse] _Middlewares = FrozenList[Middleware] _MiddlewaresHandlers = Sequence[Middleware] - _Subapps = List["Application"] + _Subapps = list["Application"] else: # No type checker mode, skip types _AppSignal = Signal @@ -64,7 +53,7 @@ _Handler = Callable _Middlewares = FrozenList _MiddlewaresHandlers = Sequence - _Subapps = List + _Subapps = list _T = TypeVar("_T") _U = TypeVar("_U") @@ -72,7 +61,7 @@ def _build_middlewares( - handler: Handler, apps: Tuple["Application", ...] + handler: Handler, apps: tuple["Application", ...] ) -> Callable[[Request], Awaitable[StreamResponse]]: """Apply middlewares to handler.""" # The slice is to reverse the order of the apps @@ -88,7 +77,7 @@ def _build_middlewares( @final -class Application(MutableMapping[Union[str, AppKey[Any]], Any]): +class Application(MutableMapping[str | AppKey[Any], Any]): __slots__ = ( "logger", "_router", @@ -114,7 +103,7 @@ def __init__( *, logger: logging.Logger = web_logger, middlewares: Iterable[Middleware] = (), - handler_args: Optional[Mapping[str, Any]] = None, + handler_args: Mapping[str, Any] | None = None, client_max_size: int = 1024**2, debug: Any = ..., # mypy doesn't support ellipsis ) -> None: @@ -133,9 +122,9 @@ def __init__( # initialized on freezing self._middlewares_handlers: _MiddlewaresHandlers = tuple() # initialized on freezing - self._run_middlewares: Optional[bool] = None + self._run_middlewares: bool | None = None - self._state: Dict[Union[AppKey[Any], str], object] = {} + self._state: dict[AppKey[Any] | str, object] = {} self._frozen = False self._pre_frozen = False self._subapps: _Subapps = [] @@ -149,10 +138,9 @@ def __init__( self._on_cleanup.append(self._cleanup_ctx._on_cleanup) self._client_max_size = client_max_size - def __init_subclass__(cls: Type["Application"]) -> None: + def __init_subclass__(cls: type["Application"]) -> None: raise TypeError( - "Inheritance class {} from web.Application " - "is forbidden".format(cls.__name__) + f"Inheritance class {cls.__name__} from web.Application " "is forbidden" ) # MutableMapping API @@ -166,7 +154,7 @@ def __getitem__(self, key: AppKey[_T]) -> _T: ... @overload def __getitem__(self, key: str) -> Any: ... - def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: + def __getitem__(self, key: str | AppKey[_T]) -> Any: return self._state[key] def _check_frozen(self) -> None: @@ -181,7 +169,7 @@ def __setitem__(self, key: AppKey[_T], value: _T) -> None: ... @overload def __setitem__(self, key: str, value: Any) -> None: ... - def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None: + def __setitem__(self, key: str | AppKey[_T], value: Any) -> None: self._check_frozen() if not isinstance(key, AppKey): warnings.warn( @@ -193,33 +181,33 @@ def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None: ) self._state[key] = value - def __delitem__(self, key: Union[str, AppKey[_T]]) -> None: + def __delitem__(self, key: str | AppKey[_T]) -> None: self._check_frozen() del self._state[key] def __len__(self) -> int: return len(self._state) - def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: + def __iter__(self) -> Iterator[str | AppKey[Any]]: return iter(self._state) def __hash__(self) -> int: return id(self) @overload # type: ignore[override] - def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ... + def get(self, key: AppKey[_T], default: None = ...) -> _T | None: ... @overload - def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: ... + def get(self, key: AppKey[_T], default: _U) -> _T | _U: ... @overload def get(self, key: str, default: Any = ...) -> Any: ... - def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: + def get(self, key: str | AppKey[_T], default: Any = None) -> Any: return self._state.get(key, default) ######## - def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None: + def _set_loop(self, loop: asyncio.AbstractEventLoop | None) -> None: warnings.warn( "_set_loop() is no-op since 4.0 and scheduled for removal in 5.0", DeprecationWarning, @@ -324,7 +312,7 @@ def add_domain(self, domain: str, subapp: "Application") -> MatchedSubAppResourc factory = partial(MatchedSubAppResource, rule, subapp) return self._add_subapp(factory, subapp) - def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]: + def add_routes(self, routes: Iterable[AbstractRouteDef]) -> list[AbstractRoute]: return self.router.add_routes(routes) @property @@ -423,8 +411,8 @@ def __bool__(self) -> bool: class CleanupError(RuntimeError): @property - def exceptions(self) -> List[BaseException]: - return cast(List[BaseException], self.args[1]) + def exceptions(self) -> list[BaseException]: + return cast(list[BaseException], self.args[1]) if TYPE_CHECKING: @@ -436,7 +424,7 @@ def exceptions(self) -> List[BaseException]: class CleanupContext(_CleanupContextBase): def __init__(self) -> None: super().__init__() - self._exits: List[AsyncIterator[None]] = [] + self._exits: list[AsyncIterator[None]] = [] async def _on_startup(self, app: Application) -> None: for cb in self: diff --git a/aiohttp/web_exceptions.py b/aiohttp/web_exceptions.py index 5fdd27695f5..19cfdb0657d 100644 --- a/aiohttp/web_exceptions.py +++ b/aiohttp/web_exceptions.py @@ -1,6 +1,7 @@ import warnings +from collections.abc import Iterable from http import HTTPStatus -from typing import Any, Iterable, Optional, Set, Tuple +from typing import Any from multidict import CIMultiDict from yarl import URL @@ -91,10 +92,10 @@ class HTTPException(CookieMixin, Exception): def __init__( self, *, - headers: Optional[LooseHeaders] = None, - reason: Optional[str] = None, - text: Optional[str] = None, - content_type: Optional[str] = None, + headers: LooseHeaders | None = None, + reason: str | None = None, + text: str | None = None, + content_type: str | None = None, ) -> None: if reason is None: reason = self.default_reason @@ -107,11 +108,9 @@ def __init__( else: if self.empty_body: warnings.warn( - "text argument is deprecated for HTTP status {} " + f"text argument is deprecated for HTTP status {self.status_code} " "since 4.0 and scheduled for removal in 5.0 (#3462)," - "the response should be provided without a body".format( - self.status_code - ), + "the response should be provided without a body", DeprecationWarning, stacklevel=2, ) @@ -151,7 +150,7 @@ def reason(self) -> str: return self._reason @property - def text(self) -> Optional[str]: + def text(self) -> str | None: return self._text @property @@ -166,7 +165,7 @@ def __repr__(self) -> str: __reduce__ = object.__reduce__ - def __getnewargs__(self) -> Tuple[Any, ...]: + def __getnewargs__(self) -> tuple[Any, ...]: return self.args @@ -222,10 +221,10 @@ def __init__( self, location: StrOrURL, *, - headers: Optional[LooseHeaders] = None, - reason: Optional[str] = None, - text: Optional[str] = None, - content_type: Optional[str] = None, + headers: LooseHeaders | None = None, + reason: str | None = None, + text: str | None = None, + content_type: str | None = None, ) -> None: if not location: raise ValueError("HTTP redirects need a location to redirect to.") @@ -314,21 +313,21 @@ def __init__( method: str, allowed_methods: Iterable[str], *, - headers: Optional[LooseHeaders] = None, - reason: Optional[str] = None, - text: Optional[str] = None, - content_type: Optional[str] = None, + headers: LooseHeaders | None = None, + reason: str | None = None, + text: str | None = None, + content_type: str | None = None, ) -> None: allow = ",".join(sorted(allowed_methods)) super().__init__( headers=headers, reason=reason, text=text, content_type=content_type ) self.headers["Allow"] = allow - self._allowed: Set[str] = set(allowed_methods) + self._allowed: set[str] = set(allowed_methods) self._method = method @property - def allowed_methods(self) -> Set[str]: + def allowed_methods(self) -> set[str]: return self._allowed @property @@ -370,8 +369,8 @@ class HTTPRequestEntityTooLarge(HTTPClientError): def __init__(self, max_size: int, actual_size: int, **kwargs: Any) -> None: kwargs.setdefault( "text", - "Maximum request body size {} exceeded, " - "actual body size {}".format(max_size, actual_size), + f"Maximum request body size {max_size} exceeded, " + f"actual body size {actual_size}", ) super().__init__(**kwargs) @@ -425,12 +424,12 @@ class HTTPUnavailableForLegalReasons(HTTPClientError): def __init__( self, - link: Optional[StrOrURL], + link: StrOrURL | None, *, - headers: Optional[LooseHeaders] = None, - reason: Optional[str] = None, - text: Optional[str] = None, - content_type: Optional[str] = None, + headers: LooseHeaders | None = None, + reason: str | None = None, + text: str | None = None, + content_type: str | None = None, ) -> None: super().__init__( headers=headers, reason=reason, text=text, content_type=content_type @@ -441,7 +440,7 @@ def __init__( self.headers["Link"] = f'<{str(self._link)}>; rel="blocked-by"' @property - def link(self) -> Optional[URL]: + def link(self) -> URL | None: return self._link diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 6b7d002a86c..921741e8ef5 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -3,22 +3,13 @@ import os import pathlib import sys +from collections.abc import Awaitable, Callable from contextlib import suppress from enum import Enum, auto from mimetypes import MimeTypes from stat import S_ISREG from types import MappingProxyType -from typing import ( - IO, - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Final, - Optional, - Set, - Tuple, -) +from typing import IO, TYPE_CHECKING, Any, Final, Optional from . import hdrs from .abc import AbstractStreamWriter @@ -82,7 +73,7 @@ class _FileResponseResult(Enum): CONTENT_TYPES.add_type(content_type, extension) -_CLOSE_FUTURES: Set[asyncio.Future[None]] = set() +_CLOSE_FUTURES: set[asyncio.Future[None]] = set() class FileResponse(StreamResponse): @@ -93,8 +84,8 @@ def __init__( path: PathLike, chunk_size: int = 256 * 1024, status: int = 200, - reason: Optional[str] = None, - headers: Optional[LooseHeaders] = None, + reason: str | None = None, + headers: LooseHeaders | None = None, ) -> None: super().__init__(status=status, reason=reason, headers=headers) @@ -148,7 +139,7 @@ async def _sendfile( return writer @staticmethod - def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool: + def _etag_match(etag_value: str, etags: tuple[ETag, ...], *, weak: bool) -> bool: if len(etags) == 1 and etags[0].value == ETAG_ANY: return True return any( @@ -157,7 +148,7 @@ def _etag_match(etag_value: str, etags: Tuple[ETag, ...], *, weak: bool) -> bool async def _not_modified( self, request: "BaseRequest", etag_value: str, last_modified: float - ) -> Optional[AbstractStreamWriter]: + ) -> AbstractStreamWriter | None: self.set_status(HTTPNotModified.status_code) self._length_check = False self.etag = etag_value @@ -168,15 +159,15 @@ async def _not_modified( async def _precondition_failed( self, request: "BaseRequest" - ) -> Optional[AbstractStreamWriter]: + ) -> AbstractStreamWriter | None: self.set_status(HTTPPreconditionFailed.status_code) self.content_length = 0 return await super().prepare(request) def _make_response( self, request: "BaseRequest", accept_encoding: str - ) -> Tuple[ - _FileResponseResult, Optional[io.BufferedReader], os.stat_result, Optional[str] + ) -> tuple[ + _FileResponseResult, io.BufferedReader | None, os.stat_result, str | None ]: """Return the response result, io object, stat result, and encoding. @@ -231,7 +222,7 @@ def _make_response( def _get_file_path_stat_encoding( self, accept_encoding: str - ) -> Tuple[Optional[pathlib.Path], os.stat_result, Optional[str]]: + ) -> tuple[pathlib.Path | None, os.stat_result, str | None]: file_path = self._path for file_extension, file_encoding in ENCODING_EXTENSIONS.items(): if file_encoding not in accept_encoding: @@ -248,7 +239,7 @@ def _get_file_path_stat_encoding( st = file_path.stat() return file_path if S_ISREG(st.st_mode) else None, st, None - async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: + async def prepare(self, request: "BaseRequest") -> AbstractStreamWriter | None: loop = asyncio.get_running_loop() # Encoding comparisons should be case-insensitive # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 @@ -298,13 +289,13 @@ async def _prepare_open_file( request: "BaseRequest", fobj: io.BufferedReader, st: os.stat_result, - file_encoding: Optional[str], - ) -> Optional[AbstractStreamWriter]: + file_encoding: str | None, + ) -> AbstractStreamWriter | None: status = self._status file_size: int = st.st_size file_mtime: float = st.st_mtime count: int = file_size - start: Optional[int] = None + start: int | None = None if (ifrange := request.if_range) is None or file_mtime <= ifrange.timestamp(): # If-Range header check: @@ -317,7 +308,7 @@ async def _prepare_open_file( try: rng = request.http_range start = rng.start - end: Optional[int] = rng.stop + end: int | None = rng.stop except ValueError: # https://tools.ietf.org/html/rfc7233: # A server generating a 416 (Range Not Satisfiable) response to @@ -400,8 +391,8 @@ async def _prepare_open_file( if status == HTTPPartialContent.status_code: real_start = start assert real_start is not None - self._headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format( - real_start, real_start + count - 1, file_size + self._headers[hdrs.CONTENT_RANGE] = ( + f"bytes {real_start}-{real_start + count - 1}/{file_size}" ) # If we are sending 0 bytes calling sendfile() will throw a ValueError diff --git a/aiohttp/web_log.py b/aiohttp/web_log.py index aff1a73a81b..95b34e1029a 100644 --- a/aiohttp/web_log.py +++ b/aiohttp/web_log.py @@ -58,7 +58,7 @@ class AccessLogger(AbstractAccessLogger): LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' FORMAT_RE = re.compile(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)") CLEANUP_RE = re.compile(r"(%[^s])") - _FORMAT_CACHE: Dict[str, Tuple[str, List[KeyMethod]]] = {} + _FORMAT_CACHE: dict[str, tuple[str, list[KeyMethod]]] = {} def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None: """Initialise the logger. @@ -76,7 +76,7 @@ def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None self._log_format, self._methods = _compiled_format - def compile_format(self, log_format: str) -> Tuple[str, List[KeyMethod]]: + def compile_format(self, log_format: str) -> tuple[str, list[KeyMethod]]: """Translate log_format into form usable by modulo formatting All known atoms will be replaced with %s @@ -149,12 +149,7 @@ def _format_P(request: BaseRequest, response: StreamResponse, time: float) -> st @staticmethod def _format_r(request: BaseRequest, response: StreamResponse, time: float) -> str: - return "{} {} HTTP/{}.{}".format( - request.method, - request.path_qs, - request.version.major, - request.version.minor, - ) + return f"{request.method} {request.path_qs} HTTP/{request.version.major}.{request.version.minor}" @staticmethod def _format_s(request: BaseRequest, response: StreamResponse, time: float) -> int: @@ -178,7 +173,7 @@ def _format_D(request: BaseRequest, response: StreamResponse, time: float) -> st def _format_line( self, request: BaseRequest, response: StreamResponse, time: float - ) -> Iterable[Tuple[str, Callable[[BaseRequest, StreamResponse, float], str]]]: + ) -> Iterable[tuple[str, Callable[[BaseRequest, StreamResponse, float], str]]]: return [(key, method(request, response, time)) for key, method in self._methods] @property diff --git a/aiohttp/web_middlewares.py b/aiohttp/web_middlewares.py index 22e63f872cf..5e26fe354cb 100644 --- a/aiohttp/web_middlewares.py +++ b/aiohttp/web_middlewares.py @@ -1,6 +1,6 @@ import re import warnings -from typing import TYPE_CHECKING, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, TypeVar from .typedefs import Handler, Middleware from .web_exceptions import HTTPMove, HTTPPermanentRedirect @@ -19,7 +19,7 @@ _Func = TypeVar("_Func") -async def _check_request_resolves(request: Request, path: str) -> Tuple[bool, Request]: +async def _check_request_resolves(request: Request, path: str) -> tuple[bool, Request]: alt_request = request.clone(rel_url=path) match_info = await request.app.router.resolve(alt_request) @@ -47,7 +47,7 @@ def normalize_path_middleware( append_slash: bool = True, remove_slash: bool = False, merge_slashes: bool = True, - redirect_class: Type[HTTPMove] = HTTPPermanentRedirect, + redirect_class: type[HTTPMove] = HTTPPermanentRedirect, ) -> Middleware: """Factory for producing a middleware that normalizes the path of a request. diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index cdeb423554f..0d58fee567b 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -3,25 +3,12 @@ import sys import traceback from collections import deque +from collections.abc import Awaitable, Callable, Sequence from contextlib import suppress from html import escape as html_escape from http import HTTPStatus from logging import Logger -from typing import ( - TYPE_CHECKING, - Any, - Awaitable, - Callable, - Deque, - Generic, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast import yarl from propcache import under_cached_property @@ -67,8 +54,8 @@ _RequestHandler = Callable[[_Request], Awaitable[StreamResponse]] _AnyAbstractAccessLogger = Union[ - Type[AbstractAsyncAccessLogger], - Type[AbstractAccessLogger], + type[AbstractAsyncAccessLogger], + type[AbstractAccessLogger], ] ERROR = RawRequestMessage( @@ -126,7 +113,7 @@ class _ErrInfo: message: str -_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader] +_MsgType = tuple[RawRequestMessage | _ErrInfo, StreamReader] class RequestHandler(BaseProtocol, Generic[_Request]): @@ -206,7 +193,7 @@ def __init__( tcp_keepalive: bool = True, logger: Logger = server_logger, access_log_class: _AnyAbstractAccessLogger = AccessLogger, - access_log: Optional[Logger] = access_logger, + access_log: Logger | None = access_logger, access_log_format: str = AccessLogger.LOG_FORMAT, max_line_size: int = 8190, max_field_size: int = 8190, @@ -220,32 +207,32 @@ def __init__( # _request_count is the number of requests processed with the same connection. self._request_count = 0 self._keepalive = False - self._current_request: Optional[_Request] = None - self._manager: Optional[Server[_Request]] = manager - self._request_handler: Optional[_RequestHandler[_Request]] = ( + self._current_request: _Request | None = None + self._manager: Server[_Request] | None = manager + self._request_handler: _RequestHandler[_Request] | None = ( manager.request_handler ) - self._request_factory: Optional[_RequestFactory[_Request]] = ( + self._request_factory: _RequestFactory[_Request] | None = ( manager.request_factory ) self._tcp_keepalive = tcp_keepalive # placeholder to be replaced on keepalive timeout setup self._next_keepalive_close_time = 0.0 - self._keepalive_handle: Optional[asyncio.Handle] = None + self._keepalive_handle: asyncio.Handle | None = None self._keepalive_timeout = keepalive_timeout self._lingering_time = float(lingering_time) - self._messages: Deque[_MsgType] = deque() + self._messages: deque[_MsgType] = deque() self._message_tail = b"" - self._waiter: Optional[asyncio.Future[None]] = None - self._handler_waiter: Optional[asyncio.Future[None]] = None - self._task_handler: Optional[asyncio.Task[None]] = None + self._waiter: asyncio.Future[None] | None = None + self._handler_waiter: asyncio.Future[None] | None = None + self._task_handler: asyncio.Task[None] | None = None self._upgrade = False self._payload_parser: Any = None - self._request_parser: Optional[HttpRequestParser] = HttpRequestParser( + self._request_parser: HttpRequestParser | None = HttpRequestParser( self, loop, read_bufsize, @@ -265,7 +252,7 @@ def __init__( self.access_log = access_log if access_log: if issubclass(access_log_class, AbstractAsyncAccessLogger): - self.access_logger: Optional[AbstractAsyncAccessLogger] = ( + self.access_logger: AbstractAsyncAccessLogger | None = ( access_log_class() ) else: @@ -302,7 +289,7 @@ def ssl_context(self) -> Optional["ssl.SSLContext"]: @under_cached_property def peername( self, - ) -> Optional[Union[str, Tuple[str, int, int, int], Tuple[str, int]]]: + ) -> str | tuple[str, int, int, int] | tuple[str, int] | None: """Return peername if available.""" return ( None @@ -314,7 +301,7 @@ def peername( def keepalive_timeout(self) -> float: return self._keepalive_timeout - async def shutdown(self, timeout: Optional[float] = 15.0) -> None: + async def shutdown(self, timeout: float | None = 15.0) -> None: """Do worker process exit preparations. We need to clean up everything and stop accepting requests. @@ -381,7 +368,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: task = loop.create_task(self.start()) self._task_handler = task - def connection_lost(self, exc: Optional[BaseException]) -> None: + def connection_lost(self, exc: BaseException | None) -> None: if self._manager is None: return self._manager.connection_lost(self, exc) @@ -498,7 +485,7 @@ async def log_access( self, request: BaseRequest, response: StreamResponse, - request_start: Optional[float], + request_start: float | None, ) -> None: if self._logging_enabled and self.access_logger is not None: if TYPE_CHECKING: @@ -532,9 +519,9 @@ def _process_keepalive(self) -> None: async def _handle_request( self, request: _Request, - start_time: Optional[float], + start_time: float | None, request_handler: Callable[[_Request], Awaitable[StreamResponse]], - ) -> Tuple[StreamResponse, bool]: + ) -> tuple[StreamResponse, bool]: self._request_in_progress = True try: try: @@ -706,8 +693,8 @@ async def start(self) -> None: self.transport.close() async def finish_response( - self, request: BaseRequest, resp: StreamResponse, start_time: Optional[float] - ) -> Tuple[StreamResponse, bool]: + self, request: BaseRequest, resp: StreamResponse, start_time: float | None + ) -> tuple[StreamResponse, bool]: """Prepare the response and write_eof, then log access. This has to @@ -729,8 +716,7 @@ async def finish_response( self.log_exception("Missing return statement on request handler") # type: ignore[unreachable] else: self.log_exception( - "Web-handler should return a response instance, " - "got {!r}".format(resp) + "Web-handler should return a response instance, " f"got {resp!r}" ) exc = HTTPInternalServerError() resp = Response( @@ -751,8 +737,8 @@ def handle_error( self, request: BaseRequest, status: int = 500, - exc: Optional[BaseException] = None, - message: Optional[str] = None, + exc: BaseException | None = None, + message: str | None = None, ) -> StreamResponse: """Handle errors. @@ -781,7 +767,7 @@ def handle_error( ct = "text/plain" if status == HTTPStatus.INTERNAL_SERVER_ERROR: - title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) + title = f"{HTTPStatus.INTERNAL_SERVER_ERROR.value} {HTTPStatus.INTERNAL_SERVER_ERROR.phrase}" msg = HTTPStatus.INTERNAL_SERVER_ERROR.description tb = None if self._loop.get_debug(): @@ -794,10 +780,10 @@ def handle_error( msg = f"

Traceback:

\n
{tb}
" message = ( "" - "{title}" - "\n

{title}

" - "\n{msg}\n\n" - ).format(title=title, msg=msg) + f"{title}" + f"\n

{title}

" + f"\n{msg}\n\n" + ) ct = "text/html" else: if tb: diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 5f0317954d5..84a15753c4b 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -7,21 +7,10 @@ import sys import tempfile import types +from collections.abc import Iterator, Mapping, MutableMapping +from re import Pattern from types import MappingProxyType -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Final, - Iterator, - Mapping, - MutableMapping, - Optional, - Pattern, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Final, Optional, cast from urllib.parse import parse_qsl from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy @@ -98,15 +87,9 @@ class FileField: _QUOTED_PAIR: Final[str] = r"\\[\t !-~]" -_QUOTED_STRING: Final[str] = r'"(?:{quoted_pair}|{qdtext})*"'.format( - qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR -) +_QUOTED_STRING: Final[str] = rf'"(?:{_QUOTED_PAIR}|{_QDTEXT})*"' -_FORWARDED_PAIR: Final[str] = ( - r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format( - token=_TOKEN, quoted_string=_QUOTED_STRING - ) -) +_FORWARDED_PAIR: Final[str] = rf"({_TOKEN})=({_TOKEN}|{_QUOTED_STRING})(:\d{{1,4}})?" _QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])") # same pattern as _QUOTED_PAIR but contains a capture group @@ -127,8 +110,8 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): hdrs.METH_DELETE, } - _post: Optional[MultiDictProxy[Union[str, bytes, FileField]]] = None - _read_bytes: Optional[bytes] = None + _post: MultiDictProxy[str | bytes | FileField] | None = None + _read_bytes: bytes | None = None def __init__( self, @@ -140,10 +123,10 @@ def __init__( loop: asyncio.AbstractEventLoop, *, client_max_size: int = 1024**2, - state: Optional[Dict[str, Any]] = None, - scheme: Optional[str] = None, - host: Optional[str] = None, - remote: Optional[str] = None, + state: dict[str, Any] | None = None, + scheme: str | None = None, + host: str | None = None, + remote: str | None = None, ) -> None: self._message = message self._protocol = protocol @@ -153,7 +136,7 @@ def __init__( self._headers: CIMultiDictProxy[str] = message.headers self._method = message.method self._version = message.version - self._cache: Dict[str, Any] = {} + self._cache: dict[str, Any] = {} url = message.url if url.absolute: if scheme is not None: @@ -188,13 +171,13 @@ def __init__( def clone( self, *, - method: Union[str, _SENTINEL] = sentinel, - rel_url: Union[StrOrURL, _SENTINEL] = sentinel, - headers: Union[LooseHeaders, _SENTINEL] = sentinel, - scheme: Union[str, _SENTINEL] = sentinel, - host: Union[str, _SENTINEL] = sentinel, - remote: Union[str, _SENTINEL] = sentinel, - client_max_size: Union[int, _SENTINEL] = sentinel, + method: str | _SENTINEL = sentinel, + rel_url: StrOrURL | _SENTINEL = sentinel, + headers: LooseHeaders | _SENTINEL = sentinel, + scheme: str | _SENTINEL = sentinel, + host: str | _SENTINEL = sentinel, + remote: str | _SENTINEL = sentinel, + client_max_size: int | _SENTINEL = sentinel, ) -> "BaseRequest": """Clone itself with replacement some attributes. @@ -205,7 +188,7 @@ def clone( if self._read_bytes: raise RuntimeError("Cannot clone request after reading its content") - dct: Dict[str, Any] = {} + dct: dict[str, Any] = {} if method is not sentinel: dct["method"] = method if rel_url is not sentinel: @@ -222,7 +205,7 @@ def clone( message = self._message._replace(**dct) - kwargs: Dict[str, str] = {} + kwargs: dict[str, str] = {} if scheme is not sentinel: kwargs["scheme"] = scheme if host is not sentinel: @@ -253,7 +236,7 @@ def protocol(self) -> "RequestHandler[Self]": return self._protocol @property - def transport(self) -> Optional[asyncio.Transport]: + def transport(self) -> asyncio.Transport | None: return self._protocol.transport @property @@ -293,7 +276,7 @@ def secure(self) -> bool: return self.scheme == "https" @reify - def forwarded(self) -> Tuple[Mapping[str, str], ...]: + def forwarded(self) -> tuple[Mapping[str, str], ...]: """A tuple containing all parsed Forwarded header(s). Makes an effort to parse Forwarded headers as specified by RFC 7239: @@ -317,7 +300,7 @@ def forwarded(self) -> Tuple[Mapping[str, str], ...]: length = len(field_value) pos = 0 need_separator = False - elem: Dict[str, str] = {} + elem: dict[str, str] = {} elems.append(types.MappingProxyType(elem)) while 0 <= pos < length: match = _FORWARDED_PAIR_RE.match(field_value, pos) @@ -405,7 +388,7 @@ def host(self) -> str: return socket.getfqdn() @reify - def remote(self) -> Optional[str]: + def remote(self) -> str | None: """Remote IP of client initiated HTTP request. The IP is resolved in this order: @@ -476,7 +459,7 @@ def raw_headers(self) -> RawHeaders: return self._message.raw_headers @reify - def if_modified_since(self) -> Optional[datetime.datetime]: + def if_modified_since(self) -> datetime.datetime | None: """The value of If-Modified-Since HTTP header, or None. This header is represented as a `datetime` object. @@ -484,7 +467,7 @@ def if_modified_since(self) -> Optional[datetime.datetime]: return parse_http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE)) @reify - def if_unmodified_since(self) -> Optional[datetime.datetime]: + def if_unmodified_since(self) -> datetime.datetime | None: """The value of If-Unmodified-Since HTTP header, or None. This header is represented as a `datetime` object. @@ -514,15 +497,15 @@ def _etag_values(etag_header: str) -> Iterator[ETag]: @classmethod def _if_match_or_none_impl( - cls, header_value: Optional[str] - ) -> Optional[Tuple[ETag, ...]]: + cls, header_value: str | None + ) -> tuple[ETag, ...] | None: if not header_value: return None return tuple(cls._etag_values(header_value)) @reify - def if_match(self) -> Optional[Tuple[ETag, ...]]: + def if_match(self) -> tuple[ETag, ...] | None: """The value of If-Match HTTP header, or None. This header is represented as a `tuple` of `ETag` objects. @@ -530,7 +513,7 @@ def if_match(self) -> Optional[Tuple[ETag, ...]]: return self._if_match_or_none_impl(self.headers.get(hdrs.IF_MATCH)) @reify - def if_none_match(self) -> Optional[Tuple[ETag, ...]]: + def if_none_match(self) -> tuple[ETag, ...] | None: """The value of If-None-Match HTTP header, or None. This header is represented as a `tuple` of `ETag` objects. @@ -538,7 +521,7 @@ def if_none_match(self) -> Optional[Tuple[ETag, ...]]: return self._if_match_or_none_impl(self.headers.get(hdrs.IF_NONE_MATCH)) @reify - def if_range(self) -> Optional[datetime.datetime]: + def if_range(self) -> datetime.datetime | None: """The value of If-Range HTTP header, or None. This header is represented as a `datetime` object. @@ -655,7 +638,7 @@ async def json( self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER, - content_type: Optional[str] = "application/json", + content_type: str | None = "application/json", ) -> Any: """Return BODY as JSON.""" body = await self.text() @@ -674,7 +657,7 @@ async def multipart(self) -> MultipartReader: """Return async iterator to process BODY as multipart.""" return MultipartReader(self._headers, self._payload) - async def post(self) -> "MultiDictProxy[Union[str, bytes, FileField]]": + async def post(self) -> "MultiDictProxy[str | bytes | FileField]": """Return POST parameters.""" if self._post is not None: return self._post @@ -691,7 +674,7 @@ async def post(self) -> "MultiDictProxy[Union[str, bytes, FileField]]": self._post = MultiDictProxy(MultiDict()) return self._post - out: MultiDict[Union[str, bytes, FileField]] = MultiDict() + out: MultiDict[str | bytes | FileField] = MultiDict() if content_type == "multipart/form-data": multipart = await self.multipart() @@ -785,9 +768,7 @@ def __repr__(self) -> str: ascii_encodable_path = self.path.encode("ascii", "backslashreplace").decode( "ascii" ) - return "<{} {} {} >".format( - self.__class__.__name__, self._method, ascii_encodable_path - ) + return f"<{self.__class__.__name__} {self._method} {ascii_encodable_path} >" def __eq__(self, other: object) -> bool: return id(self) == id(other) @@ -821,13 +802,13 @@ class Request(BaseRequest): def clone( self, *, - method: Union[str, _SENTINEL] = sentinel, - rel_url: Union[StrOrURL, _SENTINEL] = sentinel, - headers: Union[LooseHeaders, _SENTINEL] = sentinel, - scheme: Union[str, _SENTINEL] = sentinel, - host: Union[str, _SENTINEL] = sentinel, - remote: Union[str, _SENTINEL] = sentinel, - client_max_size: Union[int, _SENTINEL] = sentinel, + method: str | _SENTINEL = sentinel, + rel_url: StrOrURL | _SENTINEL = sentinel, + headers: LooseHeaders | _SENTINEL = sentinel, + scheme: str | _SENTINEL = sentinel, + host: str | _SENTINEL = sentinel, + remote: str | _SENTINEL = sentinel, + client_max_size: int | _SENTINEL = sentinel, ) -> "Request": ret = super().clone( method=method, diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 41559143b6a..2f921d4f559 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -6,18 +6,10 @@ import math import time import warnings +from collections.abc import Iterator, MutableMapping from concurrent.futures import Executor from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterator, - MutableMapping, - Optional, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Optional, Union, cast from multidict import CIMultiDict, istr @@ -76,18 +68,18 @@ class ContentCoding(enum.Enum): class StreamResponse(BaseClass, HeadersMixin, CookieMixin): - _body: Union[None, bytes, bytearray, Payload] + _body: None | bytes | bytearray | Payload _length_check = True _body = None - _keep_alive: Optional[bool] = None + _keep_alive: bool | None = None _chunked: bool = False _compression: bool = False - _compression_strategy: Optional[int] = None - _compression_force: Optional[ContentCoding] = None + _compression_strategy: int | None = None + _compression_force: ContentCoding | None = None _req: Optional["BaseRequest"] = None - _payload_writer: Optional[AbstractStreamWriter] = None + _payload_writer: AbstractStreamWriter | None = None _eof_sent: bool = False - _must_be_empty_body: Optional[bool] = None + _must_be_empty_body: bool | None = None _body_length = 0 _send_headers_immediately = True @@ -95,9 +87,9 @@ def __init__( self, *, status: int = 200, - reason: Optional[str] = None, - headers: Optional[LooseHeaders] = None, - _real_headers: Optional[CIMultiDict[str]] = None, + reason: str | None = None, + headers: LooseHeaders | None = None, + _real_headers: CIMultiDict[str] | None = None, ) -> None: """Initialize a new stream response object. @@ -106,7 +98,7 @@ def __init__( the headers when creating a new response object. It is not intended to be used by external code. """ - self._state: Dict[str, Any] = {} + self._state: dict[str, Any] = {} if _real_headers is not None: self._headers = _real_headers @@ -122,7 +114,7 @@ def prepared(self) -> bool: return self._eof_sent or self._payload_writer is not None @property - def task(self) -> "Optional[asyncio.Task[None]]": + def task(self) -> "asyncio.Task[None] | None": if self._req: return self._req.task else: @@ -147,14 +139,14 @@ def reason(self) -> str: def set_status( self, status: int, - reason: Optional[str] = None, + reason: str | None = None, ) -> None: assert ( not self.prepared ), "Cannot change the response status code after the headers have been sent" self._set_status(status, reason) - def _set_status(self, status: int, reason: Optional[str]) -> None: + def _set_status(self, status: int, reason: str | None) -> None: self._status = status if reason is None: reason = REASON_PHRASES.get(self._status, "") @@ -163,7 +155,7 @@ def _set_status(self, status: int, reason: Optional[str]) -> None: self._reason = reason @property - def keep_alive(self) -> Optional[bool]: + def keep_alive(self) -> bool | None: return self._keep_alive def force_close(self) -> None: @@ -183,8 +175,8 @@ def enable_chunked_encoding(self) -> None: def enable_compression( self, - force: Optional[ContentCoding] = None, - strategy: Optional[int] = None, + force: ContentCoding | None = None, + strategy: int | None = None, ) -> None: """Enables response compression encoding.""" # Don't enable compression if content is already encoded. @@ -203,12 +195,12 @@ def headers(self) -> "CIMultiDict[str]": return self._headers @property - def content_length(self) -> Optional[int]: + def content_length(self) -> int | None: # Just a placeholder for adding setter return super().content_length @content_length.setter - def content_length(self, value: Optional[int]) -> None: + def content_length(self, value: int | None) -> None: if value is not None: value = int(value) if self._chunked: @@ -231,12 +223,12 @@ def content_type(self, value: str) -> None: self._generate_content_type_header() @property - def charset(self) -> Optional[str]: + def charset(self) -> str | None: # Just a placeholder for adding setter return super().charset @charset.setter - def charset(self, value: Optional[str]) -> None: + def charset(self, value: str | None) -> None: ctype = self.content_type # read header values if needed if ctype == "application/octet-stream": raise RuntimeError( @@ -251,7 +243,7 @@ def charset(self, value: Optional[str]) -> None: self._generate_content_type_header() @property - def last_modified(self) -> Optional[datetime.datetime]: + def last_modified(self) -> datetime.datetime | None: """The value of Last-Modified HTTP header, or None. This header is represented as a `datetime` object. @@ -260,7 +252,7 @@ def last_modified(self) -> Optional[datetime.datetime]: @last_modified.setter def last_modified( - self, value: Optional[Union[int, float, datetime.datetime, str]] + self, value: int | float | datetime.datetime | str | None ) -> None: if value is None: self._headers.pop(hdrs.LAST_MODIFIED, None) @@ -279,7 +271,7 @@ def last_modified( raise TypeError(msg) @property - def etag(self) -> Optional[ETag]: + def etag(self) -> ETag | None: quoted_value = self._headers.get(hdrs.ETAG) if not quoted_value: return None @@ -295,7 +287,7 @@ def etag(self) -> Optional[ETag]: ) @etag.setter - def etag(self, value: Optional[Union[ETag, str]]) -> None: + def etag(self, value: ETag | str | None) -> None: if value is None: self._headers.pop(hdrs.ETAG, None) elif (isinstance(value, str) and value == ETAG_ANY) or ( @@ -351,7 +343,7 @@ async def _start_compression(self, request: "BaseRequest") -> None: await self._do_start_compression(coding) return - async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: + async def prepare(self, request: "BaseRequest") -> AbstractStreamWriter | None: if self._eof_sent: return None if self._payload_writer is not None: @@ -392,7 +384,7 @@ async def _prepare_headers(self) -> None: if version != HttpVersion11: raise RuntimeError( "Using chunked encoding is forbidden " - "for HTTP/{0.major}.{0.minor}".format(request.version) + f"for HTTP/{request.version.major}.{request.version.minor}" ) if not self._must_be_empty_body: writer.enable_chunking() @@ -523,7 +515,7 @@ def __bool__(self) -> bool: class Response(StreamResponse): - _compressed_body: Optional[bytes] = None + _compressed_body: bytes | None = None _send_headers_immediately = False def __init__( @@ -531,13 +523,13 @@ def __init__( *, body: Any = None, status: int = 200, - reason: Optional[str] = None, - text: Optional[str] = None, - headers: Optional[LooseHeaders] = None, - content_type: Optional[str] = None, - charset: Optional[str] = None, - zlib_executor_size: Optional[int] = None, - zlib_executor: Optional[Executor] = None, + reason: str | None = None, + text: str | None = None, + headers: LooseHeaders | None = None, + content_type: str | None = None, + charset: str | None = None, + zlib_executor_size: int | None = None, + zlib_executor: Executor | None = None, ) -> None: if body is not None and text is not None: raise ValueError("body and text are not allowed together") @@ -592,7 +584,7 @@ def __init__( self._zlib_executor = zlib_executor @property - def body(self) -> Optional[Union[bytes, bytearray, Payload]]: + def body(self) -> bytes | bytearray | Payload | None: return self._body @body.setter @@ -622,7 +614,7 @@ def body(self, body: Any) -> None: self._compressed_body = None @property - def text(self) -> Optional[str]: + def text(self) -> str | None: if self._body is None: return None # Note: When _body is a Payload (e.g. FilePayload), this may do blocking I/O @@ -643,7 +635,7 @@ def text(self, text: str) -> None: self._compressed_body = None @property - def content_length(self) -> Optional[int]: + def content_length(self) -> int | None: if self._chunked: return None @@ -662,7 +654,7 @@ def content_length(self) -> Optional[int]: return 0 @content_length.setter - def content_length(self, value: Optional[int]) -> None: + def content_length(self, value: int | None) -> None: raise RuntimeError("Content length is set automatically") async def write_eof(self, data: bytes = b"") -> None: @@ -732,11 +724,11 @@ async def _do_start_compression(self, coding: ContentCoding) -> None: def json_response( data: Any = sentinel, *, - text: Optional[str] = None, - body: Optional[bytes] = None, + text: str | None = None, + body: bytes | None = None, status: int = 200, - reason: Optional[str] = None, - headers: Optional[LooseHeaders] = None, + reason: str | None = None, + headers: LooseHeaders | None = None, content_type: str = "application/json", dumps: JSONEncoder = json.dumps, ) -> Response: diff --git a/aiohttp/web_routedef.py b/aiohttp/web_routedef.py index 3d644a0d404..dc566b64f03 100644 --- a/aiohttp/web_routedef.py +++ b/aiohttp/web_routedef.py @@ -1,18 +1,7 @@ import abc import dataclasses -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterator, - List, - Optional, - Sequence, - Type, - Union, - overload, -) +from collections.abc import Callable, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Union, overload from . import hdrs from .abc import AbstractView @@ -46,11 +35,11 @@ class AbstractRouteDef(abc.ABC): @abc.abstractmethod - def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + def register(self, router: UrlDispatcher) -> list[AbstractRoute]: """Register itself into the given router.""" -_HandlerType = Union[Type[AbstractView], Handler] +_HandlerType = Union[type[AbstractView], Handler] @dataclasses.dataclass(frozen=True, repr=False) @@ -58,7 +47,7 @@ class RouteDef(AbstractRouteDef): method: str path: str handler: _HandlerType - kwargs: Dict[str, Any] + kwargs: dict[str, Any] def __repr__(self) -> str: info = [] @@ -68,7 +57,7 @@ def __repr__(self) -> str: method=self.method, path=self.path, handler=self.handler, info="".join(info) ) - def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + def register(self, router: UrlDispatcher) -> list[AbstractRoute]: if self.method in hdrs.METH_ALL: reg = getattr(router, "add_" + self.method.lower()) return [reg(self.path, self.handler, **self.kwargs)] @@ -82,7 +71,7 @@ def register(self, router: UrlDispatcher) -> List[AbstractRoute]: class StaticDef(AbstractRouteDef): prefix: str path: PathLike - kwargs: Dict[str, Any] + kwargs: dict[str, Any] def __repr__(self) -> str: info = [] @@ -92,7 +81,7 @@ def __repr__(self) -> str: prefix=self.prefix, path=self.path, info="".join(info) ) - def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + def register(self, router: UrlDispatcher) -> list[AbstractRoute]: resource = router.add_static(self.prefix, self.path, **self.kwargs) routes = resource.get_info().get("routes", {}) return list(routes.values()) @@ -114,7 +103,7 @@ def get( path: str, handler: _HandlerType, *, - name: Optional[str] = None, + name: str | None = None, allow_head: bool = True, **kwargs: Any, ) -> RouteDef: @@ -139,7 +128,7 @@ def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: return route(hdrs.METH_DELETE, path, handler, **kwargs) -def view(path: str, handler: Type[AbstractView], **kwargs: Any) -> RouteDef: +def view(path: str, handler: type[AbstractView], **kwargs: Any) -> RouteDef: return route(hdrs.METH_ANY, path, handler, **kwargs) @@ -154,7 +143,7 @@ class RouteTableDef(Sequence[AbstractRouteDef]): """Route definition table""" def __init__(self) -> None: - self._items: List[AbstractRouteDef] = [] + self._items: list[AbstractRouteDef] = [] def __repr__(self) -> str: return f"" @@ -163,11 +152,11 @@ def __repr__(self) -> str: def __getitem__(self, index: int) -> AbstractRouteDef: ... @overload - def __getitem__(self, index: "slice[int, int, int]") -> List[AbstractRouteDef]: ... + def __getitem__(self, index: "slice[int, int, int]") -> list[AbstractRouteDef]: ... def __getitem__( self, index: Union[int, "slice[int, int, int]"] - ) -> Union[AbstractRouteDef, List[AbstractRouteDef]]: + ) -> AbstractRouteDef | list[AbstractRouteDef]: return self._items[index] def __iter__(self) -> Iterator[AbstractRouteDef]: diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index c5e62003a34..b06c9ebce94 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,7 +2,7 @@ import signal import socket from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Generic, List, Optional, Set, Type, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from yarl import URL @@ -54,7 +54,7 @@ def __init__( self, runner: "BaseRunner[Any]", *, - ssl_context: Optional[SSLContext] = None, + ssl_context: SSLContext | None = None, backlog: int = 128, ) -> None: if runner.server is None: @@ -62,7 +62,7 @@ def __init__( self._runner = runner self._ssl_context = ssl_context self._backlog = backlog - self._server: Optional[asyncio.Server] = None + self._server: asyncio.Server | None = None @property @abstractmethod @@ -87,13 +87,13 @@ class TCPSite(BaseSite): def __init__( self, runner: "BaseRunner[Any]", - host: Optional[str] = None, - port: Optional[int] = None, + host: str | None = None, + port: int | None = None, *, - ssl_context: Optional[SSLContext] = None, + ssl_context: SSLContext | None = None, backlog: int = 128, - reuse_address: Optional[bool] = None, - reuse_port: Optional[bool] = None, + reuse_address: bool | None = None, + reuse_port: bool | None = None, ) -> None: super().__init__( runner, @@ -137,7 +137,7 @@ def __init__( runner: "BaseRunner[Any]", path: PathLike, *, - ssl_context: Optional[SSLContext] = None, + ssl_context: SSLContext | None = None, backlog: int = 128, ) -> None: super().__init__( @@ -202,7 +202,7 @@ def __init__( runner: "BaseRunner[Any]", sock: socket.socket, *, - ssl_context: Optional[SSLContext] = None, + ssl_context: SSLContext | None = None, backlog: int = 128, ) -> None: super().__init__( @@ -245,17 +245,17 @@ def __init__( ) -> None: self._handle_signals = handle_signals self._kwargs = kwargs - self._server: Optional[Server[_Request]] = None - self._sites: List[BaseSite] = [] + self._server: Server[_Request] | None = None + self._sites: list[BaseSite] = [] self._shutdown_timeout = shutdown_timeout @property - def server(self) -> Optional[Server[_Request]]: + def server(self) -> Server[_Request] | None: return self._server @property - def addresses(self) -> List[Any]: - ret: List[Any] = [] + def addresses(self) -> list[Any]: + ret: list[Any] = [] for site in self._sites: server = site._server if server is not None: @@ -266,7 +266,7 @@ def addresses(self) -> List[Any]: return ret @property - def sites(self) -> Set[BaseSite]: + def sites(self) -> set[BaseSite]: return set(self._sites) async def setup(self) -> None: @@ -371,13 +371,12 @@ def __init__( app: Application, *, handle_signals: bool = False, - access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_class: type[AbstractAccessLogger] = AccessLogger, **kwargs: Any, ) -> None: if not isinstance(app, Application): raise TypeError( - "The first argument should be web.Application " - "instance, got {!r}".format(app) + "The first argument should be web.Application " f"instance, got {app!r}" ) kwargs["access_log_class"] = access_log_class @@ -421,7 +420,7 @@ def _make_request( protocol: RequestHandler[Request], writer: AbstractStreamWriter, task: "asyncio.Task[None]", - _cls: Type[Request] = Request, + _cls: type[Request] = Request, ) -> Request: loop = asyncio.get_running_loop() return _cls( diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 53827a62590..d02ead867e6 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -2,17 +2,8 @@ import asyncio import warnings -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Generic, - List, - Optional, - TypeVar, - overload, -) +from collections.abc import Awaitable, Callable +from typing import Any, Generic, TypeVar, overload from .abc import AbstractStreamWriter from .http_parser import RawRequestMessage @@ -44,7 +35,7 @@ def __init__( self: "Server[BaseRequest]", handler: Callable[[_Request], Awaitable[StreamResponse]], *, - debug: Optional[bool] = None, + debug: bool | None = None, handler_cancellation: bool = False, **kwargs: Any, # TODO(PY311): Use Unpack to define kwargs from RequestHandler ) -> None: ... @@ -53,8 +44,8 @@ def __init__( self, handler: Callable[[_Request], Awaitable[StreamResponse]], *, - request_factory: Optional[_RequestFactory[_Request]], - debug: Optional[bool] = None, + request_factory: _RequestFactory[_Request] | None, + debug: bool | None = None, handler_cancellation: bool = False, **kwargs: Any, ) -> None: ... @@ -62,8 +53,8 @@ def __init__( self, handler: Callable[[_Request], Awaitable[StreamResponse]], *, - request_factory: Optional[_RequestFactory[_Request]] = None, - debug: Optional[bool] = None, + request_factory: _RequestFactory[_Request] | None = None, + debug: bool | None = None, handler_cancellation: bool = False, **kwargs: Any, ) -> None: @@ -74,7 +65,7 @@ def __init__( stacklevel=2, ) self._loop = asyncio.get_running_loop() - self._connections: Dict[RequestHandler[_Request], asyncio.Transport] = {} + self._connections: dict[RequestHandler[_Request], asyncio.Transport] = {} self._kwargs = kwargs # requests_count is the number of requests being processed by the server # for the lifetime of the server. @@ -84,7 +75,7 @@ def __init__( self.handler_cancellation = handler_cancellation @property - def connections(self) -> List[RequestHandler[_Request]]: + def connections(self) -> list[RequestHandler[_Request]]: return list(self._connections.keys()) def connection_made( @@ -93,7 +84,7 @@ def connection_made( self._connections[handler] = transport def connection_lost( - self, handler: RequestHandler[_Request], exc: Optional[BaseException] = None + self, handler: RequestHandler[_Request], exc: BaseException | None = None ) -> None: if handler in self._connections: if handler._task_handler: @@ -117,7 +108,7 @@ def pre_shutdown(self) -> None: for conn in self._connections: conn.close() - async def shutdown(self, timeout: Optional[float] = None) -> None: + async def shutdown(self, timeout: float | None = None) -> None: coros = (conn.shutdown(timeout) for conn in self._connections) await asyncio.gather(*coros) self._connections.clear() diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index 0ed199f7e59..124afbba929 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -9,32 +9,20 @@ import os import re import sys -from pathlib import Path -from types import MappingProxyType -from typing import ( - TYPE_CHECKING, - Any, +from collections.abc import ( Awaitable, Callable, Container, - Dict, - Final, Generator, Iterable, Iterator, - List, Mapping, - NoReturn, - Optional, - Pattern, - Set, Sized, - Tuple, - Type, - TypedDict, - Union, - cast, ) +from pathlib import Path +from re import Pattern +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Final, NoReturn, Optional, TypedDict, cast from yarl import URL @@ -72,15 +60,11 @@ if TYPE_CHECKING: from .web_app import Application - BaseDict = Dict[str, str] + BaseDict = dict[str, str] else: BaseDict = dict -CIRCULAR_SYMLINK_ERROR = ( - (OSError,) - if sys.version_info < (3, 10) and sys.platform.startswith("win32") - else (RuntimeError,) if sys.version_info < (3, 13) else () -) +CIRCULAR_SYMLINK_ERROR = (RuntimeError,) if sys.version_info < (3, 13) else () HTTP_METHOD_RE: Final[Pattern[str]] = re.compile( r"^[0-9A-Za-z!#\$%&'\*\+\-\.\^_`\|~]+$" @@ -91,8 +75,8 @@ PATH_SEP: Final[str] = re.escape("/") -_ExpectHandler = Callable[[Request], Awaitable[Optional[StreamResponse]]] -_Resolve = Tuple[Optional["UrlMappingMatchInfo"], Set[str]] +_ExpectHandler = Callable[[Request], Awaitable[StreamResponse | None]] +_Resolve = tuple[Optional["UrlMappingMatchInfo"], set[str]] html_escape = functools.partial(html.escape, quote=True) @@ -117,11 +101,11 @@ class _InfoDict(TypedDict, total=False): class AbstractResource(Sized, Iterable["AbstractRoute"]): - def __init__(self, *, name: Optional[str] = None) -> None: + def __init__(self, *, name: str | None = None) -> None: self._name = name @property - def name(self) -> Optional[str]: + def name(self) -> str | None: return self._name @property @@ -167,10 +151,10 @@ class AbstractRoute(abc.ABC): def __init__( self, method: str, - handler: Union[Handler, Type[AbstractView]], + handler: Handler | type[AbstractView], *, - expect_handler: Optional[_ExpectHandler] = None, - resource: Optional[AbstractResource] = None, + expect_handler: _ExpectHandler | None = None, + resource: AbstractResource | None = None, ) -> None: if expect_handler is None: expect_handler = _default_expect_handler @@ -191,8 +175,7 @@ def __init__( pass else: raise TypeError( - "Only async functions are allowed as web-handlers " - ", got {!r}".format(handler) + "Only async functions are allowed as web-handlers " f", got {handler!r}" ) self._method = method @@ -210,11 +193,11 @@ def handler(self) -> Handler: @property @abc.abstractmethod - def name(self) -> Optional[str]: + def name(self) -> str | None: """Optional route's name, always equals to resource's name.""" @property - def resource(self) -> Optional[AbstractResource]: + def resource(self) -> AbstractResource | None: return self._resource @abc.abstractmethod @@ -225,7 +208,7 @@ def get_info(self) -> _InfoDict: def url_for(self, *args: str, **kwargs: str) -> URL: """Construct url for route with additional params.""" - async def handle_expect_header(self, request: Request) -> Optional[StreamResponse]: + async def handle_expect_header(self, request: Request) -> StreamResponse | None: return await self._expect_handler(request) @@ -233,11 +216,11 @@ class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo): __slots__ = ("_route", "_apps", "_current_app", "_frozen") - def __init__(self, match_dict: Dict[str, str], route: AbstractRoute) -> None: + def __init__(self, match_dict: dict[str, str], route: AbstractRoute) -> None: super().__init__(match_dict) self._route = route - self._apps: List[Application] = [] - self._current_app: Optional[Application] = None + self._apps: list[Application] = [] + self._current_app: Application | None = None self._frozen = False @property @@ -253,14 +236,14 @@ def expect_handler(self) -> _ExpectHandler: return self._route.handle_expect_header @property - def http_exception(self) -> Optional[HTTPException]: + def http_exception(self) -> HTTPException | None: return None def get_info(self) -> _InfoDict: # type: ignore[override] return self._route.get_info() @property - def apps(self) -> Tuple["Application", ...]: + def apps(self) -> tuple["Application", ...]: return tuple(self._apps) def add_app(self, app: "Application") -> None: @@ -281,9 +264,7 @@ def current_app(self, app: "Application") -> None: if DEBUG: if app not in self._apps: raise RuntimeError( - "Expected one of the following apps {!r}, got {!r}".format( - self._apps, app - ) + f"Expected one of the following apps {self._apps!r}, got {app!r}" ) self._current_app = app @@ -307,9 +288,7 @@ def http_exception(self) -> HTTPException: return self._exception def __repr__(self) -> str: - return "".format( - self._exception.status, self._exception.reason - ) + return f"" async def _default_expect_handler(request: Request) -> None: @@ -329,18 +308,18 @@ async def _default_expect_handler(request: Request) -> None: class Resource(AbstractResource): - def __init__(self, *, name: Optional[str] = None) -> None: + def __init__(self, *, name: str | None = None) -> None: super().__init__(name=name) - self._routes: Dict[str, ResourceRoute] = {} - self._any_route: Optional[ResourceRoute] = None - self._allowed_methods: Set[str] = set() + self._routes: dict[str, ResourceRoute] = {} + self._any_route: ResourceRoute | None = None + self._allowed_methods: set[str] = set() def add_route( self, method: str, - handler: Union[Type[AbstractView], Handler], + handler: type[AbstractView] | Handler, *, - expect_handler: Optional[_ExpectHandler] = None, + expect_handler: _ExpectHandler | None = None, ) -> "ResourceRoute": if route := self._routes.get(method, self._any_route): raise RuntimeError( @@ -370,7 +349,7 @@ async def resolve(self, request: Request) -> _Resolve: return None, self._allowed_methods @abc.abstractmethod - def _match(self, path: str) -> Optional[Dict[str, str]]: + def _match(self, path: str) -> dict[str, str] | None: """Return dict of path values if path matches this resource, otherwise None.""" def __len__(self) -> int: @@ -383,7 +362,7 @@ def __iter__(self) -> Iterator["ResourceRoute"]: class PlainResource(Resource): - def __init__(self, path: str, *, name: Optional[str] = None) -> None: + def __init__(self, path: str, *, name: str | None = None) -> None: super().__init__(name=name) assert not path or path.startswith("/") self._path = path @@ -402,7 +381,7 @@ def add_prefix(self, prefix: str) -> None: assert len(prefix) > 1 self._path = prefix + self._path - def _match(self, path: str) -> Optional[Dict[str, str]]: + def _match(self, path: str) -> dict[str, str] | None: # string comparison is about 10 times faster than regexp matching if self._path == path: return {} @@ -427,7 +406,7 @@ class DynamicResource(Resource): DYN_WITH_RE = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*):(?P.+)\}") GOOD = r"[^{}/]+" - def __init__(self, path: str, *, name: Optional[str] = None) -> None: + def __init__(self, path: str, *, name: str | None = None) -> None: super().__init__(name=name) self._orig_path = path pattern = "" @@ -472,7 +451,7 @@ def add_prefix(self, prefix: str) -> None: self._pattern = re.compile(re.escape(prefix) + self._pattern.pattern) self._formatter = prefix + self._formatter - def _match(self, path: str) -> Optional[Dict[str, str]]: + def _match(self, path: str) -> dict[str, str] | None: match = self._pattern.fullmatch(path) if match is None: return None @@ -492,13 +471,11 @@ def url_for(self, **parts: str) -> URL: def __repr__(self) -> str: name = "'" + self.name + "' " if self.name is not None else "" - return "".format( - name=name, formatter=self._formatter - ) + return f"" class PrefixResource(AbstractResource): - def __init__(self, prefix: str, *, name: Optional[str] = None) -> None: + def __init__(self, prefix: str, *, name: str | None = None) -> None: assert not prefix or prefix.startswith("/"), prefix assert prefix in ("", "/") or not prefix.endswith("/"), prefix super().__init__(name=name) @@ -530,8 +507,8 @@ def __init__( prefix: str, directory: PathLike, *, - name: Optional[str] = None, - expect_handler: Optional[_ExpectHandler] = None, + name: str | None = None, + expect_handler: _ExpectHandler | None = None, chunk_size: int = 256 * 1024, show_index: bool = False, follow_symlinks: bool = False, @@ -565,7 +542,7 @@ def url_for( # type: ignore[override] self, *, filename: PathLike, - append_version: Optional[bool] = None, + append_version: bool | None = None, ) -> URL: if append_version is None: append_version = self._append_version @@ -724,9 +701,7 @@ def _directory_as_html(self, dir_path: Path) -> str: def __repr__(self) -> str: name = "'" + self.name + "'" if self.name is not None else "" - return " {directory!r}>".format( - name=name, path=self._prefix, directory=self._directory - ) + return f" {self._directory!r}>" class PrefixedSubAppResource(PrefixResource): @@ -770,9 +745,7 @@ def __iter__(self) -> Iterator[AbstractRoute]: return iter(self._app.router.routes()) def __repr__(self) -> str: - return " {app!r}>".format( - prefix=self._prefix, app=self._app - ) + return f" {self._app!r}>" class AbstractRuleMatching(abc.ABC): @@ -881,22 +854,20 @@ class ResourceRoute(AbstractRoute): def __init__( self, method: str, - handler: Union[Handler, Type[AbstractView]], + handler: Handler | type[AbstractView], resource: AbstractResource, *, - expect_handler: Optional[_ExpectHandler] = None, + expect_handler: _ExpectHandler | None = None, ) -> None: super().__init__( method, handler, expect_handler=expect_handler, resource=resource ) def __repr__(self) -> str: - return " {handler!r}".format( - method=self.method, resource=self._resource, handler=self.handler - ) + return f" {self.handler!r}" @property - def name(self) -> Optional[str]: + def name(self) -> str | None: if self._resource is None: return None return self._resource.name @@ -920,7 +891,7 @@ def url_for(self, *args: str, **kwargs: str) -> URL: raise RuntimeError(".url_for() is not allowed for SystemRoute") @property - def name(self) -> Optional[str]: + def name(self) -> str | None: return None def get_info(self) -> _InfoDict: @@ -938,14 +909,14 @@ def reason(self) -> str: return self._http_exception.reason def __repr__(self) -> str: - return "".format(self=self) + return f"" class View(AbstractView): async def _iter(self) -> StreamResponse: if self.request.method not in hdrs.METH_ALL: self._raise_allowed_methods() - method: Optional[Callable[[], Awaitable[StreamResponse]]] = getattr( + method: Callable[[], Awaitable[StreamResponse]] | None = getattr( self, self.request.method.lower(), None ) if method is None: @@ -961,7 +932,7 @@ def _raise_allowed_methods(self) -> NoReturn: class ResourcesView(Sized, Iterable[AbstractResource], Container[AbstractResource]): - def __init__(self, resources: List[AbstractResource]) -> None: + def __init__(self, resources: list[AbstractResource]) -> None: self._resources = resources def __len__(self) -> int: @@ -975,8 +946,8 @@ def __contains__(self, resource: object) -> bool: class RoutesView(Sized, Iterable[AbstractRoute], Container[AbstractRoute]): - def __init__(self, resources: List[AbstractResource]): - self._routes: List[AbstractRoute] = [] + def __init__(self, resources: list[AbstractResource]): + self._routes: list[AbstractRoute] = [] for resource in resources: for route in resource: self._routes.append(route) @@ -997,14 +968,14 @@ class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): def __init__(self) -> None: super().__init__() - self._resources: List[AbstractResource] = [] - self._named_resources: Dict[str, AbstractResource] = {} + self._resources: list[AbstractResource] = [] + self._named_resources: dict[str, AbstractResource] = {} self._resource_index: dict[str, list[AbstractResource]] = {} - self._matched_sub_app_resources: List[MatchedSubAppResource] = [] + self._matched_sub_app_resources: list[MatchedSubAppResource] = [] async def resolve(self, request: Request) -> UrlMappingMatchInfo: resource_index = self._resource_index - allowed_methods: Set[str] = set() + allowed_methods: set[str] = set() # Walk the url parts looking for candidates. We walk the url backwards # to ensure the most explicit match is found first. If there are multiple @@ -1084,15 +1055,15 @@ def register_resource(self, resource: AbstractResource) -> None: ) if not part.isidentifier(): raise ValueError( - "Incorrect route name {!r}, " + f"Incorrect route name {name!r}, " "the name should be a sequence of " "python identifiers separated " - "by dash, dot or column".format(name) + "by dash, dot or column" ) if name in self._named_resources: raise ValueError( - "Duplicate {!r}, " - "already handled by {!r}".format(name, self._named_resources[name]) + f"Duplicate {name!r}, " + f"already handled by {self._named_resources[name]!r}" ) self._named_resources[name] = resource self._resources.append(resource) @@ -1127,7 +1098,7 @@ def unindex_resource(self, resource: AbstractResource) -> None: resource_key = self._get_resource_index_key(resource) self._resource_index[resource_key].remove(resource) - def add_resource(self, path: str, *, name: Optional[str] = None) -> Resource: + def add_resource(self, path: str, *, name: str | None = None) -> Resource: if path and not path.startswith("/"): raise ValueError("path should be started with / or be empty") # Reuse last added resource if path and name are the same @@ -1147,10 +1118,10 @@ def add_route( self, method: str, path: str, - handler: Union[Handler, Type[AbstractView]], + handler: Handler | type[AbstractView], *, - name: Optional[str] = None, - expect_handler: Optional[_ExpectHandler] = None, + name: str | None = None, + expect_handler: _ExpectHandler | None = None, ) -> AbstractRoute: resource = self.add_resource(path, name=name) return resource.add_route(method, handler, expect_handler=expect_handler) @@ -1160,8 +1131,8 @@ def add_static( prefix: str, path: PathLike, *, - name: Optional[str] = None, - expect_handler: Optional[_ExpectHandler] = None, + name: str | None = None, + expect_handler: _ExpectHandler | None = None, chunk_size: int = 256 * 1024, show_index: bool = False, follow_symlinks: bool = False, @@ -1202,7 +1173,7 @@ def add_get( path: str, handler: Handler, *, - name: Optional[str] = None, + name: str | None = None, allow_head: bool = True, **kwargs: Any, ) -> AbstractRoute: @@ -1233,7 +1204,7 @@ def add_delete(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRout return self.add_route(hdrs.METH_DELETE, path, handler, **kwargs) def add_view( - self, path: str, handler: Type[AbstractView], **kwargs: Any + self, path: str, handler: type[AbstractView], **kwargs: Any ) -> AbstractRoute: """Shortcut for add_route with ANY methods for a class-based view.""" return self.add_route(hdrs.METH_ANY, path, handler, **kwargs) @@ -1243,7 +1214,7 @@ def freeze(self) -> None: for resource in self._resources: resource.freeze() - def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]: + def add_routes(self, routes: Iterable[AbstractRouteDef]) -> list[AbstractRoute]: """Append routes to route table. Parameter should be a sequence of RouteDef objects. diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index a1fd12a1e97..8eee7e3ad71 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -4,7 +4,8 @@ import hashlib import json import sys -from typing import Any, Final, Iterable, Optional, Tuple, Union +from collections.abc import Iterable +from typing import Any, Final, Union from multidict import CIMultiDict @@ -57,7 +58,7 @@ @frozen_dataclass_decorator class WebSocketReady: ok: bool - protocol: Optional[str] + protocol: str | None def __bool__(self) -> bool: return self.ok @@ -66,30 +67,30 @@ def __bool__(self) -> bool: class WebSocketResponse(StreamResponse): _length_check: bool = False - _ws_protocol: Optional[str] = None - _writer: Optional[WebSocketWriter] = None - _reader: Optional[WebSocketDataQueue] = None + _ws_protocol: str | None = None + _writer: WebSocketWriter | None = None + _reader: WebSocketDataQueue | None = None _closed: bool = False _closing: bool = False _conn_lost: int = 0 - _close_code: Optional[int] = None - _loop: Optional[asyncio.AbstractEventLoop] = None + _close_code: int | None = None + _loop: asyncio.AbstractEventLoop | None = None _waiting: bool = False - _close_wait: Optional[asyncio.Future[None]] = None - _exception: Optional[BaseException] = None + _close_wait: asyncio.Future[None] | None = None + _exception: BaseException | None = None _heartbeat_when: float = 0.0 - _heartbeat_cb: Optional[asyncio.TimerHandle] = None - _pong_response_cb: Optional[asyncio.TimerHandle] = None - _ping_task: Optional[asyncio.Task[None]] = None + _heartbeat_cb: asyncio.TimerHandle | None = None + _pong_response_cb: asyncio.TimerHandle | None = None + _ping_task: asyncio.Task[None] | None = None def __init__( self, *, timeout: float = 10.0, - receive_timeout: Optional[float] = None, + receive_timeout: float | None = None, autoclose: bool = True, autoping: bool = True, - heartbeat: Optional[float] = None, + heartbeat: float | None = None, protocols: Iterable[str] = (), compress: bool = True, max_msg_size: int = 4 * 1024 * 1024, @@ -104,7 +105,7 @@ def __init__( self._heartbeat = heartbeat if heartbeat is not None: self._pong_heartbeat = heartbeat / 2.0 - self._compress: Union[bool, int] = compress + self._compress: bool | int = compress self._max_msg_size = max_msg_size self._writer_limit = writer_limit @@ -224,25 +225,23 @@ async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: def _handshake( self, request: BaseRequest - ) -> Tuple["CIMultiDict[str]", Optional[str], int, bool]: + ) -> tuple["CIMultiDict[str]", str | None, int, bool]: headers = request.headers if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip(): raise HTTPBadRequest( text=( - "No WebSocket UPGRADE hdr: {}\n Can " + f"No WebSocket UPGRADE hdr: {headers.get(hdrs.UPGRADE)}\n Can " '"Upgrade" only to "WebSocket".' - ).format(headers.get(hdrs.UPGRADE)) + ) ) if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower(): raise HTTPBadRequest( - text="No CONNECTION upgrade hdr: {}".format( - headers.get(hdrs.CONNECTION) - ) + text=f"No CONNECTION upgrade hdr: {headers.get(hdrs.CONNECTION)}" ) # find common sub-protocol between client and server - protocol: Optional[str] = None + protocol: str | None = None if hdrs.SEC_WEBSOCKET_PROTOCOL in headers: req_protocols = [ str(proto.strip()) @@ -308,7 +307,7 @@ def _handshake( notakeover, ) - def _pre_start(self, request: BaseRequest) -> Tuple[Optional[str], WebSocketWriter]: + def _pre_start(self, request: BaseRequest) -> tuple[str | None, WebSocketWriter]: self._loop = request._loop headers, protocol, compress, notakeover = self._handshake(request) @@ -330,7 +329,7 @@ def _pre_start(self, request: BaseRequest) -> Tuple[Optional[str], WebSocketWrit return protocol, writer def _post_start( - self, request: BaseRequest, protocol: Optional[str], writer: WebSocketWriter + self, request: BaseRequest, protocol: str | None, writer: WebSocketWriter ) -> None: self._ws_protocol = protocol self._writer = writer @@ -367,15 +366,15 @@ def closed(self) -> bool: return self._closed @property - def close_code(self) -> Optional[int]: + def close_code(self) -> int | None: return self._close_code @property - def ws_protocol(self) -> Optional[str]: + def ws_protocol(self) -> str | None: return self._ws_protocol @property - def compress(self) -> Union[int, bool]: + def compress(self) -> int | bool: return self._compress def get_extra_info(self, name: str, default: Any = None) -> Any: @@ -388,7 +387,7 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: return default return writer.transport.get_extra_info(name, default) - def exception(self) -> Optional[BaseException]: + def exception(self) -> BaseException | None: return self._exception async def ping(self, message: bytes = b"") -> None: @@ -403,14 +402,14 @@ async def pong(self, message: bytes = b"") -> None: await self._writer.send_frame(message, WSMsgType.PONG) async def send_frame( - self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None + self, message: bytes, opcode: WSMsgType, compress: int | None = None ) -> None: """Send a frame over the websocket.""" if self._writer is None: raise RuntimeError("Call .prepare() first") await self._writer.send_frame(message, opcode, compress) - async def send_str(self, data: str, compress: Optional[int] = None) -> None: + async def send_str(self, data: str, compress: int | None = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, str): @@ -419,7 +418,7 @@ async def send_str(self, data: str, compress: Optional[int] = None) -> None: data.encode("utf-8"), WSMsgType.TEXT, compress=compress ) - async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: + async def send_bytes(self, data: bytes, compress: int | None = None) -> None: if self._writer is None: raise RuntimeError("Call .prepare() first") if not isinstance(data, (bytes, bytearray, memoryview)): @@ -429,7 +428,7 @@ async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: async def send_json( self, data: Any, - compress: Optional[int] = None, + compress: int | None = None, *, dumps: JSONEncoder = json.dumps, ) -> None: @@ -515,7 +514,7 @@ def _close_transport(self) -> None: if self._req is not None and self._req.transport is not None: self._req.transport.close() - async def receive(self, timeout: Optional[float] = None) -> WSMessage: + async def receive(self, timeout: float | None = None) -> WSMessage: if self._reader is None: raise RuntimeError("Call .prepare() first") @@ -589,7 +588,7 @@ async def receive(self, timeout: Optional[float] = None) -> WSMessage: return msg - async def receive_str(self, *, timeout: Optional[float] = None) -> str: + async def receive_str(self, *, timeout: float | None = None) -> str: msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( @@ -597,7 +596,7 @@ async def receive_str(self, *, timeout: Optional[float] = None) -> str: ) return msg.data - async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: + async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) if msg.type is not WSMsgType.BINARY: raise WSMessageTypeError( @@ -606,7 +605,7 @@ async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: return msg.data async def receive_json( - self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None + self, *, loads: JSONDecoder = json.loads, timeout: float | None = None ) -> Any: data = await self.receive_str(timeout=timeout) return loads(data) diff --git a/aiohttp/worker.py b/aiohttp/worker.py index ed6abd43eb6..da44d7253fc 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -42,9 +42,9 @@ class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported] def __init__(self, *args: Any, **kw: Any) -> None: super().__init__(*args, **kw) - self._task: Optional[asyncio.Task[None]] = None + self._task: asyncio.Task[None] | None = None self.exit_code = 0 - self._notify_waiter: Optional[asyncio.Future[bool]] = None + self._notify_waiter: asyncio.Future[bool] | None = None def init_process(self) -> None: # create new event_loop after fork @@ -81,7 +81,7 @@ async def _run(self) -> None: else: raise RuntimeError( "wsgi app should be either Application or " - "async function returning Application, got {}".format(self.wsgi) + f"async function returning Application, got {self.wsgi}" ) if runner is None: @@ -187,7 +187,7 @@ def init_signals(self) -> None: # Reset signals so Gunicorn doesn't swallow subprocess return codes # See: https://github.com/aio-libs/aiohttp/issues/6130 - def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None: + def handle_quit(self, sig: int, frame: FrameType | None) -> None: self.alive = False # worker_int callback @@ -196,7 +196,7 @@ def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None: # wakeup closing process self._notify_waiter_done() - def handle_abort(self, sig: int, frame: Optional[FrameType]) -> None: + def handle_abort(self, sig: int, frame: FrameType | None) -> None: self.alive = False self.exit_code = 1 self.cfg.worker_abort(self) diff --git a/examples/background_tasks.py b/examples/background_tasks.py index 401f3ccebe8..a4b37a02bb3 100755 --- a/examples/background_tasks.py +++ b/examples/background_tasks.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 """Example of aiohttp.web.Application.on_startup signal handler""" import asyncio +from collections.abc import AsyncIterator from contextlib import suppress -from typing import AsyncIterator, List import valkey.asyncio as valkey from aiohttp import web valkey_listener = web.AppKey("valkey_listener", asyncio.Task[None]) -websockets = web.AppKey("websockets", List[web.WebSocketResponse]) +websockets = web.AppKey("websockets", list[web.WebSocketResponse]) async def websocket_handler(request: web.Request) -> web.StreamResponse: @@ -57,7 +57,7 @@ async def background_tasks(app: web.Application) -> AsyncIterator[None]: def init() -> web.Application: app = web.Application() - l: List[web.WebSocketResponse] = [] + l: list[web.WebSocketResponse] = [] app[websockets] = l app.router.add_get("/news", websocket_handler) app.cleanup_ctx.append(background_tasks) diff --git a/examples/cli_app.py b/examples/cli_app.py index 37531536f1f..e18ce1fb78b 100755 --- a/examples/cli_app.py +++ b/examples/cli_app.py @@ -14,7 +14,7 @@ """ from argparse import ArgumentParser, Namespace -from typing import Optional, Sequence +from collections.abc import Sequence from aiohttp import web @@ -27,7 +27,7 @@ async def display_message(req: web.Request) -> web.StreamResponse: return web.Response(text=text) -def init(argv: Optional[Sequence[str]]) -> web.Application: +def init(argv: Sequence[str] | None) -> web.Application: arg_parser = ArgumentParser( prog="aiohttp.web ...", description="Application CLI", add_help=False ) diff --git a/examples/combined_middleware.py b/examples/combined_middleware.py index 8646a182b98..a4e50b07414 100644 --- a/examples/combined_middleware.py +++ b/examples/combined_middleware.py @@ -18,7 +18,7 @@ import logging import time from http import HTTPStatus -from typing import TYPE_CHECKING, Set, Union +from typing import TYPE_CHECKING from aiohttp import ( ClientHandlerType, @@ -92,7 +92,7 @@ async def __call__( return await handler(request) -DEFAULT_RETRY_STATUSES: Set[HTTPStatus] = { +DEFAULT_RETRY_STATUSES: set[HTTPStatus] = { HTTPStatus.TOO_MANY_REQUESTS, HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.BAD_GATEWAY, @@ -107,7 +107,7 @@ class RetryMiddleware: def __init__( self, max_retries: int = 3, - retry_statuses: Union[Set[HTTPStatus], None] = None, + retry_statuses: set[HTTPStatus] | None = None, initial_delay: float = 1.0, backoff_factor: float = 2.0, ) -> None: @@ -122,7 +122,7 @@ async def __call__( handler: ClientHandlerType, ) -> ClientResponse: """Execute request with retry logic.""" - last_response: Union[ClientResponse, None] = None + last_response: ClientResponse | None = None delay = self.initial_delay for attempt in range(self.max_retries + 1): diff --git a/examples/fake_server.py b/examples/fake_server.py index d7be5954232..13cc2dfe77d 100755 --- a/examples/fake_server.py +++ b/examples/fake_server.py @@ -3,7 +3,6 @@ import pathlib import socket import ssl -from typing import Dict, List from aiohttp import ClientSession, TCPConnector, test_utils, web from aiohttp.abc import AbstractResolver, ResolveResult @@ -13,7 +12,7 @@ class FakeResolver(AbstractResolver): _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1", socket.AF_INET6: "::1"} - def __init__(self, fakes: Dict[str, int]) -> None: + def __init__(self, fakes: dict[str, int]) -> None: """fakes -- dns -> port dict""" self._fakes = fakes self._resolver = DefaultResolver() @@ -23,7 +22,7 @@ async def resolve( host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: fake_port = self._fakes.get(host) if fake_port is not None: return [ @@ -59,7 +58,7 @@ def __init__(self) -> None: self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.ssl_context.load_cert_chain(str(ssl_cert), str(ssl_key)) - async def start(self) -> Dict[str, int]: + async def start(self) -> dict[str, int]: port = test_utils.unused_port() await self.runner.setup() site = web.TCPSite(self.runner, "127.0.0.1", port, ssl_context=self.ssl_context) diff --git a/examples/logging_middleware.py b/examples/logging_middleware.py index b6345953db2..37027ea0acb 100644 --- a/examples/logging_middleware.py +++ b/examples/logging_middleware.py @@ -13,7 +13,8 @@ import json import logging import time -from typing import Any, Coroutine, List +from collections.abc import Coroutine +from typing import Any from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web @@ -141,7 +142,7 @@ async def run_tests() -> None: # Test 6: Multiple concurrent requests print("\n=== Test 6: Multiple concurrent requests ===") - coros: List[Coroutine[Any, Any, ClientResponse]] = [] + coros: list[Coroutine[Any, Any, ClientResponse]] = [] for i in range(3): coro = session.get(f"http://localhost:8080/hello/User{i}") coros.append(coro) diff --git a/examples/retry_middleware.py b/examples/retry_middleware.py index c8fa829455a..157bd5a268c 100644 --- a/examples/retry_middleware.py +++ b/examples/retry_middleware.py @@ -13,14 +13,14 @@ import asyncio import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, List, Set, Union +from typing import TYPE_CHECKING from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) -DEFAULT_RETRY_STATUSES: Set[HTTPStatus] = { +DEFAULT_RETRY_STATUSES: set[HTTPStatus] = { HTTPStatus.TOO_MANY_REQUESTS, HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.BAD_GATEWAY, @@ -35,7 +35,7 @@ class RetryMiddleware: def __init__( self, max_retries: int = 3, - retry_statuses: Union[Set[HTTPStatus], None] = None, + retry_statuses: set[HTTPStatus] | None = None, initial_delay: float = 1.0, backoff_factor: float = 2.0, ) -> None: @@ -50,7 +50,7 @@ async def __call__( handler: ClientHandlerType, ) -> ClientResponse: """Execute request with retry logic.""" - last_response: Union[ClientResponse, None] = None + last_response: ClientResponse | None = None delay = self.initial_delay for attempt in range(self.max_retries + 1): @@ -92,8 +92,8 @@ class TestServer: """Test server with stateful endpoints for retry testing.""" def __init__(self) -> None: - self.request_counters: Dict[str, int] = {} - self.status_sequences: Dict[str, List[int]] = { + self.request_counters: dict[str, int] = {} + self.status_sequences: dict[str, list[int]] = { "eventually-ok": [500, 503, 502, 200], # Fails 3 times, then succeeds "always-error": [500, 500, 500, 500], # Always fails "immediate-ok": [200], # Succeeds immediately diff --git a/examples/token_refresh_middleware.py b/examples/token_refresh_middleware.py index 8a7ff963850..4f0a894c76c 100644 --- a/examples/token_refresh_middleware.py +++ b/examples/token_refresh_middleware.py @@ -20,8 +20,9 @@ import logging import secrets import time +from collections.abc import Coroutine from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Union +from typing import TYPE_CHECKING, Any from aiohttp import ( ClientHandlerType, @@ -42,8 +43,8 @@ class TokenRefreshMiddleware: def __init__(self, token_endpoint: str, refresh_token: str) -> None: self.token_endpoint = token_endpoint self.refresh_token = refresh_token - self.access_token: Union[str, None] = None - self.token_expires_at: Union[float, None] = None + self.access_token: str | None = None + self.token_expires_at: float | None = None self._refresh_lock = asyncio.Lock() async def _refresh_access_token(self, session: ClientSession) -> str: @@ -121,8 +122,8 @@ class TestServer: """Test server with JWT-like token authentication.""" def __init__(self) -> None: - self.tokens_db: Dict[str, Dict[str, Union[str, float]]] = {} - self.refresh_tokens_db: Dict[str, Dict[str, Union[str, float]]] = { + self.tokens_db: dict[str, dict[str, str | float]] = {} + self.refresh_tokens_db: dict[str, dict[str, str | float]] = { # Hash of refresh token -> user data hashlib.sha256(b"demo_refresh_token_12345").hexdigest(): { "user_id": "user123", @@ -135,7 +136,7 @@ def generate_access_token(self) -> str: """Generate a secure random access token.""" return secrets.token_urlsafe(32) - async def _process_token_refresh(self, data: Dict[str, str]) -> web.Response: + async def _process_token_refresh(self, data: dict[str, str]) -> web.Response: """Process the token refresh request.""" refresh_token = data.get("refresh_token") @@ -189,7 +190,7 @@ async def handle_token_refresh(self, request: web.Request) -> web.Response: async def verify_bearer_token( self, request: web.Request - ) -> Union[Dict[str, Union[str, float]], None]: + ) -> dict[str, str | float] | None: """Verify bearer token and return user data if valid.""" auth_header = request.headers.get(hdrs.AUTHORIZATION, "") @@ -285,7 +286,7 @@ async def run_tests() -> None: print("\n=== Test 3: Multiple concurrent requests ===") print("(Should only refresh token once)") - coros: List[Coroutine[Any, Any, ClientResponse]] = [] + coros: list[Coroutine[Any, Any, ClientResponse]] = [] for i in range(3): coro = session.get("http://localhost:8080/api/protected") coros.append(coro) diff --git a/examples/web_ws.py b/examples/web_ws.py index a5626ac90c9..4051ae12890 100755 --- a/examples/web_ws.py +++ b/examples/web_ws.py @@ -6,15 +6,14 @@ # mypy: disallow-any-expr, disallow-any-unimported, disallow-subclassing-any import os -from typing import List, Union from aiohttp import web WS_FILE = os.path.join(os.path.dirname(__file__), "websocket.html") -sockets = web.AppKey("sockets", List[web.WebSocketResponse]) +sockets = web.AppKey("sockets", list[web.WebSocketResponse]) -async def wshandler(request: web.Request) -> Union[web.WebSocketResponse, web.Response]: +async def wshandler(request: web.Request) -> web.WebSocketResponse | web.Response: resp = web.WebSocketResponse() available = resp.can_prepare(request) if not available: @@ -54,7 +53,7 @@ async def on_shutdown(app: web.Application) -> None: def init() -> web.Application: app = web.Application() - l: List[web.WebSocketResponse] = [] + l: list[web.WebSocketResponse] = [] app[sockets] = l app.router.add_get("/", wshandler) app.on_shutdown.append(on_shutdown) diff --git a/pyproject.toml b/pyproject.toml index a9b4200a06c..c09362d66fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,4 +92,4 @@ exclude-modules = "(^aiohttp\\.helpers)" [tool.black] # TODO: Remove when project metadata is moved here. # Black can read the value from [project.requires-python]. -target-version = ["py39", "py310", "py311", "py312", "py313"] +target-version = ["py310", "py311", "py312", "py313", "py314"] diff --git a/setup.cfg b/setup.cfg index 07a6a2d0c93..7fd0f0c59c6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,6 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 - Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Programming Language :: Python :: 3.12 @@ -43,7 +42,7 @@ classifiers = Topic :: Internet :: WWW/HTTP [options] -python_requires = >=3.9 +python_requires = >=3.10 packages = aiohttp aiohttp._websocket diff --git a/setup.py b/setup.py index fded89876f2..96cfa50d47e 100644 --- a/setup.py +++ b/setup.py @@ -4,8 +4,8 @@ from setuptools import Extension, setup -if sys.version_info < (3, 9): - raise RuntimeError("aiohttp 4.x requires Python 3.9+") +if sys.version_info < (3, 10): + raise RuntimeError("aiohttp 4.x requires Python 3.10+") USE_SYSTEM_DEPS = bool( diff --git a/tests/autobahn/Dockerfile.aiohttp b/tests/autobahn/Dockerfile.aiohttp index 2d37683a1ad..0f57652415c 100644 --- a/tests/autobahn/Dockerfile.aiohttp +++ b/tests/autobahn/Dockerfile.aiohttp @@ -1,4 +1,4 @@ -FROM python:3.9.5 +FROM python:3.14 COPY ./ /src diff --git a/tests/autobahn/server/server.py b/tests/autobahn/server/server.py index d2801e5ef3d..47122ae2578 100644 --- a/tests/autobahn/server/server.py +++ b/tests/autobahn/server/server.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 import logging -from typing import List from aiohttp import WSCloseCode, web -websockets = web.AppKey("websockets", List[web.WebSocketResponse]) +websockets = web.AppKey("websockets", list[web.WebSocketResponse]) async def wshandler(request: web.Request) -> web.WebSocketResponse: @@ -46,7 +45,7 @@ async def on_shutdown(app: web.Application) -> None: ) app = web.Application() - l: List[web.WebSocketResponse] = [] + l: list[web.WebSocketResponse] = [] app[websockets] = l app.router.add_route("GET", "/", wshandler) app.on_shutdown.append(on_shutdown) diff --git a/tests/autobahn/test_autobahn.py b/tests/autobahn/test_autobahn.py index 96b45148799..89388309073 100644 --- a/tests/autobahn/test_autobahn.py +++ b/tests/autobahn/test_autobahn.py @@ -1,8 +1,9 @@ import json import subprocess import sys +from collections.abc import Generator from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Generator, List +from typing import TYPE_CHECKING, Any import pytest from pytest import TempPathFactory @@ -37,7 +38,7 @@ def build_autobahn_testsuite() -> Generator[None, None, None]: docker.image.remove(x="autobahn-testsuite") -def get_failed_tests(report_path: str, name: str) -> List[Dict[str, Any]]: +def get_failed_tests(report_path: str, name: str) -> list[dict[str, Any]]: path = Path(report_path) result_summary = json.loads((path / "index.json").read_text())[name] failed_messages = [] diff --git a/tests/conftest.py b/tests/conftest.py index 41ed907ea70..5e872dec5c7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,10 +4,11 @@ import socket import ssl import sys +from collections.abc import AsyncIterator, Callable, Generator, Iterator from hashlib import md5, sha1, sha256 from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, AsyncIterator, Callable, Generator, Iterator +from typing import Any from unittest import mock from uuid import uuid4 diff --git a/tests/test_benchmarks_client_request.py b/tests/test_benchmarks_client_request.py index c430ef3a49c..ea5e1c28a48 100644 --- a/tests/test_benchmarks_client_request.py +++ b/tests/test_benchmarks_client_request.py @@ -2,7 +2,6 @@ import asyncio from http.cookies import BaseCookie -from typing import Union from multidict import CIMultiDict from pytest_codspeed import BenchmarkFixture @@ -122,7 +121,7 @@ def is_closing(self) -> bool: """Swallow is_closing.""" return False - def write(self, data: Union[bytes, bytearray, memoryview]) -> None: + def write(self, data: bytes | bytearray | memoryview) -> None: """Swallow writes.""" class MockProtocol(asyncio.BaseProtocol): diff --git a/tests/test_benchmarks_http_websocket.py b/tests/test_benchmarks_http_websocket.py index 728921b53eb..2aa9b8294bd 100644 --- a/tests/test_benchmarks_http_websocket.py +++ b/tests/test_benchmarks_http_websocket.py @@ -1,7 +1,6 @@ """codspeed benchmarks for http websocket.""" import asyncio -from typing import Union import pytest from pytest_codspeed import BenchmarkFixture @@ -60,7 +59,7 @@ def is_closing(self) -> bool: """Swallow is_closing.""" return False - def write(self, data: Union[bytes, bytearray, memoryview]) -> None: + def write(self, data: bytes | bytearray | memoryview) -> None: """Swallow writes.""" diff --git a/tests/test_benchmarks_web_urldispatcher.py b/tests/test_benchmarks_web_urldispatcher.py index 936ed6320ed..339eaef8a0e 100644 --- a/tests/test_benchmarks_web_urldispatcher.py +++ b/tests/test_benchmarks_web_urldispatcher.py @@ -6,7 +6,7 @@ import random import string from pathlib import Path -from typing import NoReturn, Optional, cast +from typing import NoReturn, cast from unittest import mock import pytest @@ -68,7 +68,7 @@ async def handler(request: web.Request) -> NoReturn: router = app.router request = _mock_request(method="GET", path="/") - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) @@ -106,7 +106,7 @@ async def handler(request: web.Request) -> NoReturn: router = app.router request = _mock_request(method="GET", path="/") - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) @@ -136,7 +136,7 @@ def test_resolve_static_root_route( router = app.router request = _mock_request(method="GET", path="/") - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) @@ -169,7 +169,7 @@ async def handler(request: web.Request) -> NoReturn: router = app.router request = _mock_request(method="GET", path="/api/server/dispatch/1/update") - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for _ in range(resolve_count): ret = await router.resolve(request) @@ -205,7 +205,7 @@ async def handler(request: web.Request) -> NoReturn: for count in range(250) ] - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) @@ -245,7 +245,7 @@ async def handler(request: web.Request) -> NoReturn: requests = [(_mock_request(method="GET", path=url), url) for url in urls] - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request, path in requests: ret = await router.resolve(request) @@ -282,7 +282,7 @@ async def handler(request: web.Request) -> NoReturn: for customer in range(250) ] - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) @@ -323,7 +323,7 @@ async def handler(request: web.Request) -> NoReturn: for customer in range(250) ] - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) @@ -362,7 +362,7 @@ async def handler(request: web.Request) -> NoReturn: for customer in range(250) ] - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) @@ -411,7 +411,7 @@ async def handler(request: web.Request) -> NoReturn: ) ) - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) @@ -482,7 +482,7 @@ async def handler(request: web.Request) -> NoReturn: ) ) - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) @@ -518,7 +518,7 @@ async def handler(request: web.Request) -> NoReturn: request = _mock_request(method="GET", path="/") - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for i in range(250): ret = await router.resolve(request) @@ -558,7 +558,7 @@ async def handler(request: web.Request) -> NoReturn: for customer in range(250) ] - async def run_url_dispatcher_benchmark() -> Optional[web.UrlMappingMatchInfo]: + async def run_url_dispatcher_benchmark() -> web.UrlMappingMatchInfo | None: ret = None for request in requests: ret = await router.resolve(request) diff --git a/tests/test_circular_imports.py b/tests/test_circular_imports.py index d513e9bde8b..9b5d7ed2697 100644 --- a/tests/test_circular_imports.py +++ b/tests/test_circular_imports.py @@ -14,10 +14,11 @@ import socket import subprocess import sys +from collections.abc import Generator from itertools import chain from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Generator, List, Union +from typing import TYPE_CHECKING, Union import pytest @@ -28,8 +29,8 @@ def _mark_aiohttp_worker_for_skipping( - importables: List[str], -) -> List[Union[str, "ParameterSet"]]: + importables: list[str], +) -> list[Union[str, "ParameterSet"]]: return [ ( pytest.param( @@ -45,7 +46,7 @@ def _mark_aiohttp_worker_for_skipping( ] -def _find_all_importables(pkg: ModuleType) -> List[str]: +def _find_all_importables(pkg: ModuleType) -> list[str]: """Find all importables in the project. Return them in order. diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 3433226db49..5006a745346 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -13,19 +13,9 @@ import tarfile import time import zipfile +from collections.abc import AsyncIterator, Awaitable, Callable from contextlib import suppress -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Dict, - List, - NoReturn, - Optional, - Type, - Union, -) +from typing import Any, NoReturn from unittest import mock import pytest @@ -1721,7 +1711,7 @@ async def write_bytes( self: ClientRequest, writer: StreamWriter, conn: Connection, - content_length: Optional[int] = None, + content_length: int | None = None, ) -> None: nonlocal write_mock, writelines_mock original_write = writer._write @@ -2590,7 +2580,7 @@ async def handler(request: web.Request) -> web.Response: assert request.cookies["test3"] == "456" return web.Response() - c: "http.cookies.Morsel[str]" = http.cookies.Morsel() + c: http.cookies.Morsel[str] = http.cookies.Morsel() c.set("test3", "456", "456") app = web.Application() @@ -2610,17 +2600,17 @@ async def handler(request: web.Request) -> web.Response: assert request.cookies["test6"] == "abc" return web.Response() - c: "http.cookies.Morsel[str]" = http.cookies.Morsel() + c: http.cookies.Morsel[str] = http.cookies.Morsel() c.set("test3", "456", "456") app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app, cookies={"test1": "123", "test2": c}) - rc: "http.cookies.Morsel[str]" = http.cookies.Morsel() + rc: http.cookies.Morsel[str] = http.cookies.Morsel() rc.set("test6", "abc", "abc") - cookies: Dict[str, Union[str, "http.cookies.Morsel[str]"]] + cookies: dict[str, str | http.cookies.Morsel[str]] cookies = {"test4": "789", "test5": rc} async with client.get("/", cookies=cookies) as resp: assert 200 == resp.status @@ -2681,7 +2671,7 @@ async def handler(request: web.Request) -> web.Response: assert request.cookies["test3"] == "456" return web.Response() - c: "http.cookies.Morsel[str]" = http.cookies.Morsel() + c: http.cookies.Morsel[str] = http.cookies.Morsel() c.set("test3", "456", "456") c["httponly"] = True c["secure"] = True @@ -2946,7 +2936,7 @@ async def handler_redirect(request: web.Request) -> web.Response: ), ) async def test_invalid_and_non_http_url( - url: str, error_message_url: str, expected_exception_class: Type[Exception] + url: str, error_message_url: str, expected_exception_class: type[Exception] ) -> None: async with aiohttp.ClientSession() as http_session: with pytest.raises( @@ -2973,7 +2963,7 @@ async def test_invalid_redirect_url( aiohttp_client: AiohttpClient, invalid_redirect_url: str, error_message_url: str, - expected_exception_class: Type[Exception], + expected_exception_class: type[Exception], ) -> None: headers = {hdrs.LOCATION: invalid_redirect_url} @@ -3008,7 +2998,7 @@ async def test_invalid_redirect_url_multiple_redirects( aiohttp_client: AiohttpClient, invalid_redirect_url: str, error_message_url: str, - expected_exception_class: Type[Exception], + expected_exception_class: type[Exception], ) -> None: app = web.Application() @@ -3225,7 +3215,7 @@ async def resolve( host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None @@ -3338,7 +3328,7 @@ async def resolve( host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None @@ -3409,7 +3399,7 @@ async def resolve( host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None @@ -3474,7 +3464,7 @@ async def resolve( host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: server = etc_hosts[(host, port)] assert server.port is not None @@ -3588,7 +3578,7 @@ async def resolve( host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: return [] async def close(self) -> None: @@ -3845,7 +3835,7 @@ async def test_server_close_keepalive_connection() -> None: class Proto(asyncio.Protocol): def connection_made(self, transport: asyncio.BaseTransport) -> None: assert isinstance(transport, asyncio.Transport) - self.transp: Optional[asyncio.Transport] = transport + self.transp: asyncio.Transport | None = transport self.data = b"" def data_received(self, data: bytes) -> None: @@ -3861,7 +3851,7 @@ def data_received(self, data: bytes) -> None: ) self.transp.close() - def connection_lost(self, exc: Optional[BaseException]) -> None: + def connection_lost(self, exc: BaseException | None) -> None: self.transp = None server = await loop.create_server(Proto, "127.0.0.1", unused_port()) @@ -3886,7 +3876,7 @@ async def test_handle_keepalive_on_closed_connection() -> None: class Proto(asyncio.Protocol): def connection_made(self, transport: asyncio.BaseTransport) -> None: assert isinstance(transport, asyncio.Transport) - self.transp: Optional[asyncio.Transport] = transport + self.transp: asyncio.Transport | None = transport self.data = b"" def data_received(self, data: bytes) -> None: @@ -3896,7 +3886,7 @@ def data_received(self, data: bytes) -> None: self.transp.write(b"HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 2\r\n\r\nok") self.transp.close() - def connection_lost(self, exc: Optional[BaseException]) -> None: + def connection_lost(self, exc: BaseException | None) -> None: self.transp = None server = await loop.create_server(Proto, "127.0.0.1", unused_port()) @@ -4421,7 +4411,7 @@ async def test_request_with_wrong_ssl_type(aiohttp_client: AiohttpClient) -> Non [(42, TypeError), ("InvalidUrl", InvalidURL)], ) async def test_request_with_wrong_proxy( - aiohttp_client: AiohttpClient, value: Union[int, str], exc_type: Type[Exception] + aiohttp_client: AiohttpClient, value: int | str, exc_type: type[Exception] ) -> None: app = web.Application() session = await aiohttp_client(app) @@ -5227,7 +5217,7 @@ def __init__(self, port: int): async def resolve( self, host: str, port: int = 0, family: int = 0 - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: if host in ("amazon.it", "www.amazon.it"): return [ { diff --git a/tests/test_client_middleware.py b/tests/test_client_middleware.py index 217877759c0..da5bcece6e8 100644 --- a/tests/test_client_middleware.py +++ b/tests/test_client_middleware.py @@ -2,7 +2,7 @@ import json import socket -from typing import Dict, List, NoReturn, Optional, Union +from typing import NoReturn import pytest @@ -243,7 +243,7 @@ async def handler(request: web.Request) -> web.Response: async def challenge_auth_middleware( request: ClientRequest, handler: ClientHandlerType ) -> ClientResponse: - nonce: Optional[str] = None + nonce: str | None = None attempted: bool = False while True: @@ -285,7 +285,7 @@ async def challenge_auth_middleware( async def test_client_middleware_multi_step_auth(aiohttp_server: AiohttpServer) -> None: """Test middleware with multi-step authentication flow.""" auth_state: dict[str, int] = {} - middleware_state: Dict[str, Optional[Union[int, str]]] = { + middleware_state: dict[str, int | str | None] = { "step": 0, "session": None, "challenge": None, @@ -372,7 +372,7 @@ async def test_client_middleware_conditional_retry( ) -> None: """Test middleware with conditional retry based on response content.""" request_count = 0 - token_state: Dict[str, Union[str, bool]] = { + token_state: dict[str, str | bool] = { "token": "old-token", "refreshed": False, } @@ -735,7 +735,7 @@ async def test_client_middleware_blocks_connection_before_established( ) -> None: """Test that middleware can block connections before they are established.""" blocked_hosts = {"blocked.example.com", "evil.com"} - connection_attempts: List[str] = [] + connection_attempts: list[str] = [] async def handler(request: web.Request) -> web.Response: return web.Response(text="Reached") @@ -801,7 +801,7 @@ async def test_client_middleware_blocks_connection_without_dns_lookup( ) -> None: """Test that middleware prevents DNS lookups for blocked hosts.""" blocked_hosts = {"blocked.domain.tld"} - dns_lookups_made: List[str] = [] + dns_lookups_made: list[str] = [] # Create a simple server for the allowed request async def handler(request: web.Request) -> web.Response: @@ -817,7 +817,7 @@ async def resolve( hostname: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: dns_lookups_made.append(hostname) return await super().resolve(hostname, port, family) @@ -878,7 +878,7 @@ class TrackingConnector(TCPConnector): connection_attempts = 0 async def _create_connection( - self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: self.connection_attempts += 1 return await super()._create_connection(req, traces, timeout) @@ -927,11 +927,11 @@ async def test_middleware_uses_session_avoids_recursion_with_path_check( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can avoid infinite recursion using a path check.""" - log_collector: List[Dict[str, str]] = [] + log_collector: list[dict[str, str]] = [] async def log_api_handler(request: web.Request) -> web.Response: """Handle log API requests.""" - data: Dict[str, str] = await request.json() + data: dict[str, str] = await request.json() log_collector.append(data) return web.Response(text="OK") @@ -993,14 +993,14 @@ async def test_middleware_uses_session_avoids_recursion_with_disabled_middleware aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can avoid infinite recursion by disabling middleware.""" - log_collector: List[Dict[str, str]] = [] + log_collector: list[dict[str, str]] = [] request_count = 0 async def log_api_handler(request: web.Request) -> web.Response: """Handle log API requests.""" nonlocal request_count request_count += 1 - data: Dict[str, str] = await request.json() + data: dict[str, str] = await request.json() log_collector.append(data) return web.Response(text="OK") @@ -1061,8 +1061,8 @@ async def test_middleware_can_check_request_body( aiohttp_server: AiohttpServer, ) -> None: """Test that middleware can check request body.""" - received_bodies: List[str] = [] - received_headers: List[Dict[str, str]] = [] + received_bodies: list[str] = [] + received_headers: list[dict[str, str]] = [] async def handler(request: web.Request) -> web.Response: """Server handler that receives requests.""" diff --git a/tests/test_client_middleware_digest_auth.py b/tests/test_client_middleware_digest_auth.py index 2059bfea337..064d4d78239 100644 --- a/tests/test_client_middleware_digest_auth.py +++ b/tests/test_client_middleware_digest_auth.py @@ -2,8 +2,9 @@ import io import re +from collections.abc import Generator from hashlib import md5, sha1 -from typing import Generator, Literal, Union +from typing import Literal from unittest import mock import pytest @@ -323,7 +324,7 @@ def KD(secret: str, data: str) -> str: async def test_digest_response_exact_match( qop: str, algorithm: str, - body: Union[Literal[b""], BytesIOPayload], + body: Literal[b""] | BytesIOPayload, body_str: str, mock_sha1_digest: mock.MagicMock, ) -> None: diff --git a/tests/test_client_request.py b/tests/test_client_request.py index de25a53dc93..ef444f1008f 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -4,19 +4,9 @@ import pathlib import sys import warnings +from collections.abc import AsyncIterator, Callable, Iterable, Iterator from http.cookies import BaseCookie, Morsel, SimpleCookie -from typing import ( - Any, - AsyncIterator, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Protocol, - Union, -) +from typing import Any, Protocol from unittest import mock import pytest @@ -321,7 +311,7 @@ def test_host_header_ipv6_with_port(make_request: _RequestMaker) -> None: ), ) def test_host_header_fqdn( - make_request: _RequestMaker, url: str, headers: Dict[str, str], expected: str + make_request: _RequestMaker, url: str, headers: dict[str, str], expected: str ) -> None: req = make_request("get", url, headers=headers) assert req.headers["HOST"] == expected @@ -1171,7 +1161,7 @@ async def gen() -> AsyncIterator[bytes]: original_write_bytes = req.write_bytes async def _mock_write_bytes( - writer: AbstractStreamWriter, conn: mock.Mock, content_length: Optional[int] + writer: AbstractStreamWriter, conn: mock.Mock, content_length: int | None ) -> None: # Ensure the task is scheduled await asyncio.sleep(0) @@ -1548,10 +1538,10 @@ def test_insecure_fingerprint_sha1(loop: asyncio.AbstractEventLoop) -> None: def test_loose_cookies_types(loop: asyncio.AbstractEventLoop) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) - morsel: "Morsel[str]" = Morsel() + morsel: Morsel[str] = Morsel() morsel.set(key="string", val="Another string", coded_val="really") - accepted_types: List[LooseCookies] = [ + accepted_types: list[LooseCookies] = [ [("str", BaseCookie())], [("str", morsel)], [ @@ -1749,7 +1739,7 @@ async def test_write_bytes_with_iterable_content_length_limit( loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, - data: Union[List[bytes], bytes], + data: list[bytes] | bytes, ) -> None: """Test that write_bytes respects content_length limit for iterable data.""" # Test with iterable data @@ -2212,8 +2202,8 @@ async def test_expect100_with_body_becomes_none() -> None: ) def test_content_length_for_methods( method: str, - data: Optional[bytes], - expected_content_length: Optional[str], + data: bytes | None, + expected_content_length: str | None, loop: asyncio.AbstractEventLoop, ) -> None: """Test that Content-Length header is set correctly for all HTTP methods.""" diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 1b8c3878828..11a815a325e 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -6,20 +6,9 @@ import sys import warnings from collections import deque +from collections.abc import Awaitable, Callable, Iterator from http.cookies import BaseCookie, SimpleCookie -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Iterator, - List, - NoReturn, - Optional, - TypedDict, - Union, - cast, -) +from typing import Any, NoReturn, TypedDict, cast from unittest import mock from uuid import uuid4 @@ -41,7 +30,7 @@ class _Params(TypedDict): - headers: Dict[str, str] + headers: dict[str, str] max_redirects: int compress: str chunked: bool @@ -594,7 +583,7 @@ class UnexpectedException(BaseException): original_connect = session._connector.connect async def connect( - req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout + req: ClientRequest, traces: list[Trace], timeout: aiohttp.ClientTimeout ) -> Connection: conn = await original_connect(req, traces, timeout) connections.append(conn) @@ -657,7 +646,7 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc] original_connect = session._connector.connect async def connect( - req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout + req: ClientRequest, traces: list[Trace], timeout: aiohttp.ClientTimeout ) -> Connection: conn = await original_connect(req, traces, timeout) connections.append(conn) @@ -720,7 +709,7 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc] original_connect = session._connector.connect async def connect( - req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout + req: ClientRequest, traces: list[Trace], timeout: aiohttp.ClientTimeout ) -> Connection: conn = await original_connect(req, traces, timeout) connections.append(conn) @@ -761,13 +750,13 @@ def __init__(self) -> None: self._filter_cookies_mock = mock.Mock(return_value=BaseCookie()) self._clear_mock = mock.Mock() self._clear_domain_mock = mock.Mock() - self._items: List[Any] = [] + self._items: list[Any] = [] @property def quote_cookie(self) -> bool: return True - def clear(self, predicate: Optional[abc.ClearCookiePredicate] = None) -> None: + def clear(self, predicate: abc.ClearCookiePredicate | None = None) -> None: self._clear_mock(predicate) def clear_domain(self, domain: str) -> None: @@ -1033,7 +1022,7 @@ def reset_mocks() -> None: for m in mocks: m.reset_mock() - def to_trace_urls(mock_func: mock.Mock) -> List[URL]: + def to_trace_urls(mock_func: mock.Mock) -> list[URL]: return [call_args[0][-1].url for call_args in mock_func.call_args_list] def to_url(path: str) -> URL: @@ -1294,8 +1283,8 @@ async def test_requote_redirect_url_default_disable() -> None: ) async def test_build_url_returns_expected_url( # type: ignore[misc] create_session: Callable[..., Awaitable[ClientSession]], - base_url: Union[URL, str, None], - url: Union[URL, str], + base_url: URL | str | None, + url: URL | str, expected_url: URL, ) -> None: session = await create_session(base_url) diff --git a/tests/test_client_ws.py b/tests/test_client_ws.py index 500564896c8..9e1d8457586 100644 --- a/tests/test_client_ws.py +++ b/tests/test_client_ws.py @@ -2,7 +2,7 @@ import base64 import hashlib import os -from typing import Mapping, Type +from collections.abc import Mapping from unittest import mock import pytest @@ -544,7 +544,7 @@ async def test_close_exc2( @pytest.mark.parametrize("exc", (ClientConnectionResetError, ConnectionResetError)) async def test_send_data_after_close( - exc: Type[Exception], + exc: type[Exception], ws_key: str, key_data: bytes, loop: asyncio.AbstractEventLoop, diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 3e871d8d29a..0bc05f300d4 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import List, NoReturn, Optional +from typing import NoReturn from unittest import mock import pytest @@ -360,7 +360,7 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async def test_concurrent_close(aiohttp_client: AiohttpClient) -> None: - client_ws: Optional[aiohttp.ClientWebSocketResponse] = None + client_ws: aiohttp.ClientWebSocketResponse | None = None async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() @@ -956,7 +956,7 @@ async def handler(request: web.Request) -> NoReturn: ping_started = loop.create_future() async def delayed_send_frame( - message: bytes, opcode: int, compress: Optional[int] = None + message: bytes, opcode: int, compress: int | None = None ) -> None: assert opcode == WSMsgType.PING nonlocal cancelled @@ -1277,7 +1277,7 @@ async def handler(request: web.Request) -> NoReturn: app = web.Application() app.router.add_route("GET", "/", handler) - sync_future: "asyncio.Future[List[aiohttp.ClientWebSocketResponse]]" = ( + sync_future: asyncio.Future[list[aiohttp.ClientWebSocketResponse]] = ( loop.create_future() ) client = await aiohttp_client(app) diff --git a/tests/test_connector.py b/tests/test_connector.py index 53704fc5ea5..91796280b27 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -9,21 +9,10 @@ import uuid import warnings from collections import defaultdict, deque +from collections.abc import Awaitable, Callable, Iterator, Sequence from concurrent import futures from contextlib import closing, suppress -from typing import ( - Awaitable, - Callable, - DefaultDict, - Deque, - Iterator, - List, - Literal, - NoReturn, - Optional, - Sequence, - Tuple, -) +from typing import Literal, NoReturn from unittest import mock import pytest @@ -118,7 +107,7 @@ async def go(app: web.Application) -> None: def create_mocked_conn( - conn_closing_result: Optional[asyncio.AbstractEventLoop] = None, + conn_closing_result: asyncio.AbstractEventLoop | None = None, should_close: bool = True, **kwargs: object, ) -> mock.Mock: @@ -744,7 +733,7 @@ async def test_tcp_connector_multiple_hosts_errors( async def _resolve_host( host: str, port: int, traces: object = None - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: return [ { "hostname": host, @@ -772,8 +761,8 @@ async def start_connection( return mock_socket # type: ignore[no-any-return] async def create_connection( - *args: object, sock: Optional[socket.socket] = None, **kwargs: object - ) -> Tuple[ResponseHandler, ResponseHandler]: + *args: object, sock: socket.socket | None = None, **kwargs: object + ) -> tuple[ResponseHandler, ResponseHandler]: nonlocal os_error, certificate_error, ssl_error, fingerprint_error nonlocal connected @@ -889,7 +878,7 @@ def get_extra_info(param: str) -> object: [0.1, 0.25, None], ) async def test_tcp_connector_happy_eyeballs( - loop: asyncio.AbstractEventLoop, happy_eyeballs_delay: Optional[float] + loop: asyncio.AbstractEventLoop, happy_eyeballs_delay: float | None ) -> None: conn = aiohttp.TCPConnector(happy_eyeballs_delay=happy_eyeballs_delay) @@ -906,7 +895,7 @@ async def test_tcp_connector_happy_eyeballs( async def _resolve_host( host: str, port: int, traces: object = None - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: return [ { "hostname": host, @@ -922,7 +911,7 @@ async def _resolve_host( os_error = False connected = False - async def sock_connect(*args: Tuple[str, int], **kwargs: object) -> None: + async def sock_connect(*args: tuple[str, int], **kwargs: object) -> None: addr = args[1] nonlocal os_error @@ -933,8 +922,8 @@ async def sock_connect(*args: Tuple[str, int], **kwargs: object) -> None: raise OSError async def create_connection( - *args: object, sock: Optional[socket.socket] = None, **kwargs: object - ) -> Tuple[ResponseHandler, ResponseHandler]: + *args: object, sock: socket.socket | None = None, **kwargs: object + ) -> tuple[ResponseHandler, ResponseHandler]: assert isinstance(sock, socket.socket) # Close the socket since we are not actually connecting # and we don't want to leak it. @@ -995,7 +984,7 @@ async def test_tcp_connector_interleave(loop: asyncio.AbstractEventLoop) -> None async def _resolve_host( host: str, port: int, traces: object = None - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: return [ { "hostname": host, @@ -1011,7 +1000,7 @@ async def _resolve_host( async def start_connection( addr_infos: Sequence[AddrInfoType], *, - interleave: Optional[int] = None, + interleave: int | None = None, **kwargs: object, ) -> socket.socket: nonlocal interleave_val @@ -1024,8 +1013,8 @@ async def start_connection( return mock_socket # type: ignore[no-any-return] async def create_connection( - *args: object, sock: Optional[socket.socket] = None, **kwargs: object - ) -> Tuple[ResponseHandler, ResponseHandler]: + *args: object, sock: socket.socket | None = None, **kwargs: object + ) -> tuple[ResponseHandler, ResponseHandler]: assert isinstance(sock, socket.socket) addr_info = sock.getpeername() ip = addr_info[0] @@ -1085,7 +1074,7 @@ async def test_tcp_connector_family_is_respected( async def _resolve_host( host: str, port: int, traces: object = None - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: return [ { "hostname": host, @@ -1100,13 +1089,13 @@ async def _resolve_host( connected = False - async def sock_connect(*args: Tuple[str, int], **kwargs: object) -> None: + async def sock_connect(*args: tuple[str, int], **kwargs: object) -> None: addr = args[1] addrs_tried.append(addr) async def create_connection( - *args: object, sock: Optional[socket.socket] = None, **kwargs: object - ) -> Tuple[ResponseHandler, ResponseHandler]: + *args: object, sock: socket.socket | None = None, **kwargs: object + ) -> tuple[ResponseHandler, ResponseHandler]: assert isinstance(sock, socket.socket) # Close the socket since we are not actually connecting # and we don't want to leak it. @@ -1175,7 +1164,7 @@ async def test_tcp_connector_multiple_hosts_one_timeout( async def _resolve_host( host: str, port: int, traces: object = None - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: return [ { "hostname": host, @@ -1191,7 +1180,7 @@ async def _resolve_host( async def start_connection( addr_infos: Sequence[AddrInfoType], *, - interleave: Optional[int] = None, + interleave: int | None = None, **kwargs: object, ) -> socket.socket: nonlocal timeout_error @@ -1216,8 +1205,8 @@ async def start_connection( assert False async def create_connection( - *args: object, sock: Optional[socket.socket] = None, **kwargs: object - ) -> Tuple[ResponseHandler, ResponseHandler]: + *args: object, sock: socket.socket | None = None, **kwargs: object + ) -> tuple[ResponseHandler, ResponseHandler]: nonlocal connected assert isinstance(sock, socket.socket) @@ -1288,8 +1277,8 @@ async def test_tcp_connector_resolve_host(loop: asyncio.AbstractEventLoop) -> No @pytest.fixture -def dns_response(loop: asyncio.AbstractEventLoop) -> Callable[[], Awaitable[List[str]]]: - async def coro() -> List[str]: +def dns_response(loop: asyncio.AbstractEventLoop) -> Callable[[], Awaitable[list[str]]]: + async def coro() -> list[str]: # simulates a network operation await asyncio.sleep(0) return ["127.0.0.1"] @@ -1298,7 +1287,7 @@ async def coro() -> List[str]: async def test_tcp_connector_dns_cache_not_expired( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) @@ -1312,7 +1301,7 @@ async def test_tcp_connector_dns_cache_not_expired( async def test_tcp_connector_dns_cache_forever( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) @@ -1326,7 +1315,7 @@ async def test_tcp_connector_dns_cache_forever( async def test_tcp_connector_use_dns_cache_disabled( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=False) @@ -1345,7 +1334,7 @@ async def test_tcp_connector_use_dns_cache_disabled( async def test_tcp_connector_dns_throttle_requests( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) @@ -1385,7 +1374,7 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread( async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) @@ -1452,7 +1441,7 @@ def exception_handler(loop: asyncio.AbstractEventLoop, context: object) -> None: async def test_tcp_connector_dns_tracing( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -1500,7 +1489,7 @@ async def test_tcp_connector_dns_tracing( async def test_tcp_connector_dns_tracing_cache_disabled( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -1558,7 +1547,7 @@ async def test_tcp_connector_dns_tracing_cache_disabled( async def test_tcp_connector_dns_tracing_throttle_requests( - loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[List[str]]] + loop: asyncio.AbstractEventLoop, dns_response: Callable[[], Awaitable[list[str]]] ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -1957,7 +1946,7 @@ async def test_cleanup(key: ConnectionKey) -> None: m2 = mock.Mock() m1.is_connected.return_value = True m2.is_connected.return_value = False - testset: DefaultDict[ConnectionKey, Deque[Tuple[ResponseHandler, float]]] = ( + testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[key] = deque([(m1, 10), (m2, 300)]) @@ -1982,7 +1971,7 @@ async def test_cleanup_close_ssl_transport( # type: ignore[misc] ) -> None: proto = create_mocked_conn(loop) transport = proto.transport - testset: DefaultDict[ConnectionKey, Deque[Tuple[ResponseHandler, float]]] = ( + testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[ssl_key] = deque([(proto, 10)]) @@ -2008,7 +1997,7 @@ async def test_cleanup_close_ssl_transport( # type: ignore[misc] async def test_cleanup2(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: m = create_mocked_conn() m.is_connected.return_value = True - testset: DefaultDict[ConnectionKey, Deque[Tuple[ResponseHandler, float]]] = ( + testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[key] = deque([(m, 300)]) @@ -2029,7 +2018,7 @@ async def test_cleanup2(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> async def test_cleanup3(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: m = create_mocked_conn(loop) m.is_connected.return_value = True - testset: DefaultDict[ConnectionKey, Deque[Tuple[ResponseHandler, float]]] = ( + testset: defaultdict[ConnectionKey, deque[tuple[ResponseHandler, float]]] = ( defaultdict(deque) ) testset[key] = deque([(m, 290.1), (create_mocked_conn(loop), 305.1)]) @@ -2796,7 +2785,7 @@ async def test_multiple_dns_resolution_requests_success( ) -> None: """Verify that multiple DNS resolution requests are handled correctly.""" - async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) @@ -2858,7 +2847,7 @@ async def test_multiple_dns_resolution_requests_failure( ) -> None: """Verify that DNS resolution failure for multiple requests is handled correctly.""" - async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) @@ -2911,7 +2900,7 @@ async def test_multiple_dns_resolution_requests_cancelled( ) -> None: """Verify that DNS resolution cancellation does not affect other tasks.""" - async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) @@ -2963,7 +2952,7 @@ async def test_multiple_dns_resolution_requests_first_cancelled( ) -> None: """Verify that first DNS resolution cancellation does not make other resolutions fail.""" - async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" for _ in range(3): await asyncio.sleep(0) @@ -3027,7 +3016,7 @@ async def test_multiple_dns_resolution_requests_first_fails_second_successful( """Verify that first DNS resolution fails the first time and is successful the second time.""" attempt = 0 - async def delay_resolve(*args: object, **kwargs: object) -> List[ResolveResult]: + async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: """Delayed resolve() task.""" nonlocal attempt for _ in range(3): @@ -3880,7 +3869,7 @@ async def handler(request: web.Request) -> web.Response: # resolving something.localhost with the real DNS resolver does not work on macOS, so we have a stub. async def _resolve_host( host: str, port: int, traces: object = None - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: return [ { "hostname": host, @@ -4221,7 +4210,7 @@ async def send_dns_cache_miss(self, *args: object, **kwargs: object) -> None: if request_count <= 1: raise Exception("first attempt") - async def resolve_response() -> List[ResolveResult]: + async def resolve_response() -> list[ResolveResult]: await asyncio.sleep(0) return [token] diff --git a/tests/test_cookiejar.py b/tests/test_cookiejar.py index 7477f018eda..9ca7bacecaf 100644 --- a/tests/test_cookiejar.py +++ b/tests/test_cookiejar.py @@ -7,7 +7,6 @@ from http.cookies import BaseCookie, Morsel, SimpleCookie from operator import not_ from pathlib import Path -from typing import List, Set, Tuple, Union from unittest import mock import pytest @@ -310,7 +309,7 @@ async def test_filter_cookies_str_deprecated() -> None: ) async def test_filter_cookies_with_domain_path_lookup_multilevelpath( url: str, - expected_cookies: Set[str], + expected_cookies: set[str], ) -> None: jar = CookieJar() cookie = SimpleCookie( @@ -439,7 +438,7 @@ def setup_cookies( def request_reply_with_same_url( self, url: str - ) -> Tuple["BaseCookie[str]", SimpleCookie]: + ) -> tuple["BaseCookie[str]", SimpleCookie]: jar = CookieJar() jar.update_cookies(self.cookies_to_send) cookies_sent = jar.filter_cookies(URL(url)) @@ -459,8 +458,8 @@ def timed_request( self, url: str, update_time: float, send_time: float ) -> "BaseCookie[str]": jar = CookieJar() - freeze_update_time: Union[datetime.datetime, datetime.timedelta] - freeze_send_time: Union[datetime.datetime, datetime.timedelta] + freeze_update_time: datetime.datetime | datetime.timedelta + freeze_send_time: datetime.datetime | datetime.timedelta if isinstance(update_time, int): freeze_update_time = datetime.timedelta(seconds=update_time) else: @@ -1073,7 +1072,7 @@ def test_pickle_format(cookies_to_send: SimpleCookie) -> None: ], ) async def test_treat_as_secure_origin_init( - url: Union[str, URL, List[str], List[URL]], + url: str | URL | list[str] | list[URL], ) -> None: jar = CookieJar(unsafe=True, treat_as_secure_origin=url) assert jar._treat_as_secure_origin == frozenset({URL("http://127.0.0.1")}) @@ -1204,12 +1203,12 @@ def test_update_cookies_from_headers_duplicate_names() -> None: assert len(jar) == 3 # Verify we have both session-id cookies - all_cookies: List[Morsel[str]] = list(jar) - session_ids: List[Morsel[str]] = [c for c in all_cookies if c.key == "session-id"] + all_cookies: list[Morsel[str]] = list(jar) + session_ids: list[Morsel[str]] = [c for c in all_cookies if c.key == "session-id"] assert len(session_ids) == 2 # Check their domains are different - domains: Set[str] = {c["domain"] for c in session_ids} + domains: set[str] = {c["domain"] for c in session_ids} assert domains == {"example.com", "www.example.com"} diff --git a/tests/test_helpers.py b/tests/test_helpers.py index e36b376bd73..fc9069729d5 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -4,9 +4,9 @@ import gc import sys import weakref +from collections.abc import Iterator from math import ceil, modf from pathlib import Path -from typing import Dict, Iterator, Optional, Union from unittest import mock from urllib.request import getproxies_environment @@ -528,7 +528,7 @@ async def test_ceil_timeout_small_with_overriden_threshold( ], ) def test_content_disposition( - params: Dict[str, str], quote_fields: bool, _charset: str, expected: str + params: dict[str, str], quote_fields: bool, _charset: str, expected: str ) -> None: result = helpers.content_disposition_header( "attachment", quote_fields=quote_fields, _charset=_charset, params=params @@ -601,8 +601,8 @@ def test_proxies_from_env_skipped( url = URL(url_input) assert helpers.proxies_from_env() == {} assert len(caplog.records) == 1 - log_message = "{proto!s} proxies {url!s} are not supported, ignoring".format( - proto=expected_scheme.upper(), url=url + log_message = ( + f"{expected_scheme.upper()!s} proxies {url!s} are not supported, ignoring" ) assert caplog.record_tuples == [("aiohttp.client", 30, log_message)] @@ -716,7 +716,7 @@ def test_get_env_proxy_for_url_negative(url_input: str, expected_err_msg: str) - "url_scheme_match_http_proxy_list_multiple", ), ) -def test_get_env_proxy_for_url(proxy_env_vars: Dict[str, str], url_input: str) -> None: +def test_get_env_proxy_for_url(proxy_env_vars: dict[str, str], url_input: str) -> None: url = URL(url_input) proxy, proxy_auth = helpers.get_env_proxy_for_url(url) proxy_list = proxy_env_vars[url.scheme + "_proxy"] @@ -760,7 +760,7 @@ async def test_set_exception_cancelled(loop: asyncio.AbstractEventLoop) -> None: # ----------- ChainMapProxy -------------------------- -AppKeyDict = Dict[Union[str, web.AppKey[object]], object] +AppKeyDict = dict[str | web.AppKey[object], object] class TestChainMapProxy: @@ -1040,7 +1040,7 @@ def test_populate_with_cookies() -> None: ), ], ) -def test_parse_http_date(value: str, expected: Optional[datetime.datetime]) -> None: +def test_parse_http_date(value: str, expected: datetime.datetime | None) -> None: assert parse_http_date(value) == expected diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 69e1fd7d4cc..4aa965e0ae4 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -3,8 +3,9 @@ import asyncio import re import sys +from collections.abc import Iterable from contextlib import suppress -from typing import Any, Dict, Iterable, List, Type +from typing import Any from unittest import mock from urllib.parse import quote @@ -59,7 +60,7 @@ def protocol() -> Any: return mock.create_autospec(BaseProtocol, spec_set=True, instance=True) -def _gen_ids(parsers: Iterable[Type[HttpParser[Any]]]) -> List[str]: +def _gen_ids(parsers: Iterable[type[HttpParser[Any]]]) -> list[str]: return [ "py-parser" if parser.__module__ == "aiohttp.http_parser" else "c-parser" for parser in parsers @@ -83,7 +84,7 @@ def parser( @pytest.fixture(params=REQUEST_PARSERS, ids=_gen_ids(REQUEST_PARSERS)) -def request_cls(request: pytest.FixtureRequest) -> Type[HttpRequestParser]: +def request_cls(request: pytest.FixtureRequest) -> type[HttpRequestParser]: # Request Parser class return request.param # type: ignore[no-any-return] @@ -106,7 +107,7 @@ def response( @pytest.fixture(params=RESPONSE_PARSERS, ids=_gen_ids(RESPONSE_PARSERS)) -def response_cls(request: pytest.FixtureRequest) -> Type[HttpResponseParser]: +def response_cls(request: pytest.FixtureRequest) -> type[HttpResponseParser]: # Parser implementations return request.param # type: ignore[no-any-return] @@ -644,7 +645,7 @@ def test_headers_content_length_err_2(parser: HttpRequestParser) -> None: parser.feed_data(text) -_pad: Dict[bytes, str] = { +_pad: dict[bytes, str] = { b"": "empty", # not a typo. Python likes triple zero b"\000": "NUL", @@ -804,7 +805,7 @@ def test_http_request_bad_status_line(parser: HttpRequestParser) -> None: assert r"\n" not in exc_info.value.message -_num: Dict[bytes, str] = { +_num: dict[bytes, str] = { # dangerous: accepted by Python int() # unicodedata.category("\U0001D7D9") == 'Nd' "\N{MATHEMATICAL DOUBLE-STRUCK DIGIT ONE}".encode(): "utf8digit", @@ -1356,7 +1357,7 @@ def test_parse_no_length_or_te_on_post( def test_parse_payload_response_without_body( loop: asyncio.AbstractEventLoop, protocol: BaseProtocol, - response_cls: Type[HttpResponseParser], + response_cls: type[HttpResponseParser], ) -> None: parser = response_cls(protocol, loop, 2**16, response_with_body=False) text = b"HTTP/1.1 200 Ok\r\ncontent-length: 10\r\n\r\n" @@ -1574,7 +1575,7 @@ def test_partial_url(parser: HttpRequestParser) -> None: ], ) def test_parse_uri_percent_encoded( - parser: HttpRequestParser, uri: str, path: str, query: Dict[str, str], fragment: str + parser: HttpRequestParser, uri: str, path: str, query: dict[str, str], fragment: str ) -> None: text = (f"GET {uri} HTTP/1.1\r\n\r\n").encode() messages, upgrade, tail = parser.feed_data(text) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 1d6cf439e4e..e6b0bac97a1 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -2,7 +2,8 @@ import array import asyncio import zlib -from typing import Any, Generator, Iterable, Union +from collections.abc import Generator, Iterable +from typing import Any from unittest import mock import pytest @@ -66,7 +67,7 @@ def decompress(data: bytes) -> bytes: return d.decompress(data) -def decode_chunked(chunked: Union[bytes, bytearray]) -> bytes: +def decode_chunked(chunked: bytes | bytearray) -> bytes: i = 0 out = b"" while i < len(chunked): diff --git a/tests/test_multipart.py b/tests/test_multipart.py index f0efa7284bc..fcceec12396 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -4,7 +4,6 @@ import pathlib import sys from types import TracebackType -from typing import Dict, Optional, Tuple, Type from unittest import mock import pytest @@ -79,7 +78,7 @@ class Stream(StreamReader): def __init__(self, content: bytes) -> None: self.content = io.BytesIO(content) - async def read(self, size: Optional[int] = None) -> bytes: + async def read(self, size: int | None = None) -> bytes: return self.content.read(size) def at_eof(self) -> bool: @@ -96,9 +95,9 @@ def __enter__(self) -> Self: def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.content.close() @@ -114,7 +113,7 @@ def __init__(self, content: bytes) -> None: self._first = True super().__init__(content) - async def read(self, size: Optional[int] = None) -> bytes: + async def read(self, size: int | None = None) -> bytes: if size is not None and self._first: self._first = False size = size // 2 @@ -718,7 +717,6 @@ async def test_invalid_boundary(self) -> None: with pytest.raises(ValueError): await reader.next() - @pytest.mark.skipif(sys.version_info < (3, 10), reason="Needs anext()") async def test_read_boundary_across_chunks(self) -> None: class SplitBoundaryStream(StreamReader): def __init__(self) -> None: @@ -734,7 +732,7 @@ def __init__(self) -> None: b"oobar--", ] - async def read(self, size: Optional[int] = None) -> bytes: + async def read(self, size: int | None = None) -> bytes: chunk = self.content.pop(0) assert size is not None and len(chunk) <= size return chunk @@ -1448,7 +1446,7 @@ async def test_reset_content_disposition_header( async def test_async_for_reader() -> None: - data: Tuple[Dict[str, str], int, bytes, bytes, bytes] = ( + data: tuple[dict[str, str], int, bytes, bytes, bytes] = ( {"test": "passed"}, 42, b"plain text", diff --git a/tests/test_payload.py b/tests/test_payload.py index aef64aaa53d..cb38cb5a6d0 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -3,9 +3,10 @@ import io import json import unittest.mock +from collections.abc import AsyncIterator, Iterator from io import StringIO from pathlib import Path -from typing import AsyncIterator, Iterator, List, Optional, TextIO, Union +from typing import TextIO, Union import pytest from multidict import CIMultiDict @@ -33,7 +34,7 @@ async def drain(self) -> None: """No-op for test writer.""" def enable_compression( - self, encoding: str = "deflate", strategy: Optional[int] = None + self, encoding: str = "deflate", strategy: int | None = None ) -> None: """Compression not implemented for test writer.""" @@ -170,7 +171,7 @@ class MockStreamWriter(AbstractStreamWriter): """Mock stream writer for testing payload writes.""" def __init__(self) -> None: - self.written: List[bytes] = [] + self.written: list[bytes] = [] async def write( self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] @@ -178,14 +179,14 @@ async def write( """Store the chunk in the written list.""" self.written.append(bytes(chunk)) - async def write_eof(self, chunk: Optional[bytes] = None) -> None: + async def write_eof(self, chunk: bytes | None = None) -> None: """write_eof implementation - no-op for tests.""" async def drain(self) -> None: """Drain implementation - no-op for tests.""" def enable_compression( - self, encoding: str = "deflate", strategy: Optional[int] = None + self, encoding: str = "deflate", strategy: int | None = None ) -> None: """Enable compression - no-op for tests.""" @@ -310,7 +311,7 @@ async def test_bytesio_payload_write_with_length_remaining_zero() -> None: original_read = bio.read read_calls = 0 - def mock_read(size: Optional[int] = None) -> bytes: + def mock_read(size: int | None = None) -> bytes: nonlocal read_calls read_calls += 1 if read_calls == 1: @@ -414,9 +415,9 @@ async def test_iobase_payload_large_content_length() -> None: class TrackingBytesIO(io.BytesIO): def __init__(self, data: bytes) -> None: super().__init__(data) - self.read_sizes: List[int] = [] + self.read_sizes: list[int] = [] - def read(self, size: Optional[int] = -1) -> bytes: + def read(self, size: int | None = -1) -> bytes: self.read_sizes.append(size if size is not None else -1) return super().read(size) @@ -489,9 +490,9 @@ async def test_textio_payload_large_content_length() -> None: class TrackingStringIO(io.StringIO): def __init__(self, data: str) -> None: super().__init__(data) - self.read_sizes: List[int] = [] + self.read_sizes: list[int] = [] - def read(self, size: Optional[int] = -1) -> str: + def read(self, size: int | None = -1) -> str: self.read_sizes.append(size if size is not None else -1) return super().read(size) diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 89aba898c27..dc30bd36f5c 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -4,18 +4,10 @@ import platform import ssl import sys +from collections.abc import Awaitable, Callable, Iterator from contextlib import suppress from re import match as match_regex -from typing import ( - TYPE_CHECKING, - Awaitable, - Callable, - Dict, - Iterator, - Optional, - TypedDict, - Union, -) +from typing import TYPE_CHECKING, TypedDict from unittest import mock from uuid import uuid4 @@ -35,8 +27,8 @@ class _ResponseArgs(TypedDict): status: int - headers: Optional[Dict[str, str]] - body: Optional[bytes] + headers: dict[str, str] | None + body: bytes | None if sys.version_info >= (3, 11) and TYPE_CHECKING: @@ -45,7 +37,7 @@ class _ResponseArgs(TypedDict): async def get_request( method: str = "GET", *, - url: Union[str, URL], + url: str | URL, trust_env: bool = False, **kwargs: Unpack[_RequestOptions], ) -> ClientResponse: ... @@ -56,7 +48,7 @@ async def get_request( async def get_request( method: str = "GET", *, - url: Union[str, URL], + url: str | URL, trust_env: bool = False, **kwargs: Any, ) -> ClientResponse: @@ -840,10 +832,7 @@ async def test_proxy_from_env_http_with_auth_from_netrc( proxy = await proxy_test_server() auth = aiohttp.BasicAuth("user", "pass") netrc_file = tmp_path / "test_netrc" - netrc_file_data = "machine 127.0.0.1 login {} password {}".format( - auth.login, - auth.password, - ) + netrc_file_data = f"machine 127.0.0.1 login {auth.login} password {auth.password}" with netrc_file.open("w") as f: f.write(netrc_file_data) mocker.patch.dict( @@ -868,10 +857,7 @@ async def test_proxy_from_env_http_without_auth_from_netrc( proxy = await proxy_test_server() auth = aiohttp.BasicAuth("user", "pass") netrc_file = tmp_path / "test_netrc" - netrc_file_data = "machine 127.0.0.2 login {} password {}".format( - auth.login, - auth.password, - ) + netrc_file_data = f"machine 127.0.0.2 login {auth.login} password {auth.password}" with netrc_file.open("w") as f: f.write(netrc_file_data) mocker.patch.dict( diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 3a9d1a70a23..13e494d24cd 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -2,19 +2,9 @@ import gc import ipaddress import socket -from collections.abc import Generator +from collections.abc import Awaitable, Callable, Collection, Generator, Iterable from ipaddress import ip_address -from typing import ( - Any, - Awaitable, - Callable, - Collection, - Iterable, - List, - NamedTuple, - Tuple, - Union, -) +from typing import Any, NamedTuple from unittest.mock import Mock, create_autospec, patch import pytest @@ -35,16 +25,16 @@ aiodns = None # type: ignore[assignment] getaddrinfo = False -_AddrInfo4 = List[ - Tuple[socket.AddressFamily, None, socket.SocketKind, None, Tuple[str, int]] +_AddrInfo4 = list[ + tuple[socket.AddressFamily, None, socket.SocketKind, None, tuple[str, int]] ] -_AddrInfo6 = List[ - Tuple[ - socket.AddressFamily, None, socket.SocketKind, None, Tuple[str, int, int, int] +_AddrInfo6 = list[ + tuple[ + socket.AddressFamily, None, socket.SocketKind, None, tuple[str, int, int, int] ] ] -_UnknownAddrInfo = List[ - Tuple[socket.AddressFamily, socket.SocketKind, int, str, Tuple[int, bytes]] +_UnknownAddrInfo = list[ + tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[int, bytes]] ] @@ -93,7 +83,7 @@ def dns_resolver_manager() -> Generator[_DNSResolverManager, None, None]: class FakeAIODNSAddrInfoNode(NamedTuple): family: int - addr: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] + addr: tuple[bytes, int] | tuple[bytes, int, int, int] class FakeAIODNSAddrInfoIPv4Result: @@ -143,7 +133,7 @@ async def fake_aiodns_getnameinfo_ipv6_result( return FakeAIODNSNameInfoIPv6Result(host) -async def fake_query_result(result: Iterable[str]) -> List[FakeQueryResult]: +async def fake_query_result(result: Iterable[str]) -> list[FakeQueryResult]: return [FakeQueryResult(host=h) for h in result] @@ -176,8 +166,8 @@ async def fake(*args: Any, **kwargs: Any) -> _AddrInfo6: return fake -def fake_ipv6_nameinfo(host: str) -> Callable[..., Awaitable[Tuple[str, int]]]: - async def fake(*args: Any, **kwargs: Any) -> Tuple[str, int]: +def fake_ipv6_nameinfo(host: str) -> Callable[..., Awaitable[tuple[str, int]]]: + async def fake(*args: Any, **kwargs: Any) -> tuple[str, int]: return host, 0 return fake @@ -321,7 +311,7 @@ async def test_threaded_resolver_multiple_replies() -> None: async def test_threaded_negative_lookup() -> None: loop = Mock() - ips: List[str] = [] + ips: list[str] = [] loop.getaddrinfo = fake_addrinfo(ips) resolver = ThreadedResolver() resolver._loop = loop @@ -331,7 +321,7 @@ async def test_threaded_negative_lookup() -> None: async def test_threaded_negative_ipv6_lookup() -> None: loop = Mock() - ips: List[str] = [] + ips: list[str] = [] loop.getaddrinfo = fake_ipv6_addrinfo(ips) resolver = ThreadedResolver() resolver._loop = loop diff --git a/tests/test_run_app.py b/tests/test_run_app.py index dc2d791ad1f..e4f306b364b 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -9,20 +9,8 @@ import subprocess import sys import time -from typing import ( - Any, - AsyncIterator, - Awaitable, - Callable, - Coroutine, - Dict, - Iterator, - List, - NoReturn, - Optional, - Set, - Tuple, -) +from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine, Iterator +from typing import Any, NoReturn from unittest import mock from uuid import uuid4 @@ -175,8 +163,8 @@ def test_run_app_close_loop(patched_loop: asyncio.AbstractEventLoop) -> None: ) ] mock_socket = mock.Mock(getsockname=lambda: ("mock-socket", 123)) -mixed_bindings_tests: Tuple[ - Tuple[str, Dict[str, Any], List[mock._Call], List[mock._Call]], ... +mixed_bindings_tests: tuple[ + tuple[str, dict[str, Any], list[mock._Call], list[mock._Call]], ... ] = ( ( "Nothing Specified", @@ -442,9 +430,9 @@ def test_run_app_close_loop(patched_loop: asyncio.AbstractEventLoop) -> None: ids=mixed_bindings_test_ids, ) def test_run_app_mixed_bindings( # type: ignore[misc] - run_app_kwargs: Dict[str, Any], - expected_server_calls: List[mock._Call], - expected_unix_server_calls: List[mock._Call], + run_app_kwargs: dict[str, Any], + expected_server_calls: list[mock._Call], + expected_unix_server_calls: list[mock._Call], patched_loop: asyncio.AbstractEventLoop, ) -> None: app = web.Application() @@ -1024,13 +1012,13 @@ def run_app( sock: socket.socket, timeout: int, task: Callable[[], Coroutine[None, None, None]], - extra_test: Optional[Callable[[ClientSession], Awaitable[None]]] = None, - ) -> Tuple["asyncio.Task[None]", int]: + extra_test: Callable[[ClientSession], Awaitable[None]] | None = None, + ) -> tuple["asyncio.Task[None]", int]: num_connections = -1 t = test_task = None port = sock.getsockname()[1] - class DictRecordClear(Dict[RequestHandler[web.Request], asyncio.Transport]): + class DictRecordClear(dict[RequestHandler[web.Request], asyncio.Transport]): def clear(self) -> None: nonlocal num_connections # During Server.shutdown() we want to know how many connections still @@ -1255,7 +1243,7 @@ async def run_test(app: web.Application) -> AsyncIterator[None]: def test_shutdown_close_websockets(self, unused_port_socket: socket.socket) -> None: sock = unused_port_socket port = sock.getsockname()[1] - WS = web.AppKey("ws", Set[web.WebSocketResponse]) + WS = web.AppKey("ws", set[web.WebSocketResponse]) client_finished = server_finished = False t = None diff --git a/tests/test_streams.py b/tests/test_streams.py index 4305f892eea..b59eb77db96 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -5,8 +5,9 @@ import gc import types from collections import defaultdict +from collections.abc import Iterator, Sequence from itertools import groupby -from typing import DefaultDict, Iterator, Sequence, TypeVar +from typing import TypeVar from unittest import mock import pytest @@ -49,7 +50,7 @@ def get_memory_usage(obj: object) -> int: objs = [obj] # Memory leak may be caused by leaked links to same objects. # Without link counting, [1,2,3] is indistinguishable from [1,2,3,3,3,3,3,3] - known: DefaultDict[int, int] = defaultdict(int) + known: defaultdict[int, int] = defaultdict(int) known[id(obj)] += 1 while objs: diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 026be2f906c..bbab015061f 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -2,7 +2,8 @@ import gzip import socket import sys -from typing import Iterator, Mapping, NoReturn +from collections.abc import Iterator, Mapping +from typing import NoReturn from unittest import mock import pytest diff --git a/tests/test_tracing.py b/tests/test_tracing.py index e50d908b2ae..7ee7e6ae6d7 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -1,6 +1,6 @@ import sys from types import SimpleNamespace -from typing import Any, Tuple +from typing import Any from unittest import mock from unittest.mock import Mock @@ -119,7 +119,7 @@ class TestTrace: ], ) async def test_send( # type: ignore[misc] - self, signal: str, params: Tuple[Mock, ...], param_obj: Any + self, signal: str, params: tuple[Mock, ...], param_obj: Any ) -> None: session = Mock() trace_request_ctx = Mock() diff --git a/tests/test_urldispatch.py b/tests/test_urldispatch.py index 01eb686fd3c..abe13726353 100644 --- a/tests/test_urldispatch.py +++ b/tests/test_urldispatch.py @@ -1,9 +1,17 @@ import asyncio import pathlib import re -from collections.abc import Container, Iterable, Mapping, MutableMapping, Sized +from collections.abc import ( + Awaitable, + Callable, + Container, + Iterable, + Mapping, + MutableMapping, + Sized, +) from functools import partial -from typing import Awaitable, Callable, Dict, List, NoReturn, Optional, Type +from typing import NoReturn from urllib.parse import quote, unquote import pytest @@ -46,8 +54,8 @@ def router(app: web.Application) -> web.UrlDispatcher: @pytest.fixture -def fill_routes(router: web.UrlDispatcher) -> Callable[[], List[web.AbstractRoute]]: - def go() -> List[web.AbstractRoute]: +def fill_routes(router: web.UrlDispatcher) -> Callable[[], list[web.AbstractRoute]]: + def go() -> list[web.AbstractRoute]: route1 = router.add_route("GET", "/plain", make_handler()) route2 = router.add_route("GET", "/variable/{name}", make_handler()) resource = router.add_static("/static", pathlib.Path(aiohttp.__file__).parent) @@ -614,7 +622,7 @@ def test_static_remove_trailing_slash(router: web.UrlDispatcher) -> None: ), ) async def test_add_route_with_re( - router: web.UrlDispatcher, pattern: str, url: str, expected: Dict[str, str] + router: web.UrlDispatcher, pattern: str, url: str, expected: dict[str, str] ) -> None: handler = make_handler() router.add_route("GET", f"/handler/{pattern}", handler) @@ -853,21 +861,21 @@ def test_add_route_invalid_method(router: web.UrlDispatcher) -> None: def test_routes_view_len( - router: web.UrlDispatcher, fill_routes: Callable[[], List[web.AbstractRoute]] + router: web.UrlDispatcher, fill_routes: Callable[[], list[web.AbstractRoute]] ) -> None: fill_routes() assert 4 == len(router.routes()) def test_routes_view_iter( - router: web.UrlDispatcher, fill_routes: Callable[[], List[web.AbstractRoute]] + router: web.UrlDispatcher, fill_routes: Callable[[], list[web.AbstractRoute]] ) -> None: routes = fill_routes() assert list(routes) == list(router.routes()) def test_routes_view_contains( - router: web.UrlDispatcher, fill_routes: Callable[[], List[web.AbstractRoute]] + router: web.UrlDispatcher, fill_routes: Callable[[], list[web.AbstractRoute]] ) -> None: routes = fill_routes() for route in routes: @@ -1119,7 +1127,7 @@ def test_subapp_get_info(app: web.Application) -> None: ("example$com", ValueError), ], ) -def test_domain_validation_error(domain: Optional[str], error: Type[Exception]) -> None: +def test_domain_validation_error(domain: str | None, error: type[Exception]) -> None: with pytest.raises(error): Domain(domain) # type: ignore[arg-type] diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 41a4fcba7ff..2d2d21dbc42 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -1,5 +1,7 @@ import asyncio -from typing import AsyncIterator, Callable, Iterator, NoReturn, Type +import sys +from collections.abc import AsyncIterator, Callable, Iterator +from typing import NoReturn from unittest import mock import pytest @@ -126,20 +128,34 @@ def test_appkey_repr_concrete() -> None: def test_appkey_repr_nonconcrete() -> None: key = web.AppKey("key", Iterator[int]) - assert repr(key) in ( - # pytest-xdist: - "", - "", - ) + if sys.version_info < (3, 11): + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + else: + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) def test_appkey_repr_annotated() -> None: key = web.AppKey[Iterator[int]]("key") - assert repr(key) in ( - # pytest-xdist: - "", - "", - ) + if sys.version_info < (3, 11): + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + else: + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) def test_app_str_keys() -> None: @@ -336,7 +352,7 @@ async def fail_ctx(app: web.Application) -> AsyncIterator[NoReturn]: @pytest.mark.parametrize("exc_cls", (Exception, asyncio.CancelledError)) async def test_cleanup_ctx_exception_on_cleanup_multiple( - exc_cls: Type[BaseException], + exc_cls: type[BaseException], ) -> None: app = web.Application() out = [] diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index c7e156ad875..9371f860af8 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -1,7 +1,8 @@ import collections import pickle +from collections.abc import Mapping from traceback import format_exception -from typing import Mapping, NoReturn +from typing import NoReturn import pytest from yarl import URL diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 42c3edf20d1..5ff56cc2dab 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -4,17 +4,8 @@ import pathlib import socket import sys -from typing import ( - AsyncIterator, - Awaitable, - Callable, - Dict, - Generator, - List, - NoReturn, - Optional, - Tuple, -) +from collections.abc import AsyncIterator, Awaitable, Callable, Generator +from typing import NoReturn from unittest import mock import pytest @@ -650,7 +641,7 @@ async def test_expect_handler_custom_response(aiohttp_client: AiohttpClient) -> async def handler(request: web.Request) -> web.Response: return web.Response(text="handler") - async def expect_handler(request: web.Request) -> Optional[web.Response]: + async def expect_handler(request: web.Request) -> web.Response | None: k = request.headers["X-Key"] cached_value = cache.get(k) return web.Response(text=cached_value) if cached_value else None @@ -1103,11 +1094,11 @@ async def handler(request: web.Request) -> web.Response: def compressor_case( request: pytest.FixtureRequest, parametrize_zlib_backend: None, -) -> Generator[Tuple[ZLibCompressObjProtocol, str], None, None]: +) -> Generator[tuple[ZLibCompressObjProtocol, str], None, None]: encoding: str = request.param max_wbits: int = ZLibBackend.MAX_WBITS - encoding_to_wbits: Dict[str, int] = { + encoding_to_wbits: dict[str, int] = { "deflate": max_wbits, "deflate-raw": -max_wbits, "gzip": 16 + max_wbits, @@ -1119,7 +1110,7 @@ def compressor_case( async def test_response_with_precompressed_body( aiohttp_client: AiohttpClient, - compressor_case: Tuple[ZLibCompressObjProtocol, str], + compressor_case: tuple[ZLibCompressObjProtocol, str], ) -> None: compressor, encoding = compressor_case @@ -1592,7 +1583,7 @@ async def on_signal(app: web.Application) -> None: ], ) async def test_subapp_middleware_context( - aiohttp_client: AiohttpClient, route: str, expected: List[str], middlewares: str + aiohttp_client: AiohttpClient, route: str, expected: list[str], middlewares: str ) -> None: values = [] @@ -2052,7 +2043,7 @@ async def redirected(request: web.Request) -> web.Response: class FakeResolver(AbstractResolver): _LOCAL_HOST = {0: "127.0.0.1", socket.AF_INET: "127.0.0.1"} - def __init__(self, fakes: Dict[str, int]): + def __init__(self, fakes: dict[str, int]): # fakes -- dns -> port dict self._fakes = fakes self._resolver = aiohttp.DefaultResolver() @@ -2065,7 +2056,7 @@ async def resolve( host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET, - ) -> List[ResolveResult]: + ) -> list[ResolveResult]: fake_port = self._fakes.get(host) assert fake_port is not None return [ @@ -2295,7 +2286,7 @@ async def handler(request: web.Request) -> web.Response: async def test_keepalive_race_condition(aiohttp_client: AiohttpClient) -> None: - protocol: Optional[RequestHandler[web.Request]] = None + protocol: RequestHandler[web.Request] | None = None orig_data_received = RequestHandler.data_received def delay_received(self: RequestHandler[web.Request], data: bytes) -> None: diff --git a/tests/test_web_log.py b/tests/test_web_log.py index 6456735de4a..70419e7d70a 100644 --- a/tests/test_web_log.py +++ b/tests/test_web_log.py @@ -3,7 +3,7 @@ import platform import sys from contextvars import ContextVar -from typing import Dict, NoReturn, Optional +from typing import NoReturn from unittest import mock import pytest @@ -89,11 +89,11 @@ def test_access_logger_atoms( monkeypatch: pytest.MonkeyPatch, log_format: str, expected: str, - extra: Dict[str, object], + extra: dict[str, object], ) -> None: class PatchedDatetime(datetime.datetime): @classmethod - def now(cls, tz: Optional[datetime.tzinfo] = None) -> Self: + def now(cls, tz: datetime.tzinfo | None = None) -> Self: return cls(1843, 1, 1, 0, 30, tzinfo=tz) monkeypatch.setattr("datetime.datetime", PatchedDatetime) diff --git a/tests/test_web_middleware.py b/tests/test_web_middleware.py index 5b8f1c78166..00301eccb31 100644 --- a/tests/test_web_middleware.py +++ b/tests/test_web_middleware.py @@ -1,5 +1,6 @@ import asyncio -from typing import Awaitable, Callable, Iterable, NoReturn +from collections.abc import Awaitable, Callable, Iterable +from typing import NoReturn import pytest from yarl import URL diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 51d6e1b108b..2f59fc77bed 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -4,7 +4,7 @@ import ssl import weakref from collections.abc import MutableMapping -from typing import NoReturn, Optional, Tuple +from typing import NoReturn from unittest import mock import pytest @@ -1026,7 +1026,7 @@ def test_weakref_creation() -> None: ), ) def test_etag_headers( - header: str, header_attr: str, header_val: str, expected: Tuple[ETag, ...] + header: str, header_attr: str, header_val: str, expected: tuple[ETag, ...] ) -> None: req = make_mocked_request("GET", "/", headers={header: header_val}) assert getattr(req, header_attr) == expected @@ -1056,7 +1056,7 @@ def test_datetime_headers( header: str, header_attr: str, header_val: str, - expected: Optional[datetime.datetime], + expected: datetime.datetime | None, ) -> None: req = make_mocked_request("GET", "/", headers={header: header_val}) assert getattr(req, header_attr) == expected diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 8a47aa2f873..8e2864fc5fc 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -5,8 +5,8 @@ import json import re import weakref +from collections.abc import AsyncIterator from concurrent.futures import ThreadPoolExecutor -from typing import AsyncIterator, Optional, Union from unittest import mock import aiosignal @@ -29,8 +29,8 @@ def make_request( headers: LooseHeaders = CIMultiDict(), version: HttpVersion = HttpVersion11, *, - app: Optional[web.Application] = None, - writer: Optional[AbstractStreamWriter] = None, + app: web.Application | None = None, + writer: AbstractStreamWriter | None = None, ) -> web.Request: if app is None: app = mock.create_autospec( @@ -305,7 +305,7 @@ def test_etag_any() -> None: ETag(value="bad ©®"), ), ) -def test_etag_invalid_value_set(invalid_value: Union[str, ETag]) -> None: +def test_etag_invalid_value_set(invalid_value: str | ETag) -> None: resp = web.StreamResponse() with pytest.raises(ValueError, match="is not a valid etag"): resp.etag = invalid_value @@ -325,7 +325,7 @@ def test_etag_invalid_value_get(header: str) -> None: @pytest.mark.parametrize("invalid", (123, ETag(value=123, is_weak=True))) # type: ignore[arg-type] -def test_etag_invalid_value_class(invalid: Union[int, ETag]) -> None: +def test_etag_invalid_value_class(invalid: int | ETag) -> None: resp = web.StreamResponse() with pytest.raises(ValueError, match="Unsupported etag type"): resp.etag = invalid # type: ignore[assignment] @@ -1115,7 +1115,7 @@ def read(self, size: int = -1) -> bytes: ), ), ) -def test_payload_body_get_text(payload: object, expected: Optional[str]) -> None: +def test_payload_body_get_text(payload: object, expected: str | None) -> None: resp = web.Response(body=payload) if expected is None: with pytest.raises(TypeError): diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index ff796bebde5..5631deb526a 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -1,7 +1,8 @@ import asyncio import platform import signal -from typing import Any, Iterator, NoReturn, Protocol, Union +from collections.abc import Iterator +from typing import Any, NoReturn, Protocol from unittest import mock import pytest @@ -158,7 +159,7 @@ async def test_app_handler_args_failure() -> None: ), ) async def test_app_handler_args_ceil_threshold( - value: Union[int, str, None], expected: int + value: int | str | None, expected: int ) -> None: app = web.Application(handler_args={"timeout_ceil_threshold": value}) runner = web.AppRunner(app) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 591d7af7fce..5a44207b7c1 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -3,7 +3,8 @@ import gzip import pathlib import socket -from typing import Iterable, Iterator, NoReturn, Optional, Protocol, Tuple +from collections.abc import Iterable, Iterator +from typing import NoReturn, Protocol from unittest import mock import pytest @@ -320,7 +321,7 @@ async def test_static_file_with_encoding_and_enable_compression( sender: _Sender, accept_encoding: str, expect_encoding: str, - forced_compression: Optional[web.ContentCoding], + forced_compression: web.ContentCoding | None, ) -> None: """Test that enable_compression does not double compress when an encoded file is also present.""" @@ -491,7 +492,7 @@ async def test_static_file_if_match_custom_tags( aiohttp_client: AiohttpClient, app_with_static_route: web.Application, if_unmodified_since: str, - etags: Tuple[str], + etags: tuple[str], expected_status: int, ) -> None: client = await aiohttp_client(app_with_static_route) @@ -721,18 +722,18 @@ async def handler(request: web.Request) -> web.FileResponse: ) assert len(responses) == 3 assert responses[0].status == 206, "failed 'bytes=0-999': %s" % responses[0].reason - assert responses[0].headers["Content-Range"] == "bytes 0-999/{}".format( - filesize + assert ( + responses[0].headers["Content-Range"] == f"bytes 0-999/{filesize}" ), "failed: Content-Range Error" assert responses[1].status == 206, ( "failed 'bytes=1000-1999': %s" % responses[1].reason ) - assert responses[1].headers["Content-Range"] == "bytes 1000-1999/{}".format( - filesize + assert ( + responses[1].headers["Content-Range"] == f"bytes 1000-1999/{filesize}" ), "failed: Content-Range Error" assert responses[2].status == 206, "failed 'bytes=2000-': %s" % responses[2].reason - assert responses[2].headers["Content-Range"] == "bytes 2000-{}/{}".format( - filesize - 1, filesize + assert ( + responses[2].headers["Content-Range"] == f"bytes 2000-{filesize - 1}/{filesize}" ), "failed: Content-Range Error" body = await asyncio.gather( diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py index a26e3d7ae9b..50b1ad2e9e2 100644 --- a/tests/test_web_urldispatcher.py +++ b/tests/test_web_urldispatcher.py @@ -4,8 +4,9 @@ import pathlib import socket import sys +from collections.abc import Generator from stat import S_IFIFO, S_IMODE -from typing import Any, Generator, NoReturn, Optional +from typing import Any, NoReturn import pytest import yarl @@ -57,7 +58,7 @@ async def test_access_root_of_static_handler( status: int, prefix: str, request_path: str, - data: Optional[bytes], + data: bytes | None, ) -> None: # Tests the operation of static file server. # Try to access the root of static file server, and make @@ -142,7 +143,7 @@ async def test_access_root_of_static_handler_xss( status: int, prefix: str, request_path: str, - data: Optional[bytes], + data: bytes | None, ) -> None: # Tests the operation of static file server. # Try to access the root of static file server, and make diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index d082f7d5d47..33380d94560 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Optional, Protocol +from typing import Protocol from unittest import mock import aiosignal @@ -21,7 +21,7 @@ def __call__( self, method: str, path: str, - headers: Optional[CIMultiDict[str]] = None, + headers: CIMultiDict[str] | None = None, protocols: bool = False, ) -> web.Request: ... @@ -48,7 +48,7 @@ def make_request( def maker( method: str, path: str, - headers: Optional[CIMultiDict[str]] = None, + headers: CIMultiDict[str] | None = None, protocols: bool = False, ) -> web.Request: if headers is None: @@ -658,7 +658,7 @@ async def test_no_transfer_encoding_header( async def test_get_extra_info( make_request: _RequestMaker, mocker: MockerFixture, - ws_transport: Optional[mock.MagicMock], + ws_transport: mock.MagicMock | None, expected_result: str, ) -> None: valid_key = "test" diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index faee34cf811..afa76e2d742 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -4,7 +4,7 @@ import contextlib import sys import weakref -from typing import NoReturn, Optional +from typing import NoReturn from unittest import mock import pytest @@ -381,7 +381,7 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async def test_close_op_code_from_client( loop: asyncio.AbstractEventLoop, aiohttp_client: AiohttpClient ) -> None: - srv_ws: Optional[web.WebSocketResponse] = None + srv_ws: web.WebSocketResponse | None = None async def handler(request: web.Request) -> web.WebSocketResponse: nonlocal srv_ws @@ -1271,7 +1271,7 @@ async def test_abnormal_closure_when_client_does_not_close( aiohttp_client: AiohttpClient, ) -> None: """Test abnormal closure when the server closes and the client doesn't respond.""" - close_code: Optional[WSCloseCode] = None + close_code: WSCloseCode | None = None async def handler(request: web.Request) -> web.WebSocketResponse: # Setting a short close timeout @@ -1298,7 +1298,7 @@ async def test_normal_closure_while_client_sends_msg( aiohttp_client: AiohttpClient, ) -> None: """Test normal closure when the server closes and the client responds properly.""" - close_code: Optional[WSCloseCode] = None + close_code: WSCloseCode | None = None got_close_code = asyncio.Event() async def handler(request: web.Request) -> web.WebSocketResponse: diff --git a/tests/test_websocket_handshake.py b/tests/test_websocket_handshake.py index e069795af73..d347407fac1 100644 --- a/tests/test_websocket_handshake.py +++ b/tests/test_websocket_handshake.py @@ -2,7 +2,6 @@ import base64 import os -from typing import List, Tuple import pytest @@ -16,7 +15,7 @@ def gen_ws_headers( extension_text: str = "", server_notakeover: bool = False, client_notakeover: bool = False, -) -> Tuple[List[Tuple[str, str]], str]: +) -> tuple[list[tuple[str, str]], str]: key = base64.b64encode(os.urandom(16)).decode() hdrs = [ ("Upgrade", "websocket"), diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 808cac3380a..ce181b43f73 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -2,7 +2,6 @@ import pickle import random import struct -from typing import Optional, Union from unittest import mock import pytest @@ -36,14 +35,14 @@ class PatchableWebSocketReader(WebSocketReader): def parse_frame( self, data: bytes - ) -> list[tuple[bool, int, Union[bytes, bytearray], int]]: + ) -> list[tuple[bool, int, bytes | bytearray, int]]: # This method is overridden to allow for patching in tests. - frames: list[tuple[bool, int, Union[bytes, bytearray], int]] = [] + frames: list[tuple[bool, int, bytes | bytearray, int]] = [] def _handle_frame( fin: bool, opcode: int, - payload: Union[bytes, bytearray], + payload: bytes | bytearray, compressed: int, ) -> None: # This method is overridden to allow for patching in tests. @@ -59,7 +58,7 @@ def build_frame( opcode: int, noheader: bool = False, is_fin: bool = True, - ZLibBackend: Optional[ZLibBackendWrapper] = None, + ZLibBackend: ZLibBackendWrapper | None = None, mask: bool = False, ) -> bytes: # Send a frame over the websocket with message as its payload. @@ -262,7 +261,7 @@ def test_parse_frame_header_payload_size( def test_ping_frame( out: WebSocketDataQueue, parser: PatchableWebSocketReader, - data: Union[bytes, bytearray, memoryview], + data: bytes | bytearray | memoryview, ) -> None: parser._handle_frame(True, WSMsgType.PING, b"data", 0) res = out._buffer[0] diff --git a/tests/test_websocket_writer.py b/tests/test_websocket_writer.py index 313290349a5..3fcd9f06eb4 100644 --- a/tests/test_websocket_writer.py +++ b/tests/test_websocket_writer.py @@ -1,6 +1,6 @@ import asyncio import random -from typing import Callable +from collections.abc import Callable from unittest import mock import pytest diff --git a/tests/test_worker.py b/tests/test_worker.py index afaf9814e44..0c0be51e53e 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -3,7 +3,7 @@ import os import socket import ssl -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING from unittest import mock import pytest @@ -29,9 +29,9 @@ class BaseTestWorker: def __init__(self) -> None: - self.servers: Dict[object, object] = {} + self.servers: dict[object, object] = {} self.exit_code = 0 - self._notify_waiter: Optional[asyncio.Future[bool]] = None + self._notify_waiter: asyncio.Future[bool] | None = None self.cfg = mock.Mock() self.cfg.graceful_timeout = 100 self.pid = "pid" diff --git a/tools/bench-asyncio-write.py b/tools/bench-asyncio-write.py index 3c35f295a58..9450da37fc0 100644 --- a/tools/bench-asyncio-write.py +++ b/tools/bench-asyncio-write.py @@ -3,7 +3,6 @@ import math import os import signal -from typing import List, Tuple PORT = 8888 @@ -52,7 +51,7 @@ def fm_time(s, _fms=("", "m", "µ", "n")): return f"{s:.2f}{_fms[i]}s" -def _job(j: List[int]) -> Tuple[str, List[bytes]]: +def _job(j: list[int]) -> tuple[str, list[bytes]]: # Always start with a 256B headers chunk body = [b"0" * s for s in [256] + list(j)] job_title = f"{fm_size(sum(j))} / {len(j)}"