From 1c0172666f26e11e99850198154336390fdacb09 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 26 May 2025 00:31:06 -0500 Subject: [PATCH] Support Reusable Request Bodies and Improve Payload Handling (#11017) --- CHANGES/11017.feature.rst | 3 + CHANGES/5530.feature.rst | 1 + CHANGES/5577.feature.rst | 1 + CHANGES/9201.feature.rst | 1 + CONTRIBUTORS.txt | 1 + aiohttp/client.py | 12 + aiohttp/client_middleware_digest_auth.py | 14 +- aiohttp/client_reqrep.py | 196 ++++-- aiohttp/formdata.py | 5 +- aiohttp/multipart.py | 71 ++ aiohttp/payload.py | 377 +++++++++-- aiohttp/web_response.py | 4 + docs/client_reference.rst | 91 +++ tests/test_client_functional.py | 546 +++++++++++++++- tests/test_client_middleware.py | 108 ++++ tests/test_client_middleware_digest_auth.py | 49 +- tests/test_client_request.py | 476 +++++++++++++- tests/test_client_session.py | 10 +- tests/test_formdata.py | 172 ++++- tests/test_multipart.py | 199 +++++- tests/test_payload.py | 675 +++++++++++++++++++- 21 files changed, 2860 insertions(+), 152 deletions(-) create mode 100644 CHANGES/11017.feature.rst create mode 120000 CHANGES/5530.feature.rst create mode 120000 CHANGES/5577.feature.rst create mode 120000 CHANGES/9201.feature.rst diff --git a/CHANGES/11017.feature.rst b/CHANGES/11017.feature.rst new file mode 100644 index 00000000000..361c56e3fe8 --- /dev/null +++ b/CHANGES/11017.feature.rst @@ -0,0 +1,3 @@ +Added support for reusable request bodies to enable retries, redirects, and digest authentication -- by :user:`bdraco` and :user:`GLGDLY`. + +Most payloads can now be safely reused multiple times, fixing long-standing issues where POST requests with form data or file uploads would fail on redirects with errors like "Form data has been processed already" or "I/O operation on closed file". This also enables digest authentication to work with request bodies and allows retry mechanisms to resend requests without consuming the payload. Note that payloads derived from async iterables may still not be reusable in some cases. diff --git a/CHANGES/5530.feature.rst b/CHANGES/5530.feature.rst new file mode 120000 index 00000000000..63bf4429e55 --- /dev/null +++ b/CHANGES/5530.feature.rst @@ -0,0 +1 @@ +11017.feature.rst \ No newline at end of file diff --git a/CHANGES/5577.feature.rst b/CHANGES/5577.feature.rst new file mode 120000 index 00000000000..63bf4429e55 --- /dev/null +++ b/CHANGES/5577.feature.rst @@ -0,0 +1 @@ +11017.feature.rst \ No newline at end of file diff --git a/CHANGES/9201.feature.rst b/CHANGES/9201.feature.rst new file mode 120000 index 00000000000..63bf4429e55 --- /dev/null +++ b/CHANGES/9201.feature.rst @@ -0,0 +1 @@ +11017.feature.rst \ No newline at end of file diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index ada385c74e3..42328be3848 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -143,6 +143,7 @@ Frederik Gladhorn Frederik Peter Aalund Gabriel Tremblay Gang Ji +Gary Leung Gary Wilson Jr. Gennady Andreyev Georges Dubus diff --git a/aiohttp/client.py b/aiohttp/client.py index 2b1ccb8ee03..20e7ce6cebb 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -734,6 +734,8 @@ async def _connect_and_send_request( redirects += 1 history.append(resp) if max_redirects and redirects >= max_redirects: + if req._body is not None: + await req._body.close() resp.close() raise TooManyRedirects( history[0].request_info, tuple(history) @@ -765,6 +767,9 @@ async def _connect_and_send_request( r_url, encoded=not self._requote_redirect_url ) except ValueError as e: + if req._body is not None: + await req._body.close() + resp.close() raise InvalidUrlRedirectClientError( r_url, "Server attempted redirecting to a location that does not look like a URL", @@ -772,6 +777,8 @@ async def _connect_and_send_request( scheme = parsed_redirect_url.scheme if scheme not in HTTP_AND_EMPTY_SCHEMA_SET: + if req._body is not None: + await req._body.close() resp.close() raise NonHttpUrlRedirectClientError(r_url) elif not scheme: @@ -786,6 +793,9 @@ async def _connect_and_send_request( try: redirect_origin = parsed_redirect_url.origin() except ValueError as origin_val_err: + if req._body is not None: + await req._body.close() + resp.close() raise InvalidUrlRedirectClientError( parsed_redirect_url, "Invalid redirect URL origin", @@ -805,6 +815,8 @@ async def _connect_and_send_request( break + if req._body is not None: + await req._body.close() # check response status if raise_for_status is None: raise_for_status = self._raise_for_status diff --git a/aiohttp/client_middleware_digest_auth.py b/aiohttp/client_middleware_digest_auth.py index b63efaf0142..9a8ffc18313 100644 --- a/aiohttp/client_middleware_digest_auth.py +++ b/aiohttp/client_middleware_digest_auth.py @@ -29,6 +29,7 @@ from .client_exceptions import ClientError from .client_middlewares import ClientHandlerType from .client_reqrep import ClientRequest, ClientResponse +from .payload import Payload class DigestAuthChallenge(TypedDict, total=False): @@ -192,7 +193,7 @@ def __init__( self._nonce_count = 0 self._challenge: DigestAuthChallenge = {} - def _encode(self, method: str, url: URL, body: Union[bytes, str]) -> str: + async def _encode(self, method: str, url: URL, body: Union[bytes, Payload]) -> str: """ Build digest authorization header for the current challenge. @@ -207,6 +208,7 @@ def _encode(self, method: str, url: URL, body: Union[bytes, str]) -> str: Raises: ClientError: If the challenge is missing required parameters or contains unsupported values + """ challenge = self._challenge if "realm" not in challenge: @@ -272,11 +274,11 @@ def KD(s: bytes, d: bytes) -> bytes: A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes)) A2 = f"{method.upper()}:{path}".encode() if qop == "auth-int": - if isinstance(body, str): - entity_str = body.encode("utf-8", errors="replace") + if isinstance(body, bytes): # will always be empty bytes unless Payload + entity_bytes = body else: - entity_str = body - entity_hash = H(entity_str) + entity_bytes = await body.as_bytes() # Get bytes from Payload + entity_hash = H(entity_bytes) A2 = b":".join((A2, entity_hash)) HA1 = H(A1) @@ -398,7 +400,7 @@ async def __call__( for retry_count in range(2): # Apply authorization header if we have a challenge (on second attempt) if retry_count > 0: - request.headers[hdrs.AUTHORIZATION] = self._encode( + request.headers[hdrs.AUTHORIZATION] = await self._encode( request.method, request.url, request.body ) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 384087cd8b3..59a11be3764 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -190,6 +190,25 @@ class ConnectionKey(NamedTuple): proxy_headers_hash: Optional[int] # hash(CIMultiDict) +def _warn_if_unclosed_payload(payload: payload.Payload, stacklevel: int = 2) -> None: + """Warn if the payload is not closed. + + Callers must check that the body is a Payload before calling this method. + + Args: + payload: The payload to check + stacklevel: Stack level for the warning (default 2 for direct callers) + """ + if not payload.autoclose and not payload.consumed: + warnings.warn( + "The previous request body contains unclosed resources. " + "Use await request.update_body() instead of setting request.body " + "directly to properly close resources and avoid leaks.", + ResourceWarning, + stacklevel=stacklevel, + ) + + class ClientRequest: GET_METHODS = { hdrs.METH_GET, @@ -206,7 +225,7 @@ class ClientRequest: } # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. - body: Any = b"" + _body: Union[None, payload.Payload] = None auth = None response = None @@ -373,6 +392,36 @@ def host(self) -> str: def port(self) -> Optional[int]: return self.url.port + @property + def body(self) -> Union[bytes, payload.Payload]: + """Request body.""" + # empty body is represented as bytes for backwards compatibility + return self._body or b"" + + @body.setter + def body(self, value: Any) -> None: + """Set request body with warning for non-autoclose payloads. + + WARNING: This setter must be called from within an event loop and is not + thread-safe. Setting body outside of an event loop may raise RuntimeError + when closing file-based payloads. + + DEPRECATED: Direct assignment to body is deprecated and will be removed + in a future version. Use await update_body() instead for proper resource + management. + """ + # Close existing payload if present + if self._body is not None: + # Warn if the payload needs manual closing + # stacklevel=3: user code -> body setter -> _warn_if_unclosed_payload + _warn_if_unclosed_payload(self._body, stacklevel=3) + # NOTE: In the future, when we remove sync close support, + # this setter will need to be removed and only the async + # update_body() method will be available. For now, we call + # _close() for backwards compatibility. + self._body._close() + self._update_body(value) + @property def request_info(self) -> RequestInfo: headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers) @@ -522,9 +571,12 @@ def update_transfer_encoding(self) -> None: ) self.headers[hdrs.TRANSFER_ENCODING] = "chunked" - else: - if hdrs.CONTENT_LENGTH not in self.headers: - self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) + elif ( + self._body is not None + and hdrs.CONTENT_LENGTH not in self.headers + and (size := self._body.size) is not None + ): + self.headers[hdrs.CONTENT_LENGTH] = str(size) def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: """Set basic auth.""" @@ -542,42 +594,125 @@ def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> Non self.headers[hdrs.AUTHORIZATION] = auth.encode() - def update_body_from_data(self, body: Any) -> None: + def update_body_from_data(self, body: Any, _stacklevel: int = 3) -> None: + """Update request body from data.""" + if self._body is not None: + _warn_if_unclosed_payload(self._body, stacklevel=_stacklevel) + if body is None: + self._body = None return # FormData - if isinstance(body, FormData): - body = body() + maybe_payload = body() if isinstance(body, FormData) else body try: - body = payload.PAYLOAD_REGISTRY.get(body, disposition=None) + body_payload = payload.PAYLOAD_REGISTRY.get(maybe_payload, disposition=None) except payload.LookupError: - boundary = None + boundary: Optional[str] = None if CONTENT_TYPE in self.headers: boundary = parse_mimetype(self.headers[CONTENT_TYPE]).parameters.get( "boundary" ) - body = FormData(body, boundary=boundary)() - - self.body = body + body_payload = FormData(maybe_payload, boundary=boundary)() # type: ignore[arg-type] + self._body = body_payload # enable chunked encoding if needed if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers: - if (size := body.size) is not None: + if (size := body_payload.size) is not None: self.headers[hdrs.CONTENT_LENGTH] = str(size) else: self.chunked = True # copy payload headers - assert body.headers + assert body_payload.headers headers = self.headers skip_headers = self._skip_auto_headers - for key, value in body.headers.items(): + for key, value in body_payload.headers.items(): if key in headers or (skip_headers is not None and key in skip_headers): continue headers[key] = value + def _update_body(self, body: Any) -> None: + """Update request body after its already been set.""" + # Remove existing Content-Length header since body is changing + if hdrs.CONTENT_LENGTH in self.headers: + del self.headers[hdrs.CONTENT_LENGTH] + + # Remove existing Transfer-Encoding header to avoid conflicts + if self.chunked and hdrs.TRANSFER_ENCODING in self.headers: + del self.headers[hdrs.TRANSFER_ENCODING] + + # Now update the body using the existing method + # Called from _update_body, add 1 to stacklevel from caller + self.update_body_from_data(body, _stacklevel=4) + + # Update transfer encoding headers if needed (same logic as __init__) + if body is not None or self.method not in self.GET_METHODS: + self.update_transfer_encoding() + + async def update_body(self, body: Any) -> None: + """ + Update request body and close previous payload if needed. + + This method safely updates the request body by first closing any existing + payload to prevent resource leaks, then setting the new body. + + IMPORTANT: Always use this method instead of setting request.body directly. + Direct assignment to request.body will leak resources if the previous body + contains file handles, streams, or other resources that need cleanup. + + Args: + body: The new body content. Can be: + - bytes/bytearray: Raw binary data + - str: Text data (will be encoded using charset from Content-Type) + - FormData: Form data that will be encoded as multipart/form-data + - Payload: A pre-configured payload object + - AsyncIterable: An async iterable of bytes chunks + - File-like object: Will be read and sent as binary data + - None: Clears the body + + Usage: + # CORRECT: Use update_body + await request.update_body(b"new request data") + + # WRONG: Don't set body directly + # request.body = b"new request data" # This will leak resources! + + # Update with form data + form_data = FormData() + form_data.add_field('field', 'value') + await request.update_body(form_data) + + # Clear body + await request.update_body(None) + + Note: + This method is async because it may need to close file handles or + other resources associated with the previous payload. Always await + this method to ensure proper cleanup. + + Warning: + Setting request.body directly is highly discouraged and can lead to: + - Resource leaks (unclosed file handles, streams) + - Memory leaks (unreleased buffers) + - Unexpected behavior with streaming payloads + + It is not recommended to change the payload type in middleware. If the + body was already set (e.g., as bytes), it's best to keep the same type + rather than converting it (e.g., to str) as this may result in unexpected + behavior. + + See Also: + - update_body_from_data: Synchronous body update without cleanup + - body property: Direct body access (STRONGLY DISCOURAGED) + + """ + # Close existing payload if it exists and needs closing + if self._body is not None: + await self._body.close() + self._update_body(body) + def update_expect_continue(self, expect: bool = False) -> None: if expect: self.headers[hdrs.EXPECT] = "100-continue" @@ -654,27 +789,14 @@ async def write_bytes( protocol = conn.protocol assert protocol is not None try: - if isinstance(self.body, payload.Payload): - # Specialized handling for Payload objects that know how to write themselves - await self.body.write_with_length(writer, content_length) - else: - # Handle bytes/bytearray by converting to an iterable for consistent handling - if isinstance(self.body, (bytes, bytearray)): - self.body = (self.body,) - - if content_length is None: - # Write the entire body without length constraint - for chunk in self.body: - await writer.write(chunk) - else: - # Write with length constraint, respecting content_length limit - # If the body is larger than content_length, we truncate it - remaining_bytes = content_length - for chunk in self.body: - await writer.write(chunk[:remaining_bytes]) - remaining_bytes -= len(chunk) - if remaining_bytes <= 0: - break + # This should be a rare case but the + # self._body can be set to None while + # the task is being started or we wait above + # for the 100-continue response. + # The more likely case is we have an empty + # payload, but 100-continue is still expected. + if self._body is not None: + await self._body.write_with_length(writer, content_length) except OSError as underlying_exc: reraised_exc = underlying_exc @@ -770,7 +892,7 @@ async def send(self, conn: "Connection") -> "ClientResponse": await writer.write_headers(status_line, self.headers) task: Optional["asyncio.Task[None]"] - if self.body or self._continue is not None or protocol.writing_paused: + if self._body or self._continue is not None or protocol.writing_paused: coro = self.write_bytes(writer, conn, self._get_content_length()) if sys.version_info >= (3, 12): # Optimization for Python 3.12, try to write diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py index 994df482e00..1b74cf8da02 100644 --- a/aiohttp/formdata.py +++ b/aiohttp/formdata.py @@ -30,7 +30,6 @@ def __init__( self._writer = multipart.MultipartWriter("form-data", boundary=self._boundary) self._fields: List[Any] = [] self._is_multipart = default_to_multipart - self._is_processed = False self._quote_fields = quote_fields self._charset = charset @@ -121,8 +120,6 @@ def _gen_form_urlencoded(self) -> payload.BytesPayload: def _gen_form_data(self) -> multipart.MultipartWriter: """Encode a list of fields using the multipart/form-data MIME format""" - if self._is_processed: - raise RuntimeError("Form data has been processed already") for dispparams, headers, value in self._fields: try: if hdrs.CONTENT_TYPE in headers: @@ -153,7 +150,7 @@ def _gen_form_data(self) -> multipart.MultipartWriter: self._writer.append_payload(part) - self._is_processed = True + self._fields.clear() return self._writer def __call__(self) -> Payload: diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 5c437248ee4..fe5499e1144 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -36,6 +36,7 @@ ) from .helpers import CHAR, TOKEN, parse_mimetype, reify from .http import HeadersParser +from .log import internal_logger from .payload import ( JsonPayload, LookupError, @@ -560,6 +561,7 @@ def filename(self) -> Optional[str]: @payload_type(BodyPartReader, order=Order.try_first) class BodyPartReaderPayload(Payload): _value: BodyPartReader + # _autoclose = False (inherited) - Streaming reader that may have resources def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: super().__init__(value, *args, **kwargs) @@ -576,6 +578,16 @@ def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: raise TypeError("Unable to decode.") + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """Raises TypeError as body parts should be consumed via write(). + + This is intentional: BodyPartReader payloads are designed for streaming + large data (potentially gigabytes) and must be consumed only once via + the write() method to avoid memory exhaustion. They cannot be buffered + in memory for reuse. + """ + raise TypeError("Unable to read body part as bytes. Use write() to consume.") + async def write(self, writer: Any) -> None: field = self._value chunk = await field.read_chunk(size=2**16) @@ -793,6 +805,8 @@ class MultipartWriter(Payload): """Multipart body writer.""" _value: None + # _consumed = False (inherited) - Can be encoded multiple times + _autoclose = True # No file handles, just collects parts in memory def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None: boundary = boundary if boundary is not None else uuid.uuid4().hex @@ -979,6 +993,11 @@ def size(self) -> Optional[int]: return total def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + """Return string representation of the multipart data. + + WARNING: This method may do blocking I/O if parts contain file payloads. + It should not be called in the event loop. Use as_bytes().decode() instead. + """ return "".join( "--" + self.boundary @@ -988,6 +1007,33 @@ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: for part, _e, _te in self._parts ) + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """Return bytes representation of the multipart data. + + This method is async-safe and calls as_bytes on underlying payloads. + """ + parts: List[bytes] = [] + + # Process each part + for part, _e, _te in self._parts: + # Add boundary + parts.append(b"--" + self._boundary + b"\r\n") + + # Add headers + parts.append(part._binary_headers) + + # Add payload content using as_bytes for async safety + part_bytes = await part.as_bytes(encoding, errors) + parts.append(part_bytes) + + # Add trailing CRLF + parts.append(b"\r\n") + + # Add closing boundary + parts.append(b"--" + self._boundary + b"--\r\n") + + return b"".join(parts) + async def write(self, writer: Any, close_boundary: bool = True) -> None: """Write body.""" for part, encoding, te_encoding in self._parts: @@ -1015,6 +1061,31 @@ async def write(self, writer: Any, close_boundary: bool = True) -> None: if close_boundary: await writer.write(b"--" + self._boundary + b"--\r\n") + async def close(self) -> None: + """ + Close all part payloads that need explicit closing. + + IMPORTANT: This method must not await anything that might not finish + immediately, as it may be called during cleanup/cancellation. Schedule + any long-running operations without awaiting them. + """ + if self._consumed: + return + self._consumed = True + + # Close all parts that need explicit closing + # We catch and log exceptions to ensure all parts get a chance to close + # we do not use asyncio.gather() here because we are not allowed + # to suspend given we may be called during cleanup + for idx, (part, _, _) in enumerate(self._parts): + if not part.autoclose and not part.consumed: + try: + await part.close() + except Exception as exc: + internal_logger.error( + "Failed to close multipart part %d: %s", idx, exc, exc_info=True + ) + class MultipartPayloadWriter: def __init__(self, writer: Any) -> None: diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 8b16d16aa87..dbe12ed639c 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -15,6 +15,7 @@ Dict, Final, Iterable, + List, Optional, Set, TextIO, @@ -58,12 +59,8 @@ _CLOSE_FUTURES: Set[asyncio.Future[None]] = set() -if TYPE_CHECKING: - from typing import List - - class LookupError(Exception): - pass + """Raised when no payload factory is found for the given data type.""" class Order(str, enum.Enum): @@ -154,6 +151,8 @@ def register( class Payload(ABC): _default_content_type: str = "application/octet-stream" _size: Optional[int] = None + _consumed: bool = False # Default: payload has not been consumed yet + _autoclose: bool = False # Default: assume resource needs explicit closing def __init__( self, @@ -189,7 +188,12 @@ def __init__( @property def size(self) -> Optional[int]: - """Size of the payload.""" + """Size of the payload in bytes. + + Returns the number of bytes that will be transmitted when the payload + is written. For string payloads, this is the size after encoding to bytes, + not the length of the string. + """ return self._size @property @@ -221,6 +225,21 @@ def content_type(self) -> str: """Content type""" return self._headers[hdrs.CONTENT_TYPE] + @property + def consumed(self) -> bool: + """Whether the payload has been consumed and cannot be reused.""" + return self._consumed + + @property + def autoclose(self) -> bool: + """ + Whether the payload can close itself automatically. + + Returns True if the payload has no file handles or resources that need + explicit closing. If False, callers must await close() to release resources. + """ + return self._autoclose + def set_content_disposition( self, disptype: str, @@ -235,14 +254,16 @@ def set_content_disposition( @abstractmethod def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: - """Return string representation of the value. + """ + Return string representation of the value. This is named decode() to allow compatibility with bytes objects. """ @abstractmethod async def write(self, writer: AbstractStreamWriter) -> None: - """Write payload to the writer stream. + """ + Write payload to the writer stream. Args: writer: An AbstractStreamWriter instance that handles the actual writing @@ -256,6 +277,7 @@ async def write(self, writer: AbstractStreamWriter) -> None: All payload subclasses must override this method for backwards compatibility, but new code should use write_with_length for more flexibility and control. + """ # write_with_length is new in aiohttp 3.12 @@ -283,9 +305,52 @@ async def write_with_length( # and for the default implementation await self.write(writer) + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """ + Return bytes representation of the value. + + This is a convenience method that calls decode() and encodes the result + to bytes using the specified encoding. + """ + # Use instance encoding if available, otherwise use parameter + actual_encoding = self._encoding or encoding + return self.decode(actual_encoding, errors).encode(actual_encoding) + + def _close(self) -> None: + """ + Async safe synchronous close operations for backwards compatibility. + + This method exists only for backwards compatibility with code that + needs to clean up payloads synchronously. In the future, we will + drop this method and only support the async close() method. + + WARNING: This method must be safe to call from within the event loop + without blocking. Subclasses should not perform any blocking I/O here. + + WARNING: This method must be called from within an event loop for + certain payload types (e.g., IOBasePayload). Calling it outside an + event loop may raise RuntimeError. + """ + # This is a no-op by default, but subclasses can override it + # for non-blocking cleanup operations. + + async def close(self) -> None: + """ + Close the payload if it holds any resources. + + IMPORTANT: This method must not await anything that might not finish + immediately, as it may be called during cleanup/cancellation. Schedule + any long-running operations without awaiting them. + + In the future, this will be the only close method supported. + """ + self._close() + class BytesPayload(Payload): _value: bytes + # _consumed = False (inherited) - Bytes are immutable and can be reused + _autoclose = True # No file handle, just bytes in memory def __init__( self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any @@ -314,8 +379,18 @@ def __init__( def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: return self._value.decode(encoding, errors) + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """ + Return bytes representation of the value. + + This method returns the raw bytes content of the payload. + It is equivalent to accessing the _value attribute directly. + """ + return self._value + async def write(self, writer: AbstractStreamWriter) -> None: - """Write the entire bytes payload to the writer stream. + """ + Write the entire bytes payload to the writer stream. Args: writer: An AbstractStreamWriter instance that handles the actual writing @@ -326,6 +401,7 @@ async def write(self, writer: AbstractStreamWriter) -> None: For new implementations that need length control, use write_with_length(). This method is maintained for backwards compatibility and is equivalent to write_with_length(writer, None). + """ await writer.write(self._value) @@ -387,6 +463,9 @@ def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: class IOBasePayload(Payload): _value: io.IOBase + # _consumed = False (inherited) - File can be re-read from the same position + _start_position: Optional[int] = None + # _autoclose = False (inherited) - Has file handle that needs explicit closing def __init__( self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any @@ -400,6 +479,16 @@ def __init__( if hdrs.CONTENT_DISPOSITION not in self.headers: self.set_content_disposition(disposition, filename=self._filename) + def _set_or_restore_start_position(self) -> None: + """Set or restore the start position of the file-like object.""" + if self._start_position is None: + try: + self._start_position = self._value.tell() + except OSError: + self._consumed = True # Cannot seek, mark as consumed + return + self._value.seek(self._start_position) + def _read_and_available_len( self, remaining_content_len: Optional[int] ) -> Tuple[Optional[int], bytes]: @@ -420,6 +509,7 @@ def _read_and_available_len( context switches and file operations when streaming content. """ + self._set_or_restore_start_position() size = self.size # Call size only once since it does I/O return size, self._value.read( min(size or READ_SIZE, remaining_content_len or READ_SIZE) @@ -445,6 +535,12 @@ def _read(self, remaining_content_len: Optional[int]) -> bytes: @property def size(self) -> Optional[int]: + """ + Size of the payload in bytes. + + Returns the number of bytes remaining to be read from the file. + Returns None if the size cannot be determined (e.g., for unseekable streams). + """ try: return os.fstat(self._value.fileno()).st_size - self._value.tell() except (AttributeError, OSError): @@ -495,38 +591,31 @@ async def write_with_length( total_written_len = 0 remaining_content_len = content_length - try: - # Get initial data and available length - available_len, chunk = await loop.run_in_executor( - None, self._read_and_available_len, remaining_content_len - ) - # Process data chunks until done - while chunk: - chunk_len = len(chunk) + # Get initial data and available length + available_len, chunk = await loop.run_in_executor( + None, self._read_and_available_len, remaining_content_len + ) + # Process data chunks until done + while chunk: + chunk_len = len(chunk) - # Write data with or without length constraint - if remaining_content_len is None: - await writer.write(chunk) - else: - await writer.write(chunk[:remaining_content_len]) - remaining_content_len -= chunk_len + # Write data with or without length constraint + if remaining_content_len is None: + await writer.write(chunk) + else: + await writer.write(chunk[:remaining_content_len]) + remaining_content_len -= chunk_len - total_written_len += chunk_len + total_written_len += chunk_len - # Check if we're done writing - if self._should_stop_writing( - available_len, total_written_len, remaining_content_len - ): - return + # Check if we're done writing + if self._should_stop_writing( + available_len, total_written_len, remaining_content_len + ): + return - # Read next chunk - chunk = await loop.run_in_executor( - None, self._read, remaining_content_len - ) - finally: - # Handle closing the file without awaiting to prevent cancellation issues - # when the StreamReader reaches EOF - self._schedule_file_close(loop) + # Read next chunk + chunk = await loop.run_in_executor(None, self._read, remaining_content_len) def _should_stop_writing( self, @@ -552,20 +641,67 @@ def _should_stop_writing( remaining_content_len is not None and remaining_content_len <= 0 ) - def _schedule_file_close(self, loop: asyncio.AbstractEventLoop) -> None: - """Schedule file closing without awaiting to prevent cancellation issues.""" + def _close(self) -> None: + """ + Async safe synchronous close operations for backwards compatibility. + + This method exists only for backwards + compatibility. Use the async close() method instead. + + WARNING: This method MUST be called from within an event loop. + Calling it outside an event loop will raise RuntimeError. + """ + # Skip if already consumed + if self._consumed: + return + self._consumed = True # Mark as consumed to prevent further writes + # Schedule file closing without awaiting to prevent cancellation issues + loop = asyncio.get_running_loop() close_future = loop.run_in_executor(None, self._value.close) # Hold a strong reference to the future to prevent it from being # garbage collected before it completes. _CLOSE_FUTURES.add(close_future) close_future.add_done_callback(_CLOSE_FUTURES.remove) + async def close(self) -> None: + """ + Close the payload if it holds any resources. + + IMPORTANT: This method must not await anything that might not finish + immediately, as it may be called during cleanup/cancellation. Schedule + any long-running operations without awaiting them. + """ + self._close() + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: - return "".join(r.decode(encoding, errors) for r in self._value.readlines()) + """ + Return string representation of the value. + + WARNING: This method does blocking I/O and should not be called in the event loop. + """ + return self._read_all().decode(encoding, errors) + + def _read_all(self) -> bytes: + """Read the entire file-like object and return its content as bytes.""" + self._set_or_restore_start_position() + # Use readlines() to ensure we get all content + return b"".join(self._value.readlines()) + + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """ + Return bytes representation of the value. + + This method reads the entire file content and returns it as bytes. + It is equivalent to reading the file-like object directly. + The file reading is performed in an executor to avoid blocking the event loop. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self._read_all) class TextIOPayload(IOBasePayload): _value: io.TextIOBase + # _autoclose = False (inherited) - Has text file handle that needs explicit closing def __init__( self, @@ -618,6 +754,7 @@ def _read_and_available_len( to the stream. If no encoding is specified, UTF-8 is used as the default. """ + self._set_or_restore_start_position() size = self.size chunk = self._value.read( min(size or READ_SIZE, remaining_content_len or READ_SIZE) @@ -646,20 +783,56 @@ def _read(self, remaining_content_len: Optional[int]) -> bytes: return chunk.encode(self._encoding) if self._encoding else chunk.encode() def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + """ + Return string representation of the value. + + WARNING: This method does blocking I/O and should not be called in the event loop. + """ + self._set_or_restore_start_position() return self._value.read() + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """ + Return bytes representation of the value. + + This method reads the entire text file content and returns it as bytes. + It encodes the text content using the specified encoding. + The file reading is performed in an executor to avoid blocking the event loop. + """ + loop = asyncio.get_running_loop() + + # Use instance encoding if available, otherwise use parameter + actual_encoding = self._encoding or encoding + + def _read_and_encode() -> bytes: + self._set_or_restore_start_position() + # TextIO read() always returns the full content + return self._value.read().encode(actual_encoding, errors) + + return await loop.run_in_executor(None, _read_and_encode) + class BytesIOPayload(IOBasePayload): _value: io.BytesIO + _size: int # Always initialized in __init__ + _autoclose = True # BytesIO is in-memory, safe to auto-close + + def __init__(self, value: io.BytesIO, *args: Any, **kwargs: Any) -> None: + super().__init__(value, *args, **kwargs) + # Calculate size once during initialization + self._size = len(self._value.getbuffer()) - self._value.tell() @property def size(self) -> int: - position = self._value.tell() - end = self._value.seek(0, os.SEEK_END) - self._value.seek(position) - return end - position + """Size of the payload in bytes. + + Returns the number of bytes in the BytesIO buffer that will be transmitted. + This is calculated once during initialization for efficiency. + """ + return self._size def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + self._set_or_restore_start_position() return self._value.read().decode(encoding, errors) async def write(self, writer: AbstractStreamWriter) -> None: @@ -687,32 +860,49 @@ async def write_with_length( responsiveness when processing large in-memory buffers. """ + self._set_or_restore_start_position() loop_count = 0 remaining_bytes = content_length - try: - while chunk := self._value.read(READ_SIZE): - if loop_count > 0: - # Avoid blocking the event loop - # if they pass a large BytesIO object - # and we are not in the first iteration - # of the loop - await asyncio.sleep(0) - if remaining_bytes is None: - await writer.write(chunk) - else: - await writer.write(chunk[:remaining_bytes]) - remaining_bytes -= len(chunk) - if remaining_bytes <= 0: - return - loop_count += 1 - finally: - self._value.close() + while chunk := self._value.read(READ_SIZE): + if loop_count > 0: + # Avoid blocking the event loop + # if they pass a large BytesIO object + # and we are not in the first iteration + # of the loop + await asyncio.sleep(0) + if remaining_bytes is None: + await writer.write(chunk) + else: + await writer.write(chunk[:remaining_bytes]) + remaining_bytes -= len(chunk) + if remaining_bytes <= 0: + return + loop_count += 1 + + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """ + Return bytes representation of the value. + + This method reads the entire BytesIO content and returns it as bytes. + It is equivalent to accessing the _value attribute directly. + """ + self._set_or_restore_start_position() + return self._value.read() + + async def close(self) -> None: + """ + Close the BytesIO payload. + + This does nothing since BytesIO is in-memory and does not require explicit closing. + """ class BufferedReaderPayload(IOBasePayload): _value: io.BufferedIOBase + # _autoclose = False (inherited) - Has buffered file handle that needs explicit closing def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + self._set_or_restore_start_position() return self._value.read().decode(encoding, errors) @@ -750,6 +940,9 @@ def __init__( class AsyncIterablePayload(Payload): _iter: Optional[_AsyncIterator] = None _value: _AsyncIterable + _cached_chunks: Optional[List[bytes]] = None + # _consumed stays False to allow reuse with cached content + _autoclose = True # Iterator doesn't need explicit closing def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: if not isinstance(value, AsyncIterable): @@ -795,17 +988,30 @@ async def write_with_length( This implementation handles streaming of async iterable content with length constraints: - 1. Iterates through the async iterable one chunk at a time - 2. Respects content_length constraints when specified - 3. Handles the case when the iterable might be used twice - - Since async iterables are consumed as they're iterated, there is no way to - restart the iteration if it's already in progress or completed. + 1. If cached chunks are available, writes from them + 2. Otherwise iterates through the async iterable one chunk at a time + 3. Respects content_length constraints when specified + 4. Does NOT generate cache - that's done by as_bytes() """ + # If we have cached chunks, use them + if self._cached_chunks is not None: + remaining_bytes = content_length + for chunk in self._cached_chunks: + if remaining_bytes is None: + await writer.write(chunk) + elif remaining_bytes > 0: + await writer.write(chunk[:remaining_bytes]) + remaining_bytes -= len(chunk) + else: + break + return + + # If iterator is exhausted and we don't have cached chunks, nothing to write if self._iter is None: return + # Stream from the iterator remaining_bytes = content_length try: @@ -827,9 +1033,40 @@ async def write_with_length( except StopAsyncIteration: # Iterator is exhausted self._iter = None + self._consumed = True # Mark as consumed when streamed without caching def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: - raise TypeError("Unable to decode.") + """Decode the payload content as a string if cached chunks are available.""" + if self._cached_chunks is not None: + return b"".join(self._cached_chunks).decode(encoding, errors) + raise TypeError("Unable to decode - content not cached. Call as_bytes() first.") + + async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> bytes: + """ + Return bytes representation of the value. + + This method reads the entire async iterable content and returns it as bytes. + It generates and caches the chunks for future reuse. + """ + # If we have cached chunks, return them joined + if self._cached_chunks is not None: + return b"".join(self._cached_chunks) + + # If iterator is exhausted and no cache, return empty + if self._iter is None: + return b"" + + # Read all chunks and cache them + chunks: List[bytes] = [] + async for chunk in self._iter: + chunks.append(chunk) + + # Iterator is exhausted, cache the chunks + self._iter = None + self._cached_chunks = chunks + # Keep _consumed as False to allow reuse with cached chunks + + return b"".join(chunks) class StreamReaderPayload(AsyncIterablePayload): @@ -847,5 +1084,5 @@ def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) # try_last for giving a chance to more specialized async interables like -# multidict.BodyPartReaderPayload override the default +# multipart.BodyPartReaderPayload override the default PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last) diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 21fcde45968..dc4c16804d7 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -625,6 +625,9 @@ def body(self, body: Any) -> None: def text(self) -> Optional[str]: if self._body is None: return None + # Note: When _body is a Payload (e.g. FilePayload), this may do blocking I/O + # This is generally safe as most common payloads (BytesPayload, StringPayload) + # don't do blocking I/O, but be careful with file-based payloads return self._body.decode(self.charset or "utf-8") @text.setter @@ -676,6 +679,7 @@ async def write_eof(self, data: bytes = b"") -> None: await super().write_eof() elif isinstance(self._body, Payload): await self._body.write(self._payload_writer) + await self._body.close() await super().write_eof() else: await super().write_eof(cast(bytes, body)) diff --git a/docs/client_reference.rst b/docs/client_reference.rst index faae389f95c..b08df9c05ba 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1873,6 +1873,26 @@ ClientRequest - A :class:`Payload` object for raw data (default is empty bytes ``b""``) - A :class:`FormData` object for form submissions + .. danger:: + + **DO NOT set this attribute directly!** Direct assignment will cause resource + leaks. Always use :meth:`update_body` instead: + + .. code-block:: python + + # WRONG - This will leak resources! + request.body = b"new data" + + # CORRECT - Use update_body + await request.update_body(b"new data") + + Setting body directly bypasses cleanup of the previous payload, which can + leave file handles open, streams unclosed, and buffers unreleased. + + Additionally, setting body directly must be done from within an event loop + and is not thread-safe. Setting body outside of an event loop may raise + RuntimeError when closing file-based payloads. + .. attribute:: chunked :type: bool | None @@ -1974,6 +1994,77 @@ ClientRequest The HTTP version to use for the request (e.g., ``HttpVersion(1, 1)`` for HTTP/1.1). + .. method:: update_body(body) + + Update the request body and close any existing payload to prevent resource leaks. + + **This is the ONLY correct way to modify a request body.** Never set the + :attr:`body` attribute directly. + + This method is particularly useful in middleware when you need to modify the + request body after the request has been created but before it's sent. + + :param body: The new body content. Can be: + + - ``bytes``/``bytearray``: Raw binary data + - ``str``: Text data (encoded using charset from Content-Type) + - :class:`FormData`: Form data encoded as multipart/form-data + - :class:`Payload`: A pre-configured payload object + - ``AsyncIterable[bytes]``: Async iterable of bytes chunks + - File-like object: Will be read and sent as binary data + - ``None``: Clears the body + + .. code-block:: python + + async def middleware(request, handler): + # Modify request body in middleware + if request.method == 'POST': + # CORRECT: Always use update_body + await request.update_body(b'{"modified": true}') + + # WRONG: Never set body directly! + # request.body = b'{"modified": true}' # This leaks resources! + + # Or add authentication data to form + if isinstance(request.body, FormData): + form = FormData() + # Copy existing fields and add auth token + form.add_field('auth_token', 'secret123') + await request.update_body(form) + + return await handler(request) + + .. note:: + + This method is async because it may need to close file handles or + other resources associated with the previous payload. Always await + this method to ensure proper cleanup. + + .. danger:: + + **Never set :attr:`ClientRequest.body` directly!** Direct assignment will cause resource + leaks. Always use this method instead. Setting the body attribute directly: + + - Bypasses cleanup of the previous payload + - Leaves file handles and streams open + - Can cause memory leaks + - May result in unexpected behavior with async iterables + + .. warning:: + + When updating the body, ensure that the Content-Type header is + appropriate for the new body content. The Content-Length header + will be updated automatically. When using :class:`FormData` or + :class:`Payload` objects, headers are updated automatically, + but you may need to set Content-Type manually for raw bytes or text. + + It is not recommended to change the payload type in middleware. If the + body was already set (e.g., as bytes), it's best to keep the same type + rather than converting it (e.g., to str) as this may result in unexpected + behavior. + + .. versionadded:: 3.12 + Utilities diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 44009094431..9433ad2f2bb 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -33,7 +33,7 @@ from yarl import URL import aiohttp -from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, web +from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, payload, web from aiohttp.abc import AbstractResolver, ResolveResult from aiohttp.client_exceptions import ( ClientResponseError, @@ -48,6 +48,14 @@ from aiohttp.client_reqrep import ClientRequest from aiohttp.connector import Connection from aiohttp.http_writer import StreamWriter +from aiohttp.payload import ( + AsyncIterablePayload, + BufferedReaderPayload, + BytesIOPayload, + BytesPayload, + StringIOPayload, + StringPayload, +) from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.test_utils import TestClient, TestServer, unused_port from aiohttp.typedefs import Handler, Query @@ -619,6 +627,61 @@ async def handler(request: web.Request) -> web.Response: assert resp.status == 200 +async def test_post_bytes_data_content_length_from_body( + aiohttp_client: AiohttpClient, +) -> None: + """Test that Content-Length is set from body payload size when sending bytes.""" + data = b"test payload data" + + async def handler(request: web.Request) -> web.Response: + # Verify Content-Length header was set correctly + assert request.content_length == len(data) + assert request.headers.get("Content-Length") == str(len(data)) + + # Verify we can read the data + val = await request.read() + assert data == val + return web.Response() + + app = web.Application() + app.router.add_route("POST", "/", handler) + client = await aiohttp_client(app) + + # Send bytes data - this should trigger the code path where + # Content-Length is set from body.size in update_transfer_encoding + async with client.post("/", data=data) as resp: + assert resp.status == 200 + + +async def test_post_custom_payload_without_content_length( + aiohttp_client: AiohttpClient, +) -> None: + """Test that Content-Length is set from payload.size when not explicitly provided.""" + data = b"custom payload data" + + async def handler(request: web.Request) -> web.Response: + # Verify Content-Length header was set from payload size + assert request.content_length == len(data) + assert request.headers.get("Content-Length") == str(len(data)) + + # Verify we can read the data + val = await request.read() + assert data == val + return web.Response() + + app = web.Application() + app.router.add_route("POST", "/", handler) + client = await aiohttp_client(app) + + # Create a BytesPayload directly - this ensures we test the path + # where update_transfer_encoding sets Content-Length from body.size + bytes_payload = payload.BytesPayload(data) + + # Don't set Content-Length header explicitly + async with client.post("/", data=bytes_payload) as resp: + assert resp.status == 200 + + async def test_ssl_client( aiohttp_server: AiohttpServer, ssl_ctx: ssl.SSLContext, @@ -2107,6 +2170,51 @@ async def expect_handler(request: web.Request) -> None: assert expect_called +async def test_expect100_with_no_body(aiohttp_client: AiohttpClient) -> None: + """Test expect100 with GET request that has no body.""" + + async def handler(request: web.Request) -> web.Response: + return web.Response(text="OK") + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + # GET request with expect100=True but no body + async with client.get("/", expect100=True) as resp: + assert resp.status == 200 + assert await resp.text() == "OK" + + +async def test_expect100_continue_with_none_payload( + aiohttp_client: AiohttpClient, +) -> None: + """Test expect100 continue handling when payload is None from the start.""" + expect_received = False + + async def handler(request: web.Request) -> web.Response: + return web.Response(body=b"OK") + + async def expect_handler(request: web.Request) -> None: + nonlocal expect_received + expect_received = True + # Send 100 Continue + assert request.transport is not None + request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") + + app = web.Application() + app.router.add_post("/", handler, expect_handler=expect_handler) + client = await aiohttp_client(app) + + # POST request with expect100=True but no body (data=None) + async with client.post("/", expect100=True, data=None) as resp: + assert resp.status == 200 + assert await resp.read() == b"OK" + + # Expect handler should still be called even with no body + assert expect_received + + @pytest.mark.usefixtures("parametrize_zlib_backend") async def test_encoding_deflate( aiohttp_client: AiohttpClient, @@ -4573,3 +4681,439 @@ async def handler(request: web.Request) -> web.Response: data = await resp.read() assert data == b"" resp.close() + + +async def test_bytes_payload_redirect(aiohttp_client: AiohttpClient) -> None: + """Test that BytesPayload can be reused across redirects.""" + data_received = [] + + async def redirect_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("redirect", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("final", data)) + return web.Response(text=f"Received: {data.decode()}") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + payload_data = b"test payload data" + payload = BytesPayload(payload_data) + + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + text = await resp.text() + assert text == "Received: test payload data" + # Both endpoints should have received the data + assert data_received == [("redirect", payload_data), ("final", payload_data)] + + +async def test_string_payload_redirect(aiohttp_client: AiohttpClient) -> None: + """Test that StringPayload can be reused across redirects.""" + data_received = [] + + async def redirect_handler(request: web.Request) -> web.Response: + data = await request.text() + data_received.append(("redirect", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.text() + data_received.append(("final", data)) + return web.Response(text=f"Received: {data}") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + payload_data = "test string payload" + payload = StringPayload(payload_data) + + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + text = await resp.text() + assert text == "Received: test string payload" + # Both endpoints should have received the data + assert data_received == [("redirect", payload_data), ("final", payload_data)] + + +async def test_async_iterable_payload_redirect(aiohttp_client: AiohttpClient) -> None: + """Test that AsyncIterablePayload cannot be reused across redirects.""" + data_received = [] + + async def redirect_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("redirect", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("final", data)) + return web.Response(text=f"Received: {data.decode()}") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + chunks = [b"chunk1", b"chunk2", b"chunk3"] + + async def async_gen() -> AsyncIterator[bytes]: + for chunk in chunks: + yield chunk + + payload = AsyncIterablePayload(async_gen()) + + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + text = await resp.text() + # AsyncIterablePayload is consumed after first use, so redirect gets empty body + assert text == "Received: " + + # Only the first endpoint should have received data + expected_data = b"".join(chunks) + assert len(data_received) == 2 + assert data_received[0] == ("redirect", expected_data) + assert data_received[1] == ("final", b"") # Empty after being consumed + + +async def test_buffered_reader_payload_redirect(aiohttp_client: AiohttpClient) -> None: + """Test that BufferedReaderPayload can be reused across redirects.""" + data_received = [] + + async def redirect_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("redirect", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("final", data)) + return web.Response(text=f"Received: {data.decode()}") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + payload_data = b"buffered reader payload" + buffer = io.BufferedReader(io.BytesIO(payload_data)) # type: ignore[arg-type] + payload = BufferedReaderPayload(buffer) + + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + text = await resp.text() + assert text == "Received: buffered reader payload" + # Both endpoints should have received the data + assert data_received == [("redirect", payload_data), ("final", payload_data)] + + +async def test_string_io_payload_redirect(aiohttp_client: AiohttpClient) -> None: + """Test that StringIOPayload can be reused across redirects.""" + data_received = [] + + async def redirect_handler(request: web.Request) -> web.Response: + data = await request.text() + data_received.append(("redirect", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.text() + data_received.append(("final", data)) + return web.Response(text=f"Received: {data}") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + payload_data = "string io payload" + string_io = io.StringIO(payload_data) + payload = StringIOPayload(string_io) + + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + text = await resp.text() + assert text == "Received: string io payload" + # Both endpoints should have received the data + assert data_received == [("redirect", payload_data), ("final", payload_data)] + + +async def test_bytes_io_payload_redirect(aiohttp_client: AiohttpClient) -> None: + """Test that BytesIOPayload can be reused across redirects.""" + data_received = [] + + async def redirect_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("redirect", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("final", data)) + return web.Response(text=f"Received: {data.decode()}") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + payload_data = b"bytes io payload" + bytes_io = io.BytesIO(payload_data) + payload = BytesIOPayload(bytes_io) + + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + text = await resp.text() + assert text == "Received: bytes io payload" + # Both endpoints should have received the data + assert data_received == [("redirect", payload_data), ("final", payload_data)] + + +async def test_multiple_redirects_with_bytes_payload( + aiohttp_client: AiohttpClient, +) -> None: + """Test BytesPayload with multiple redirects.""" + data_received = [] + + async def redirect1_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("redirect1", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/redirect2") + + async def redirect2_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("redirect2", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("final", data)) + return web.Response(text=f"Received after 2 redirects: {data.decode()}") + + app = web.Application() + app.router.add_post("/redirect", redirect1_handler) + app.router.add_post("/redirect2", redirect2_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + payload_data = b"multi-redirect-test" + payload = BytesPayload(payload_data) + + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + text = await resp.text() + assert text == f"Received after 2 redirects: {payload_data.decode()}" + # All 3 endpoints should have received the same data + assert data_received == [ + ("redirect1", payload_data), + ("redirect2", payload_data), + ("final", payload_data), + ] + + +async def test_redirect_with_empty_payload(aiohttp_client: AiohttpClient) -> None: + """Test redirects with empty payloads.""" + data_received = [] + + async def redirect_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("redirect", data)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + data = await request.read() + data_received.append(("final", data)) + return web.Response(text="Done") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + # Test with empty BytesPayload + payload = BytesPayload(b"") + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + assert data_received == [("redirect", b""), ("final", b"")] + + +async def test_redirect_preserves_content_type(aiohttp_client: AiohttpClient) -> None: + """Test that content-type is preserved across redirects.""" + content_types = [] + + async def redirect_handler(request: web.Request) -> web.Response: + content_types.append(("redirect", request.content_type)) + # Use 307 to preserve POST method + raise web.HTTPTemporaryRedirect("/final_destination") + + async def final_handler(request: web.Request) -> web.Response: + content_types.append(("final", request.content_type)) + return web.Response(text="Done") + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + app.router.add_post("/final_destination", final_handler) + + client = await aiohttp_client(app) + + # StringPayload should set content-type with charset + payload = StringPayload("test data") + resp = await client.post("/redirect", data=payload) + assert resp.status == 200 + # Both requests should have the same content type + assert len(content_types) == 2 + assert content_types[0][1] == "text/plain" + assert content_types[1][1] == "text/plain" + + +class MockedBytesPayload(BytesPayload): + """A BytesPayload that tracks whether close() was called.""" + + def __init__(self, data: bytes) -> None: + super().__init__(data) + self.close_called = False + + async def close(self) -> None: + self.close_called = True + await super().close() + + +async def test_too_many_redirects_closes_payload(aiohttp_client: AiohttpClient) -> None: + """Test that TooManyRedirects exception closes the request payload.""" + + async def redirect_handler(request: web.Request) -> web.Response: + # Read the payload to simulate server processing + await request.read() + count = int(request.match_info.get("count", 0)) + # Use 307 to preserve POST method + return web.Response( + status=307, headers={hdrs.LOCATION: f"/redirect/{count + 1}"} + ) + + app = web.Application() + app.router.add_post(r"/redirect/{count:\d+}", redirect_handler) + + client = await aiohttp_client(app) + + # Create a mocked payload to verify close() is called + payload = MockedBytesPayload(b"test payload") + + with pytest.raises(TooManyRedirects): + await client.post("/redirect/0", data=payload, max_redirects=2) + + assert ( + payload.close_called + ), "Payload.close() was not called when TooManyRedirects was raised" + + +async def test_invalid_url_redirect_closes_payload( + aiohttp_client: AiohttpClient, +) -> None: + """Test that InvalidUrlRedirectClientError exception closes the request payload.""" + + async def redirect_handler(request: web.Request) -> web.Response: + # Read the payload to simulate server processing + await request.read() + # Return an invalid URL that will cause ValueError in URL parsing + # Using a URL with invalid port that's out of range + return web.Response( + status=307, headers={hdrs.LOCATION: "http://example.com:999999/path"} + ) + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + + client = await aiohttp_client(app) + + # Create a mocked payload to verify close() is called + payload = MockedBytesPayload(b"test payload") + + with pytest.raises( + InvalidUrlRedirectClientError, + match="Server attempted redirecting to a location that does not look like a URL", + ): + await client.post("/redirect", data=payload) + + assert ( + payload.close_called + ), "Payload.close() was not called when InvalidUrlRedirectClientError was raised" + + +async def test_non_http_redirect_closes_payload(aiohttp_client: AiohttpClient) -> None: + """Test that NonHttpUrlRedirectClientError exception closes the request payload.""" + + async def redirect_handler(request: web.Request) -> web.Response: + # Read the payload to simulate server processing + await request.read() + # Return a non-HTTP scheme URL + return web.Response( + status=307, headers={hdrs.LOCATION: "ftp://example.com/file"} + ) + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + + client = await aiohttp_client(app) + + # Create a mocked payload to verify close() is called + payload = MockedBytesPayload(b"test payload") + + with pytest.raises(NonHttpUrlRedirectClientError): + await client.post("/redirect", data=payload) + + assert ( + payload.close_called + ), "Payload.close() was not called when NonHttpUrlRedirectClientError was raised" + + +async def test_invalid_redirect_origin_closes_payload( + aiohttp_client: AiohttpClient, +) -> None: + """Test that InvalidUrlRedirectClientError exception (invalid origin) closes the request payload.""" + + async def redirect_handler(request: web.Request) -> web.Response: + # Read the payload to simulate server processing + await request.read() + # Return a URL that will fail origin() check - using a relative URL without host + return web.Response(status=307, headers={hdrs.LOCATION: "http:///path"}) + + app = web.Application() + app.router.add_post("/redirect", redirect_handler) + + client = await aiohttp_client(app) + + # Create a mocked payload to verify close() is called + payload = MockedBytesPayload(b"test payload") + + with pytest.raises( + InvalidUrlRedirectClientError, match="Invalid redirect URL origin" + ): + await client.post("/redirect", data=payload) + + assert ( + payload.close_called + ), "Payload.close() was not called when InvalidUrlRedirectClientError (invalid origin) was raised" diff --git a/tests/test_client_middleware.py b/tests/test_client_middleware.py index e698e8ee825..217877759c0 100644 --- a/tests/test_client_middleware.py +++ b/tests/test_client_middleware.py @@ -1161,3 +1161,111 @@ async def __call__( assert received_bodies[1] == json_str2 assert received_bodies[2] == "" # GET request has no body assert received_bodies[3] == text_data + + +async def test_client_middleware_update_shorter_body( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can update request body using update_body method.""" + + async def handler(request: web.Request) -> web.Response: + body = await request.text() + return web.Response(text=body) + + app = web.Application() + app.router.add_post("/", handler) + server = await aiohttp_server(app) + + async def update_body_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + # Update the request body + await request.update_body(b"short body") + return await handler(request) + + async with ClientSession(middlewares=(update_body_middleware,)) as session: + async with session.post(server.make_url("/"), data=b"original body") as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "short body" + + +async def test_client_middleware_update_longer_body( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can update request body using update_body method.""" + + async def handler(request: web.Request) -> web.Response: + body = await request.text() + return web.Response(text=body) + + app = web.Application() + app.router.add_post("/", handler) + server = await aiohttp_server(app) + + async def update_body_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + # Update the request body + await request.update_body(b"much much longer body") + return await handler(request) + + async with ClientSession(middlewares=(update_body_middleware,)) as session: + async with session.post(server.make_url("/"), data=b"original body") as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "much much longer body" + + +async def test_client_middleware_update_string_body( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can update request body using update_body method.""" + + async def handler(request: web.Request) -> web.Response: + body = await request.text() + return web.Response(text=body) + + app = web.Application() + app.router.add_post("/", handler) + server = await aiohttp_server(app) + + async def update_body_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + # Update the request body + await request.update_body("this is a string") + return await handler(request) + + async with ClientSession(middlewares=(update_body_middleware,)) as session: + async with session.post(server.make_url("/"), data="original string") as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "this is a string" + + +async def test_client_middleware_switch_types( + aiohttp_server: AiohttpServer, +) -> None: + """Test that middleware can update request body using update_body method.""" + + async def handler(request: web.Request) -> web.Response: + body = await request.text() + return web.Response(text=body) + + app = web.Application() + app.router.add_post("/", handler) + server = await aiohttp_server(app) + + async def update_body_middleware( + request: ClientRequest, handler: ClientHandlerType + ) -> ClientResponse: + # Update the request body + await request.update_body("now a string") + return await handler(request) + + async with ClientSession(middlewares=(update_body_middleware,)) as session: + async with session.post(server.make_url("/"), data=b"original bytes") as resp: + assert resp.status == 200 + text = await resp.text() + assert text == "now a string" diff --git a/tests/test_client_middleware_digest_auth.py b/tests/test_client_middleware_digest_auth.py index 26118288913..6da6850bafc 100644 --- a/tests/test_client_middleware_digest_auth.py +++ b/tests/test_client_middleware_digest_auth.py @@ -1,5 +1,6 @@ """Test digest authentication middleware for aiohttp client.""" +import io from hashlib import md5, sha1 from typing import Generator, Union from unittest import mock @@ -18,6 +19,7 @@ unescape_quotes, ) from aiohttp.client_reqrep import ClientResponse +from aiohttp.payload import BytesIOPayload from aiohttp.pytest_plugin import AiohttpServer from aiohttp.web import Application, Request, Response @@ -154,7 +156,7 @@ async def test_authenticate_scenarios( ), ], ) -def test_encode_validation_errors( +async def test_encode_validation_errors( digest_auth_mw: DigestAuthMiddleware, challenge: DigestAuthChallenge, expected_error: str, @@ -162,12 +164,14 @@ def test_encode_validation_errors( """Test validation errors when encoding digest auth headers.""" digest_auth_mw._challenge = challenge with pytest.raises(ClientError, match=expected_error): - digest_auth_mw._encode("GET", URL("http://example.com/resource"), "") + await digest_auth_mw._encode("GET", URL("http://example.com/resource"), b"") -def test_encode_digest_with_md5(auth_mw_with_challenge: DigestAuthMiddleware) -> None: - header = auth_mw_with_challenge._encode( - "GET", URL("http://example.com/resource"), "" +async def test_encode_digest_with_md5( + auth_mw_with_challenge: DigestAuthMiddleware, +) -> None: + header = await auth_mw_with_challenge._encode( + "GET", URL("http://example.com/resource"), b"" ) assert header.startswith("Digest ") assert 'username="user"' in header @@ -177,7 +181,7 @@ def test_encode_digest_with_md5(auth_mw_with_challenge: DigestAuthMiddleware) -> @pytest.mark.parametrize( "algorithm", ["MD5-SESS", "SHA-SESS", "SHA-256-SESS", "SHA-512-SESS"] ) -def test_encode_digest_with_sess_algorithms( +async def test_encode_digest_with_sess_algorithms( digest_auth_mw: DigestAuthMiddleware, qop_challenge: DigestAuthChallenge, algorithm: str, @@ -188,11 +192,13 @@ def test_encode_digest_with_sess_algorithms( challenge["algorithm"] = algorithm digest_auth_mw._challenge = challenge - header = digest_auth_mw._encode("GET", URL("http://example.com/resource"), "") + header = await digest_auth_mw._encode( + "GET", URL("http://example.com/resource"), b"" + ) assert f"algorithm={algorithm}" in header -def test_encode_unsupported_algorithm( +async def test_encode_unsupported_algorithm( digest_auth_mw: DigestAuthMiddleware, basic_challenge: DigestAuthChallenge ) -> None: """Test that unsupported algorithm raises ClientError.""" @@ -202,10 +208,10 @@ def test_encode_unsupported_algorithm( digest_auth_mw._challenge = challenge with pytest.raises(ClientError, match="Unsupported hash algorithm"): - digest_auth_mw._encode("GET", URL("http://example.com/resource"), "") + await digest_auth_mw._encode("GET", URL("http://example.com/resource"), b"") -def test_invalid_qop_rejected( +async def test_invalid_qop_rejected( digest_auth_mw: DigestAuthMiddleware, basic_challenge: DigestAuthChallenge ) -> None: """Test that invalid Quality of Protection values are rejected.""" @@ -217,7 +223,7 @@ def test_invalid_qop_rejected( # This should raise an error about unsupported QoP with pytest.raises(ClientError, match="Unsupported Quality of Protection"): - digest_auth_mw._encode("GET", URL("http://example.com"), "") + await digest_auth_mw._encode("GET", URL("http://example.com"), b"") def compute_expected_digest( @@ -264,14 +270,17 @@ def KD(secret: str, data: str) -> str: @pytest.mark.parametrize( ("body", "body_str"), [ - ("this is a body", "this is a body"), # String case (b"this is a body", "this is a body"), # Bytes case + ( + BytesIOPayload(io.BytesIO(b"this is a body")), + "this is a body", + ), # BytesIOPayload case ], ) -def test_digest_response_exact_match( +async def test_digest_response_exact_match( qop: str, algorithm: str, - body: Union[str, bytes], + body: Union[bytes, BytesIOPayload], body_str: str, mock_sha1_digest: mock.MagicMock, ) -> None: @@ -295,7 +304,7 @@ def test_digest_response_exact_match( auth._last_nonce_bytes = nonce.encode("utf-8") auth._nonce_count = nc - header = auth._encode(method, URL(f"http://host{uri}"), body) + header = await auth._encode(method, URL(f"http://host{uri}"), body) # Get expected digest expected = compute_expected_digest( @@ -402,7 +411,7 @@ def test_middleware_invalid_login() -> None: DigestAuthMiddleware("user:name", "pass") -def test_escaping_quotes_in_auth_header() -> None: +async def test_escaping_quotes_in_auth_header() -> None: """Test that double quotes are properly escaped in auth header.""" auth = DigestAuthMiddleware('user"with"quotes', "pass") auth._challenge = DigestAuthChallenge( @@ -413,7 +422,7 @@ def test_escaping_quotes_in_auth_header() -> None: opaque='opaque"with"quotes', ) - header = auth._encode("GET", URL("http://example.com/path"), "") + header = await auth._encode("GET", URL("http://example.com/path"), b"") # Check that quotes are escaped in the header assert 'username="user\\"with\\"quotes"' in header @@ -422,13 +431,15 @@ def test_escaping_quotes_in_auth_header() -> None: assert 'opaque="opaque\\"with\\"quotes"' in header -def test_template_based_header_construction( +async def test_template_based_header_construction( auth_mw_with_challenge: DigestAuthMiddleware, mock_sha1_digest: mock.MagicMock, mock_md5_digest: mock.MagicMock, ) -> None: """Test that the template-based header construction works correctly.""" - header = auth_mw_with_challenge._encode("GET", URL("http://example.com/test"), "") + header = await auth_mw_with_challenge._encode( + "GET", URL("http://example.com/test"), b"" + ) # Split the header into scheme and parameters scheme, params_str = header.split(" ", 1) diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 6b094171012..361163c87a0 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -3,6 +3,7 @@ import io import pathlib import sys +import warnings from http.cookies import BaseCookie, Morsel, SimpleCookie from typing import ( Any, @@ -756,7 +757,7 @@ async def test_formdata_boundary_from_headers( ) async with await req.send(conn): await asyncio.sleep(0) - assert req.body._boundary == boundary.encode() + assert req.body._boundary == boundary.encode() # type: ignore[union-attr] async def test_post_data(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: @@ -766,7 +767,7 @@ async def test_post_data(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> No ) resp = await req.send(conn) assert "/" == req.url.path - assert b"life=42" == req.body._value + assert b"life=42" == req.body._value # type: ignore[union-attr] assert "application/x-www-form-urlencoded" == req.headers["CONTENT-TYPE"] await req.close() resp.close() @@ -805,7 +806,7 @@ async def test_get_with_data(loop: asyncio.AbstractEventLoop) -> None: meth, URL("http://python.org/"), data={"life": "42"}, loop=loop ) assert "/" == req.url.path - assert b"life=42" == req.body._value + assert b"life=42" == req.body._value # type: ignore[union-attr] await req.close() @@ -942,6 +943,7 @@ async def test_chunked_explicit( req = ClientRequest("post", URL("http://python.org/"), chunked=True, loop=loop) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: m_writer.return_value.write_headers = mock.AsyncMock() + m_writer.return_value.write_eof = mock.AsyncMock() resp = await req.send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] @@ -1002,6 +1004,64 @@ async def test_precompressed_data_stays_intact( await req.close() +async def test_body_with_size_sets_content_length( + loop: asyncio.AbstractEventLoop, +) -> None: + """Test that when body has a size and no Content-Length header is set, it gets added.""" + # Create a BytesPayload which has a size property + data = b"test data" + + # Create request with data that will create a BytesPayload + req = ClientRequest( + "post", + URL("http://python.org/"), + data=data, + loop=loop, + ) + + # Verify Content-Length was set from body.size + assert req.headers["CONTENT-LENGTH"] == str(len(data)) + assert req.body is not None + assert req._body is not None # When _body is set, body returns it + assert req._body.size == len(data) + await req.close() + + +async def test_body_payload_with_size_no_content_length( + loop: asyncio.AbstractEventLoop, +) -> None: + """Test that when a body payload with size is set directly, Content-Length is added.""" + # Create a payload with a known size + data = b"payload data" + bytes_payload = payload.BytesPayload(data) + + # Create request with no data initially + req = ClientRequest( + "post", + URL("http://python.org/"), + loop=loop, + ) + + # Set body directly (bypassing update_body_from_data to avoid it setting Content-Length) + req._body = bytes_payload + + # Ensure conditions for the code path we want to test + assert req._body is not None + assert hdrs.CONTENT_LENGTH not in req.headers + assert req._body.size is not None + assert not req.chunked + + # Now trigger update_transfer_encoding which should set Content-Length + req.update_transfer_encoding() + + # Verify Content-Length was set from body.size + assert req.headers["CONTENT-LENGTH"] == str(len(data)) + assert req.body is bytes_payload + assert req._body is bytes_payload # Access _body which is the Payload + assert req._body.size == len(data) + await req.close() + + async def test_file_upload_not_chunked_seek(loop: asyncio.AbstractEventLoop) -> None: file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: @@ -1260,6 +1320,7 @@ async def test_oserror_on_write_bytes( loop: asyncio.AbstractEventLoop, conn: mock.Mock ) -> None: req = ClientRequest("POST", URL("http://python.org/"), loop=loop) + req.body = b"test data" writer = WriterMock() writer.write.side_effect = OSError @@ -1634,7 +1695,17 @@ async def test_write_bytes_with_iterable_content_length_limit( """Test that write_bytes respects content_length limit for iterable data.""" # Test with iterable data req = ClientRequest("post", URL("http://python.org/"), loop=loop) - req.body = data + + # Convert list to async generator if needed + if isinstance(data, list): + + async def gen() -> AsyncIterator[bytes]: + for chunk in data: + yield chunk + + req.body = gen() # type: ignore[assignment] # https://github.com/python/mypy/issues/12892 + else: + req.body = data writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=7 to truncate at the middle of Part2 @@ -1649,7 +1720,13 @@ async def test_write_bytes_empty_iterable_with_content_length( ) -> None: """Test that write_bytes handles empty iterable body with content_length.""" req = ClientRequest("post", URL("http://python.org/"), loop=loop) - req.body = [] # Empty iterable + + # Create an empty async generator + async def gen() -> AsyncIterator[bytes]: + return + yield # pragma: no cover # This makes it a generator but never executes + + req.body = gen() # type: ignore[assignment] # https://github.com/python/mypy/issues/12892 writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=10 with empty body @@ -1658,3 +1735,392 @@ async def test_write_bytes_empty_iterable_with_content_length( # Verify nothing was written assert len(buf) == 0 await req.close() + + +async def test_warn_if_unclosed_payload_via_body_setter( + make_request: _RequestMaker, +) -> None: + """Test that _warn_if_unclosed_payload is called when setting body with unclosed payload.""" + req = make_request("POST", "http://python.org/") + + # First set a payload that needs manual closing (autoclose=False) + file_payload = payload.BufferedReaderPayload( + io.BufferedReader(io.BytesIO(b"test data")), # type: ignore[arg-type] + encoding="utf-8", + ) + req.body = file_payload + + # Setting body again should trigger the warning for the previous payload + with pytest.warns( + ResourceWarning, + match="The previous request body contains unclosed resources", + ): + req.body = b"new data" + + await req.close() + + +async def test_no_warn_for_autoclose_payload_via_body_setter( + make_request: _RequestMaker, +) -> None: + """Test that no warning is issued for payloads with autoclose=True.""" + req = make_request("POST", "http://python.org/") + + # First set BytesIOPayload which has autoclose=True + bytes_payload = payload.BytesIOPayload(io.BytesIO(b"test data")) + req.body = bytes_payload + + # Setting body again should not trigger warning since previous payload has autoclose=True + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + req.body = b"new data" + + # Filter out any non-ResourceWarning warnings + resource_warnings = [ + w for w in warning_list if issubclass(w.category, ResourceWarning) + ] + assert len(resource_warnings) == 0 + + await req.close() + + +async def test_no_warn_for_consumed_payload_via_body_setter( + make_request: _RequestMaker, +) -> None: + """Test that no warning is issued for already consumed payloads.""" + req = make_request("POST", "http://python.org/") + + # Create a payload that needs manual closing + file_payload = payload.BufferedReaderPayload( + io.BufferedReader(io.BytesIO(b"test data")), # type: ignore[arg-type] + encoding="utf-8", + ) + req.body = file_payload + + # Properly close the payload to mark it as consumed + await file_payload.close() + + # Setting body again should not trigger warning since previous payload is consumed + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + req.body = b"new data" + + # Filter out any non-ResourceWarning warnings + resource_warnings = [ + w for w in warning_list if issubclass(w.category, ResourceWarning) + ] + assert len(resource_warnings) == 0 + + await req.close() + + +async def test_warn_if_unclosed_payload_via_update_body_from_data( + make_request: _RequestMaker, +) -> None: + """Test that _warn_if_unclosed_payload is called via update_body_from_data.""" + req = make_request("POST", "http://python.org/") + + # First set a payload that needs manual closing + file_payload = payload.BufferedReaderPayload( + io.BufferedReader(io.BytesIO(b"initial data")), # type: ignore[arg-type] + encoding="utf-8", + ) + req.update_body_from_data(file_payload) + + # Create FormData for second update + form = aiohttp.FormData() + form.add_field("test", "value") + + # update_body_from_data should trigger the warning for the previous payload + with pytest.warns( + ResourceWarning, + match="The previous request body contains unclosed resources", + ): + req.update_body_from_data(form) + + await req.close() + + +async def test_warn_via_update_with_file_payload( + make_request: _RequestMaker, +) -> None: + """Test warning via update_body_from_data with file-like object.""" + req = make_request("POST", "http://python.org/") + + # First create a file-like object that results in BufferedReaderPayload + buffered1 = io.BufferedReader(io.BytesIO(b"file content 1")) # type: ignore[arg-type] + req.update_body_from_data(buffered1) + + # Second update should warn about the first payload + buffered2 = io.BufferedReader(io.BytesIO(b"file content 2")) # type: ignore[arg-type] + + with pytest.warns( + ResourceWarning, + match="The previous request body contains unclosed resources", + ): + req.update_body_from_data(buffered2) + + await req.close() + + +async def test_no_warn_for_simple_data_via_update_body_from_data( + make_request: _RequestMaker, +) -> None: + """Test that no warning is issued for simple data types.""" + req = make_request("POST", "http://python.org/") + + # Simple bytes data should not trigger warning + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + req.update_body_from_data(b"simple data") + + # Filter out any non-ResourceWarning warnings + resource_warnings = [ + w for w in warning_list if issubclass(w.category, ResourceWarning) + ] + assert len(resource_warnings) == 0 + + await req.close() + + +async def test_update_body_closes_previous_payload( + make_request: _RequestMaker, +) -> None: + """Test that update_body properly closes the previous payload.""" + req = make_request("POST", "http://python.org/") + + # Create a mock payload that tracks if it was closed + mock_payload = mock.Mock(spec=payload.Payload) + mock_payload.close = mock.AsyncMock() + + # Set initial payload + req._body = mock_payload + + # Update body with new data + await req.update_body(b"new body data") + + # Verify the previous payload was closed + mock_payload.close.assert_called_once() + + # Verify new body is set (it's a BytesPayload now) + assert isinstance(req.body, payload.BytesPayload) + + await req.close() + + +async def test_body_setter_closes_previous_payload( + make_request: _RequestMaker, +) -> None: + """Test that body setter properly closes the previous payload.""" + req = make_request("POST", "http://python.org/") + + # Create a mock payload that tracks if it was closed + # We need to use create_autospec to ensure all methods are available + mock_payload = mock.create_autospec(payload.Payload, instance=True) + + # Set initial payload + req._body = mock_payload + + # Update body with new data using setter + req.body = b"new body data" + + # Verify the previous payload was closed using _close + mock_payload._close.assert_called_once() + + # Verify new body is set (it's a BytesPayload now) + assert isinstance(req.body, payload.BytesPayload) + + await req.close() + + +async def test_update_body_with_different_types( + make_request: _RequestMaker, +) -> None: + """Test update_body with various data types.""" + req = make_request("POST", "http://python.org/") + + # Test with bytes + await req.update_body(b"bytes data") + assert isinstance(req.body, payload.BytesPayload) + + # Test with string + await req.update_body("string data") + assert isinstance(req.body, payload.BytesPayload) + + # Test with None (clears body) + await req.update_body(None) + assert req.body == b"" # type: ignore[comparison-overlap] # empty body is represented as b"" + + await req.close() + + +async def test_update_body_with_chunked_encoding( + make_request: _RequestMaker, +) -> None: + """Test that update_body properly handles chunked transfer encoding.""" + # Create request with chunked=True + req = make_request("POST", "http://python.org/", chunked=True) + + # Verify Transfer-Encoding header is set + assert req.headers["Transfer-Encoding"] == "chunked" + assert "Content-Length" not in req.headers + + # Update body - should maintain chunked encoding + await req.update_body(b"chunked data") + assert req.headers["Transfer-Encoding"] == "chunked" + assert "Content-Length" not in req.headers + assert isinstance(req.body, payload.BytesPayload) + + # Update with different body - chunked should remain + await req.update_body(b"different chunked data") + assert req.headers["Transfer-Encoding"] == "chunked" + assert "Content-Length" not in req.headers + + # Clear body - chunked header should remain + await req.update_body(None) + assert req.headers["Transfer-Encoding"] == "chunked" + assert "Content-Length" not in req.headers + + await req.close() + + +async def test_update_body_get_method_with_none_body( + make_request: _RequestMaker, +) -> None: + """Test that update_body with GET method and None body doesn't call update_transfer_encoding.""" + # Create GET request + req = make_request("GET", "http://python.org/") + + # GET requests shouldn't have Transfer-Encoding or Content-Length initially + assert "Transfer-Encoding" not in req.headers + assert "Content-Length" not in req.headers + + # Update body to None - should not trigger update_transfer_encoding + # This covers the branch where body is None AND method is in GET_METHODS + await req.update_body(None) + + # Headers should remain unchanged + assert "Transfer-Encoding" not in req.headers + assert "Content-Length" not in req.headers + + await req.close() + + +async def test_update_body_updates_content_length( + make_request: _RequestMaker, +) -> None: + """Test that update_body properly updates Content-Length header when body size changes.""" + req = make_request("POST", "http://python.org/") + + # Set initial body with known size + await req.update_body(b"initial data") + initial_content_length = req.headers.get("Content-Length") + assert initial_content_length == "12" # len(b"initial data") = 12 + + # Update body with different size + await req.update_body(b"much longer data than before") + new_content_length = req.headers.get("Content-Length") + assert new_content_length == "28" # len(b"much longer data than before") = 28 + + # Update body with shorter data + await req.update_body(b"short") + assert req.headers.get("Content-Length") == "5" # len(b"short") = 5 + + # Clear body + await req.update_body(None) + # For None body, Content-Length should not be set + assert "Content-Length" not in req.headers + + await req.close() + + +async def test_warn_stacklevel_points_to_user_code( + make_request: _RequestMaker, +) -> None: + """Test that the warning stacklevel correctly points to user code.""" + req = make_request("POST", "http://python.org/") + + # First set a payload that needs manual closing (autoclose=False) + file_payload = payload.BufferedReaderPayload( + io.BufferedReader(io.BytesIO(b"test data")), # type: ignore[arg-type] + encoding="utf-8", + ) + req.body = file_payload + + # Capture warnings with their details + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always", ResourceWarning) + # This line should be reported as the warning source + req.body = b"new data" # LINE TO BE REPORTED + + # Find the ResourceWarning + resource_warnings = [ + w for w in warning_list if issubclass(w.category, ResourceWarning) + ] + assert len(resource_warnings) == 1 + + warning = resource_warnings[0] + # The warning should point to the line where we set req.body, not inside the library + # Call chain: user code -> body setter -> _warn_if_unclosed_payload + # stacklevel=3 is used in body setter to skip the setter and _warn_if_unclosed_payload + assert warning.filename == __file__ + # The line number should be the line with "req.body = b'new data'" + # We can't hardcode the line number, but we can verify it's not pointing + # to client_reqrep.py (the library code) + assert "client_reqrep.py" not in warning.filename + + await req.close() + + +async def test_warn_stacklevel_update_body_from_data( + make_request: _RequestMaker, +) -> None: + """Test that warning stacklevel is correct when called from update_body_from_data.""" + req = make_request("POST", "http://python.org/") + + # First set a payload that needs manual closing (autoclose=False) + file_payload = payload.BufferedReaderPayload( + io.BufferedReader(io.BytesIO(b"test data")), # type: ignore[arg-type] + encoding="utf-8", + ) + req.update_body_from_data(file_payload) + + # Capture warnings with their details + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always", ResourceWarning) + # This line should be reported as the warning source + req.update_body_from_data(b"new data") # LINE TO BE REPORTED + + # Find the ResourceWarning + resource_warnings = [ + w for w in warning_list if issubclass(w.category, ResourceWarning) + ] + assert len(resource_warnings) == 1 + + warning = resource_warnings[0] + # For update_body_from_data, stacklevel=3 points to this test file + # Call chain: user code -> update_body_from_data -> _warn_if_unclosed_payload + assert warning.filename == __file__ + assert "client_reqrep.py" not in warning.filename + + await req.close() + + +async def test_expect100_with_body_becomes_none() -> None: + """Test that write_bytes handles body becoming None after expect100 handling.""" + # Create a mock writer and connection + mock_writer = mock.AsyncMock() + mock_conn = mock.Mock() + + # Create a request + req = ClientRequest( + "POST", URL("http://test.example.com/"), loop=asyncio.get_event_loop() + ) + req._body = mock.Mock() # Start with a body + + # Now set body to None to simulate a race condition + # where req._body is set to None after expect100 handling + req._body = None + + await req.write_bytes(mock_writer, mock_conn, None) + await req.close() diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 1fc05b04a4e..4f22b6a3851 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -461,7 +461,9 @@ async def test_reraise_os_error( err = OSError(1, "permission error") req = mock.Mock() req_factory = mock.Mock(return_value=req) - req.send = mock.Mock(side_effect=err) + req.send = mock.AsyncMock(side_effect=err) + req._body = mock.Mock() + req._body.close = mock.AsyncMock() session = await create_session(request_class=req_factory) async def create_connection( @@ -491,7 +493,9 @@ class UnexpectedException(BaseException): err = UnexpectedException("permission error") req = mock.Mock() req_factory = mock.Mock(return_value=req) - req.send = mock.Mock(side_effect=err) + req.send = mock.AsyncMock(side_effect=err) + req._body = mock.Mock() + req._body.close = mock.AsyncMock() session = await create_session(request_class=req_factory) connections = [] @@ -549,6 +553,7 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc] resp.start = mock.AsyncMock() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) req.send = mock.AsyncMock(return_value=resp) # BaseConnector allows all high level protocols by default @@ -611,6 +616,7 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc] resp.start = mock.AsyncMock() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) req.send = mock.AsyncMock(return_value=resp) # UnixConnector allows all high level protocols by default and unix sockets diff --git a/tests/test_formdata.py b/tests/test_formdata.py index 73977f4497a..bda2b754d70 100644 --- a/tests/test_formdata.py +++ b/tests/test_formdata.py @@ -111,7 +111,7 @@ async def test_formdata_field_name_is_not_quoted( assert b'name="email 1"' in buf -async def test_mark_formdata_as_processed(aiohttp_client: AiohttpClient) -> None: +async def test_formdata_is_reusable(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.Response() @@ -123,16 +123,176 @@ async def handler(request: web.Request) -> web.Response: data = FormData() data.add_field("test", "test_value", content_type="application/json") - resp = await client.post("/", data=data) - assert len(data._writer._parts) == 1 + # First request + resp1 = await client.post("/", data=data) + assert resp1.status == 200 + resp1.release() - with pytest.raises(RuntimeError): - await client.post("/", data=data) + # Second request - should work without RuntimeError + resp2 = await client.post("/", data=data) + assert resp2.status == 200 + resp2.release() - resp.release() + # Third request to ensure continued reusability + resp3 = await client.post("/", data=data) + assert resp3.status == 200 + resp3.release() async def test_formdata_boundary_param() -> None: boundary = "some_boundary" form = FormData(boundary=boundary) assert form._writer.boundary == boundary + + +async def test_formdata_reusability_multipart( + writer: StreamWriter, buf: bytearray +) -> None: + form = FormData() + form.add_field("name", "value") + form.add_field("file", b"content", filename="test.txt", content_type="text/plain") + + # First call - should generate multipart payload + payload1 = form() + assert form.is_multipart + buf.clear() + await payload1.write(writer) + result1 = bytes(buf) + + # Verify first result contains expected content + assert b"name" in result1 + assert b"value" in result1 + assert b"test.txt" in result1 + assert b"content" in result1 + assert b"text/plain" in result1 + + # Second call - should generate identical multipart payload + payload2 = form() + buf.clear() + await payload2.write(writer) + result2 = bytes(buf) + + # Results should be identical (same boundary and content) + assert result1 == result2 + + # Third call to ensure continued reusability + payload3 = form() + buf.clear() + await payload3.write(writer) + result3 = bytes(buf) + + assert result1 == result3 + + +async def test_formdata_reusability_urlencoded( + writer: StreamWriter, buf: bytearray +) -> None: + form = FormData() + form.add_field("key1", "value1") + form.add_field("key2", "value2") + + # First call - should generate urlencoded payload + payload1 = form() + assert not form.is_multipart + buf.clear() + await payload1.write(writer) + result1 = bytes(buf) + + # Verify first result contains expected content + assert b"key1=value1" in result1 + assert b"key2=value2" in result1 + + # Second call - should generate identical urlencoded payload + payload2 = form() + buf.clear() + await payload2.write(writer) + result2 = bytes(buf) + + # Results should be identical + assert result1 == result2 + + # Third call to ensure continued reusability + payload3 = form() + buf.clear() + await payload3.write(writer) + result3 = bytes(buf) + + assert result1 == result3 + + +async def test_formdata_reusability_after_adding_fields( + writer: StreamWriter, buf: bytearray +) -> None: + form = FormData() + form.add_field("field1", "value1") + + # First call + payload1 = form() + buf.clear() + await payload1.write(writer) + result1 = bytes(buf) + + # Add more fields after first call + form.add_field("field2", "value2") + + # Second call should include new field + payload2 = form() + buf.clear() + await payload2.write(writer) + result2 = bytes(buf) + + # Results should be different + assert result1 != result2 + assert b"field1=value1" in result2 + assert b"field2=value2" in result2 + assert b"field2=value2" not in result1 + + # Third call should be same as second + payload3 = form() + buf.clear() + await payload3.write(writer) + result3 = bytes(buf) + + assert result2 == result3 + + +async def test_formdata_reusability_with_io_fields( + writer: StreamWriter, buf: bytearray +) -> None: + form = FormData() + + # Create BytesIO and StringIO objects + bytes_io = io.BytesIO(b"bytes content") + string_io = io.StringIO("string content") + + form.add_field( + "bytes_field", + bytes_io, + filename="bytes.bin", + content_type="application/octet-stream", + ) + form.add_field( + "string_field", string_io, filename="text.txt", content_type="text/plain" + ) + + # First call + payload1 = form() + buf.clear() + await payload1.write(writer) + result1 = bytes(buf) + + assert b"bytes content" in result1 + assert b"string content" in result1 + + # Reset IO objects for reuse + bytes_io.seek(0) + string_io.seek(0) + + # Second call - should work with reset IO objects + payload2 = form() + buf.clear() + await payload2.write(writer) + result2 = bytes(buf) + + # Should produce identical results + assert result1 == result2 diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 6d9707f4a5a..f0efa7284bc 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -20,7 +20,12 @@ CONTENT_TYPE, ) from aiohttp.helpers import parse_mimetype -from aiohttp.multipart import BodyPartReader, MultipartReader, MultipartResponseWrapper +from aiohttp.multipart import ( + BodyPartReader, + BodyPartReaderPayload, + MultipartReader, + MultipartResponseWrapper, +) from aiohttp.streams import StreamReader if sys.version_info >= (3, 11): @@ -49,6 +54,22 @@ async def write(chunk: bytes) -> None: return writer +@pytest.fixture +def buf2() -> bytearray: + return bytearray() + + +@pytest.fixture +def stream2(buf2: bytearray) -> mock.Mock: + writer = mock.Mock() + + async def write(chunk: bytes) -> None: + buf2.extend(chunk) + + writer.write.side_effect = write + return writer + + @pytest.fixture def writer() -> aiohttp.MultipartWriter: return aiohttp.MultipartWriter(boundary=":") @@ -1501,3 +1522,179 @@ async def test_async_for_bodypart() -> None: part = aiohttp.BodyPartReader(boundary=b"--:", headers=h, content=stream) async for data in part: assert data == b"foobarbaz" + + +async def test_multipart_writer_reusability( + buf: bytearray, + stream: mock.Mock, + buf2: bytearray, + stream2: mock.Mock, + writer: aiohttp.MultipartWriter, +) -> None: + """Test that MultipartWriter can be written multiple times.""" + # Add some parts + writer.append("text content") + writer.append(b"binary content", {"Content-Type": "application/octet-stream"}) + writer.append_json({"key": "value"}) + + # Test as_bytes multiple times + bytes1 = await writer.as_bytes() + bytes2 = await writer.as_bytes() + bytes3 = await writer.as_bytes() + + # All as_bytes calls should return identical data + assert bytes1 == bytes2 == bytes3 + + # Verify content is there + assert b"text content" in bytes1 + assert b"binary content" in bytes1 + assert b'"key": "value"' in bytes1 + + # First write + buf.clear() + await writer.write(stream) + result1 = bytes(buf) + + # Second write - should produce identical output + buf2.clear() + await writer.write(stream2) + result2 = bytes(buf2) + + # Results should be identical + assert result1 == result2 + + # Third write to ensure continued reusability + buf.clear() + await writer.write(stream) + result3 = bytes(buf) + + assert result1 == result3 + + # as_bytes should still work after writes + bytes4 = await writer.as_bytes() + assert bytes1 == bytes4 + + +async def test_multipart_writer_reusability_with_io_payloads( + buf: bytearray, + stream: mock.Mock, + buf2: bytearray, + stream2: mock.Mock, + writer: aiohttp.MultipartWriter, +) -> None: + """Test that MultipartWriter with IO payloads can be reused.""" + # Create IO objects + bytes_io = io.BytesIO(b"bytes io content") + string_io = io.StringIO("string io content") + + # Add IO payloads + writer.append(bytes_io, {"Content-Type": "application/octet-stream"}) + writer.append(string_io, {"Content-Type": "text/plain"}) + + # Test as_bytes multiple times + bytes1 = await writer.as_bytes() + bytes2 = await writer.as_bytes() + + # All as_bytes calls should return identical data + assert bytes1 == bytes2 + assert b"bytes io content" in bytes1 + assert b"string io content" in bytes1 + + # First write + buf.clear() + await writer.write(stream) + result1 = bytes(buf) + + assert b"bytes io content" in result1 + assert b"string io content" in result1 + + # Reset IO objects for reuse + bytes_io.seek(0) + string_io.seek(0) + + # Second write + buf2.clear() + await writer.write(stream2) + result2 = bytes(buf2) + + # Should produce identical results + assert result1 == result2 + + # Test as_bytes after writes (IO objects should auto-reset) + bytes3 = await writer.as_bytes() + assert bytes1 == bytes3 + + +async def test_body_part_reader_payload_as_bytes() -> None: + """Test that BodyPartReaderPayload.as_bytes raises TypeError.""" + # Create a mock BodyPartReader + headers = CIMultiDictProxy(CIMultiDict({CONTENT_TYPE: "text/plain"})) + protocol = mock.Mock(_reading_paused=False) + stream = StreamReader(protocol, 2**16, loop=asyncio.get_event_loop()) + body_part = BodyPartReader(BOUNDARY, headers, stream) + + # Create the payload + payload = BodyPartReaderPayload(body_part) + + # Test that as_bytes raises TypeError + with pytest.raises(TypeError, match="Unable to read body part as bytes"): + await payload.as_bytes() + + # Test that decode also raises TypeError + with pytest.raises(TypeError, match="Unable to decode"): + payload.decode() + + +async def test_multipart_writer_close_with_exceptions() -> None: + """Test that MultipartWriter.close() continues closing all parts even if one raises.""" + writer = aiohttp.MultipartWriter() + + # Create mock payloads + # First part will raise during close + part1 = mock.Mock() + part1.autoclose = False + part1.consumed = False + part1.close = mock.AsyncMock(side_effect=RuntimeError("Part 1 close failed")) + + # Second part should still get closed + part2 = mock.Mock() + part2.autoclose = False + part2.consumed = False + part2.close = mock.AsyncMock() + + # Third part with autoclose=True should not be closed + part3 = mock.Mock() + part3.autoclose = True + part3.consumed = False + part3.close = mock.AsyncMock() + + # Fourth part already consumed should not be closed + part4 = mock.Mock() + part4.autoclose = False + part4.consumed = True + part4.close = mock.AsyncMock() + + # Add parts to writer's internal list + writer._parts = [ + (part1, "", ""), + (part2, "", ""), + (part3, "", ""), + (part4, "", ""), + ] + + # Close the writer - should not raise despite part1 failing + await writer.close() + + # Verify close was called on appropriate parts + part1.close.assert_called_once() + part2.close.assert_called_once() # Should still be called despite part1 failing + part3.close.assert_not_called() # autoclose=True + part4.close.assert_not_called() # consumed=True + + # Verify writer is marked as consumed + assert writer._consumed is True + + # Calling close again should do nothing + await writer.close() + assert part1.close.call_count == 1 + assert part2.close.call_count == 1 diff --git a/tests/test_payload.py b/tests/test_payload.py index 24dcbaeb819..2d80dc0c65d 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -1,8 +1,11 @@ import array +import asyncio import io +import json import unittest.mock from io import StringIO -from typing import AsyncIterator, Iterator, List, Optional, Union +from pathlib import Path +from typing import AsyncIterator, Iterator, List, Optional, TextIO, Union import pytest from multidict import CIMultiDict @@ -11,6 +14,35 @@ from aiohttp.abc import AbstractStreamWriter +class BufferWriter(AbstractStreamWriter): + """Test writer that captures written bytes in a buffer.""" + + def __init__(self) -> None: + self.buffer = bytearray() + + async def write( + self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + ) -> None: + self.buffer.extend(bytes(chunk)) + + async def write_eof(self, chunk: bytes = b"") -> None: + """No-op for test writer.""" + + async def drain(self) -> None: + """No-op for test writer.""" + + def enable_compression( + self, encoding: str = "deflate", strategy: Optional[int] = None + ) -> None: + """Compression not implemented for test writer.""" + + def enable_chunking(self) -> None: + """Chunking not implemented for test writer.""" + + async def write_headers(self, status_line: str, headers: CIMultiDict[str]) -> None: + """Headers not captured for payload tests.""" + + @pytest.fixture(autouse=True) def cleanup( cleanup_payload_pending_file_closes: None, @@ -415,6 +447,43 @@ async def test_textio_payload_with_encoding() -> None: assert writer.get_written_bytes() == b"hello wo" +async def test_textio_payload_as_bytes() -> None: + """Test TextIOPayload.as_bytes method with different encodings.""" + # Test with UTF-8 encoding + data = io.StringIO("Hello 世界") + p = payload.TextIOPayload(data, encoding="utf-8") + + # Test as_bytes() method + result = await p.as_bytes() + assert result == "Hello 世界".encode() + + # Test that position is restored for multiple reads + result2 = await p.as_bytes() + assert result2 == "Hello 世界".encode() + + # Test with different encoding parameter (should use instance encoding) + result3 = await p.as_bytes(encoding="latin-1") + assert result3 == "Hello 世界".encode() # Should still use utf-8 + + # Test with different encoding in payload + data2 = io.StringIO("Hello World") + p2 = payload.TextIOPayload(data2, encoding="latin-1") + result4 = await p2.as_bytes() + assert result4 == b"Hello World" # latin-1 encoding + + # Test with no explicit encoding (defaults to utf-8) + data3 = io.StringIO("Test データ") + p3 = payload.TextIOPayload(data3) + result5 = await p3.as_bytes() + assert result5 == "Test データ".encode() + + # Test with encoding errors parameter + data4 = io.StringIO("Test") + p4 = payload.TextIOPayload(data4, encoding="ascii") + result6 = await p4.as_bytes(errors="strict") + assert result6 == b"Test" + + async def test_bytesio_payload_backwards_compatibility() -> None: """Test BytesIOPayload.write() backwards compatibility delegates to write_with_length().""" data = io.BytesIO(b"test data") @@ -453,3 +522,607 @@ async def gen() -> AsyncIterator[bytes]: # Should return early without writing anything await p.write_with_length(writer, 10) assert writer.get_written_bytes() == b"" + + +async def test_async_iterable_payload_caching() -> None: + """Test AsyncIterablePayload caching behavior.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"Hello" + yield b" " + yield b"World" + + p = payload.AsyncIterablePayload(gen()) + + # First call to as_bytes should consume iterator and cache + result1 = await p.as_bytes() + assert result1 == b"Hello World" + assert p._iter is None # Iterator exhausted + assert p._cached_chunks == [b"Hello", b" ", b"World"] # Chunks cached + assert p._consumed is False # Not marked as consumed to allow reuse + + # Second call should use cache + result2 = await p.as_bytes() + assert result2 == b"Hello World" + assert p._cached_chunks == [b"Hello", b" ", b"World"] # Still cached + + # decode should work with cached chunks + decoded = p.decode() + assert decoded == "Hello World" + + # write_with_length should use cached chunks + writer = MockStreamWriter() + await p.write_with_length(writer, None) + assert writer.get_written_bytes() == b"Hello World" + + # write_with_length with limit should respect it + writer2 = MockStreamWriter() + await p.write_with_length(writer2, 5) + assert writer2.get_written_bytes() == b"Hello" + + +async def test_async_iterable_payload_decode_without_cache() -> None: + """Test AsyncIterablePayload decode raises error without cache.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"test" + + p = payload.AsyncIterablePayload(gen()) + + # decode should raise without cache + with pytest.raises(TypeError) as excinfo: + p.decode() + assert "Unable to decode - content not cached" in str(excinfo.value) + + # After as_bytes, decode should work + await p.as_bytes() + assert p.decode() == "test" + + +async def test_async_iterable_payload_write_then_cache() -> None: + """Test AsyncIterablePayload behavior when written before caching.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"Hello" + yield b"World" + + p = payload.AsyncIterablePayload(gen()) + + # First write without caching (streaming) + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == b"HelloWorld" + assert p._iter is None # Iterator exhausted + assert p._cached_chunks is None # No cache created + assert p._consumed is True # Marked as consumed + + # Subsequent operations should handle exhausted iterator + result = await p.as_bytes() + assert result == b"" # Empty since iterator exhausted without cache + + # Write should also be empty + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == b"" + + +async def test_bytes_payload_reusability() -> None: + """Test that BytesPayload can be written and read multiple times.""" + data = b"test payload data" + p = payload.BytesPayload(data) + + # First write_with_length + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == data + + # Second write_with_length (simulating redirect) + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == data + + # Write with partial length + writer3 = MockStreamWriter() + await p.write_with_length(writer3, 5) + assert writer3.get_written_bytes() == b"test " + + # Test as_bytes multiple times + bytes1 = await p.as_bytes() + bytes2 = await p.as_bytes() + bytes3 = await p.as_bytes() + assert bytes1 == bytes2 == bytes3 == data + + +async def test_string_payload_reusability() -> None: + """Test that StringPayload can be written and read multiple times.""" + text = "test string data" + expected_bytes = text.encode("utf-8") + p = payload.StringPayload(text) + + # First write_with_length + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == expected_bytes + + # Second write_with_length (simulating redirect) + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == expected_bytes + + # Write with partial length + writer3 = MockStreamWriter() + await p.write_with_length(writer3, 5) + assert writer3.get_written_bytes() == b"test " + + # Test as_bytes multiple times + bytes1 = await p.as_bytes() + bytes2 = await p.as_bytes() + bytes3 = await p.as_bytes() + assert bytes1 == bytes2 == bytes3 == expected_bytes + + +async def test_bytes_io_payload_reusability() -> None: + """Test that BytesIOPayload can be written and read multiple times.""" + data = b"test bytesio payload" + bytes_io = io.BytesIO(data) + p = payload.BytesIOPayload(bytes_io) + + # First write_with_length + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == data + + # Second write_with_length (simulating redirect) + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == data + + # Write with partial length + writer3 = MockStreamWriter() + await p.write_with_length(writer3, 5) + assert writer3.get_written_bytes() == b"test " + + # Test as_bytes multiple times + bytes1 = await p.as_bytes() + bytes2 = await p.as_bytes() + bytes3 = await p.as_bytes() + assert bytes1 == bytes2 == bytes3 == data + + +async def test_string_io_payload_reusability() -> None: + """Test that StringIOPayload can be written and read multiple times.""" + text = "test stringio payload" + expected_bytes = text.encode("utf-8") + string_io = io.StringIO(text) + p = payload.StringIOPayload(string_io) + + # Note: StringIOPayload reads all content in __init__ and becomes a StringPayload + # So it should be fully reusable + + # First write_with_length + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == expected_bytes + + # Second write_with_length (simulating redirect) + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == expected_bytes + + # Write with partial length + writer3 = MockStreamWriter() + await p.write_with_length(writer3, 5) + assert writer3.get_written_bytes() == b"test " + + # Test as_bytes multiple times + bytes1 = await p.as_bytes() + bytes2 = await p.as_bytes() + bytes3 = await p.as_bytes() + assert bytes1 == bytes2 == bytes3 == expected_bytes + + +async def test_buffered_reader_payload_reusability() -> None: + """Test that BufferedReaderPayload can be written and read multiple times.""" + data = b"test buffered reader payload" + buffer = io.BufferedReader(io.BytesIO(data)) # type: ignore[arg-type] + p = payload.BufferedReaderPayload(buffer) + + # First write_with_length + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == data + + # Second write_with_length (simulating redirect) + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == data + + # Write with partial length + writer3 = MockStreamWriter() + await p.write_with_length(writer3, 5) + assert writer3.get_written_bytes() == b"test " + + # Test as_bytes multiple times + bytes1 = await p.as_bytes() + bytes2 = await p.as_bytes() + bytes3 = await p.as_bytes() + assert bytes1 == bytes2 == bytes3 == data + + +async def test_async_iterable_payload_reusability_with_cache() -> None: + """Test that AsyncIterablePayload can be reused when cached via as_bytes.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"async " + yield b"iterable " + yield b"payload" + + expected_data = b"async iterable payload" + p = payload.AsyncIterablePayload(gen()) + + # First call to as_bytes should cache the data + bytes1 = await p.as_bytes() + assert bytes1 == expected_data + assert p._cached_chunks is not None + assert p._iter is None # Iterator exhausted + + # Subsequent as_bytes calls should use cache + bytes2 = await p.as_bytes() + bytes3 = await p.as_bytes() + assert bytes1 == bytes2 == bytes3 == expected_data + + # Now writes should also use the cached data + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == expected_data + + # Second write should also work + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == expected_data + + # Write with partial length + writer3 = MockStreamWriter() + await p.write_with_length(writer3, 5) + assert writer3.get_written_bytes() == b"async" + + +async def test_async_iterable_payload_no_reuse_without_cache() -> None: + """Test that AsyncIterablePayload cannot be reused without caching.""" + + async def gen() -> AsyncIterator[bytes]: + yield b"test " + yield b"data" + + p = payload.AsyncIterablePayload(gen()) + + # First write exhausts the iterator + writer1 = MockStreamWriter() + await p.write_with_length(writer1, None) + assert writer1.get_written_bytes() == b"test data" + assert p._iter is None # Iterator exhausted + assert p._consumed is True + + # Second write should produce empty result + writer2 = MockStreamWriter() + await p.write_with_length(writer2, None) + assert writer2.get_written_bytes() == b"" + + +async def test_bytes_io_payload_close_does_not_close_io() -> None: + """Test that BytesIOPayload close() does not close the underlying BytesIO.""" + bytes_io = io.BytesIO(b"data") + bytes_io_payload = payload.BytesIOPayload(bytes_io) + + # Close the payload + await bytes_io_payload.close() + + # BytesIO should NOT be closed + assert not bytes_io.closed + + # Can still write after close + writer = MockStreamWriter() + await bytes_io_payload.write_with_length(writer, None) + assert writer.get_written_bytes() == b"data" + + +async def test_custom_payload_backwards_compat_as_bytes() -> None: + """Test backwards compatibility for custom Payload that only implements decode().""" + + class LegacyPayload(payload.Payload): + """A custom payload that only implements decode() like old code might do.""" + + def __init__(self, data: str) -> None: + super().__init__(data, headers=CIMultiDict()) + self._data = data + + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + """Custom decode implementation.""" + return self._data + + async def write(self, writer: AbstractStreamWriter) -> None: + """Write implementation which is a no-op for this test.""" + + # Create instance with test data + p = LegacyPayload("Hello, World!") + + # Test that as_bytes() works even though it's not explicitly implemented + # The base class should call decode() and encode the result + result = await p.as_bytes() + assert result == b"Hello, World!" + + # Test with different text + p2 = LegacyPayload("Test with special chars: café") + result_utf8 = await p2.as_bytes(encoding="utf-8") + assert result_utf8 == "Test with special chars: café".encode() + + # Test that decode() still works as expected + assert p.decode() == "Hello, World!" + assert p2.decode() == "Test with special chars: café" + + +async def test_custom_payload_with_encoding_backwards_compat() -> None: + """Test custom Payload with encoding set uses instance encoding for as_bytes().""" + + class EncodedPayload(payload.Payload): + """A custom payload with specific encoding.""" + + def __init__(self, data: str, encoding: str) -> None: + super().__init__(data, headers=CIMultiDict(), encoding=encoding) + self._data = data + + def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str: + """Custom decode implementation.""" + return self._data + + async def write(self, writer: AbstractStreamWriter) -> None: + """Write implementation is a no-op.""" + + # Create instance with specific encoding + p = EncodedPayload("Test data", encoding="latin-1") + + # as_bytes() should use the instance encoding (latin-1) not the default utf-8 + result = await p.as_bytes() + assert result == b"Test data" # ASCII chars are same in latin-1 + + # Test with non-ASCII that differs between encodings + p2 = EncodedPayload("café", encoding="latin-1") + result_latin1 = await p2.as_bytes() + assert result_latin1 == "café".encode("latin-1") + assert result_latin1 != "café".encode() # Should be different bytes + + +async def test_iobase_payload_close_idempotent() -> None: + """Test that IOBasePayload.close() is idempotent and covers the _consumed check.""" + file_like = io.BytesIO(b"test data") + p = payload.IOBasePayload(file_like) + + # First close should set _consumed to True + await p.close() + assert p._consumed is True + + # Second close should be a no-op due to _consumed check (line 621) + await p.close() + assert p._consumed is True + + +def test_iobase_payload_decode() -> None: + """Test IOBasePayload.decode() returns correct string.""" + # Test with UTF-8 encoded text + text = "Hello, 世界! 🌍" + file_like = io.BytesIO(text.encode("utf-8")) + p = payload.IOBasePayload(file_like) + + # decode() should return the original string + assert p.decode() == text + + # Test with different encoding + latin1_text = "café" + file_like2 = io.BytesIO(latin1_text.encode("latin-1")) + p2 = payload.IOBasePayload(file_like2) + assert p2.decode("latin-1") == latin1_text + + # Test that file position is restored + file_like3 = io.BytesIO(b"test data") + file_like3.read(4) # Move position forward + p3 = payload.IOBasePayload(file_like3) + # decode() should read from the stored start position (4) + assert p3.decode() == " data" + + +def test_bytes_payload_size() -> None: + """Test BytesPayload.size property returns correct byte length.""" + # Test with bytes + bp = payload.BytesPayload(b"Hello World") + assert bp.size == 11 + + # Test with empty bytes + bp_empty = payload.BytesPayload(b"") + assert bp_empty.size == 0 + + # Test with bytearray + ba = bytearray(b"Hello World") + bp_array = payload.BytesPayload(ba) + assert bp_array.size == 11 + + +def test_string_payload_size() -> None: + """Test StringPayload.size property with different encodings.""" + # Test ASCII string with default UTF-8 encoding + sp = payload.StringPayload("Hello World") + assert sp.size == 11 + + # Test Unicode string with default UTF-8 encoding + unicode_str = "Hello 世界" + sp_unicode = payload.StringPayload(unicode_str) + assert sp_unicode.size == len(unicode_str.encode("utf-8")) + + # Test with UTF-16 encoding + sp_utf16 = payload.StringPayload("Hello World", encoding="utf-16") + assert sp_utf16.size == len("Hello World".encode("utf-16")) + + # Test with latin-1 encoding + sp_latin1 = payload.StringPayload("café", encoding="latin-1") + assert sp_latin1.size == len("café".encode("latin-1")) + + +def test_string_io_payload_size() -> None: + """Test StringIOPayload.size property.""" + # Test normal string + sio = StringIO("Hello World") + siop = payload.StringIOPayload(sio) + assert siop.size == 11 + + # Test Unicode string + sio_unicode = StringIO("Hello 世界") + siop_unicode = payload.StringIOPayload(sio_unicode) + assert siop_unicode.size == len("Hello 世界".encode()) + + # Test with custom encoding + sio_custom = StringIO("Hello") + siop_custom = payload.StringIOPayload(sio_custom, encoding="utf-16") + assert siop_custom.size == len("Hello".encode("utf-16")) + + # Test with emoji to ensure correct byte count + sio_emoji = StringIO("Hello 👋🌍") + siop_emoji = payload.StringIOPayload(sio_emoji) + assert siop_emoji.size == len("Hello 👋🌍".encode()) + # Verify it's not the string length + assert siop_emoji.size != len("Hello 👋🌍") + + +def test_all_string_payloads_size_is_bytes() -> None: + """Test that all string-like payload classes report size in bytes, not string length.""" + # Test string with multibyte characters + test_str = "Hello 👋 世界 🌍" # Contains emoji and Chinese characters + + # StringPayload + sp = payload.StringPayload(test_str) + assert sp.size == len(test_str.encode("utf-8")) + assert sp.size != len(test_str) # Ensure it's not string length + + # StringIOPayload + sio = StringIO(test_str) + siop = payload.StringIOPayload(sio) + assert siop.size == len(test_str.encode("utf-8")) + assert siop.size != len(test_str) + + # Test with different encoding + sp_utf16 = payload.StringPayload(test_str, encoding="utf-16") + assert sp_utf16.size == len(test_str.encode("utf-16")) + assert sp_utf16.size != sp.size # Different encoding = different size + + # JsonPayload (which extends BytesPayload) + json_data = {"message": test_str} + jp = payload.JsonPayload(json_data) + # JSON escapes Unicode, so we need to check the actual encoded size + json_str = json.dumps(json_data) + assert jp.size == len(json_str.encode("utf-8")) + + # Test JsonPayload with ensure_ascii=False to get actual UTF-8 encoding + jp_utf8 = payload.JsonPayload( + json_data, dumps=lambda x: json.dumps(x, ensure_ascii=False) + ) + json_str_utf8 = json.dumps(json_data, ensure_ascii=False) + assert jp_utf8.size == len(json_str_utf8.encode("utf-8")) + assert jp_utf8.size != len( + json_str_utf8 + ) # Now it's different due to multibyte chars + + +def test_bytes_io_payload_size() -> None: + """Test BytesIOPayload.size property.""" + # Test normal bytes + bio = io.BytesIO(b"Hello World") + biop = payload.BytesIOPayload(bio) + assert biop.size == 11 + + # Test empty BytesIO + bio_empty = io.BytesIO(b"") + biop_empty = payload.BytesIOPayload(bio_empty) + assert biop_empty.size == 0 + + # Test with position not at start + bio_pos = io.BytesIO(b"Hello World") + bio_pos.seek(5) + biop_pos = payload.BytesIOPayload(bio_pos) + assert biop_pos.size == 6 # Size should be from position to end + + +def test_json_payload_size() -> None: + """Test JsonPayload.size property.""" + # Test simple dict + data = {"hello": "world"} + jp = payload.JsonPayload(data) + expected_json = json.dumps(data) # Use actual json.dumps output + assert jp.size == len(expected_json.encode("utf-8")) + + # Test with Unicode + data_unicode = {"message": "Hello 世界"} + jp_unicode = payload.JsonPayload(data_unicode) + expected_unicode = json.dumps(data_unicode) + assert jp_unicode.size == len(expected_unicode.encode("utf-8")) + + # Test with custom encoding + data_custom = {"test": "data"} + jp_custom = payload.JsonPayload(data_custom, encoding="utf-16") + expected_custom = json.dumps(data_custom) + assert jp_custom.size == len(expected_custom.encode("utf-16")) + + +async def test_text_io_payload_size_matches_file_encoding(tmp_path: Path) -> None: + """Test TextIOPayload.size when file encoding matches payload encoding.""" + # Create UTF-8 file + utf8_file = tmp_path / "test_utf8.txt" + content = "Hello 世界" + + # Write file in executor + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, utf8_file.write_text, content, "utf-8") + + # Open file in executor + def open_file() -> TextIO: + return open(utf8_file, encoding="utf-8") + + f = await loop.run_in_executor(None, open_file) + try: + tiop = payload.TextIOPayload(f) + # Size should match the actual UTF-8 encoded size + assert tiop.size == len(content.encode("utf-8")) + finally: + await loop.run_in_executor(None, f.close) + + +async def test_text_io_payload_size_utf16(tmp_path: Path) -> None: + """Test TextIOPayload.size reports correct size with utf-16.""" + # Create UTF-16 file + utf16_file = tmp_path / "test_utf16.txt" + content = "Hello World" + + loop = asyncio.get_running_loop() + # Write file in executor + await loop.run_in_executor(None, utf16_file.write_text, content, "utf-16") + + # Get file size in executor + utf16_file_size = await loop.run_in_executor( + None, lambda: utf16_file.stat().st_size + ) + + # Open file in executor + def open_file() -> TextIO: + return open(utf16_file, encoding="utf-16") + + f = await loop.run_in_executor(None, open_file) + try: + tiop = payload.TextIOPayload(f, encoding="utf-16") + # Payload reports file size on disk (UTF-16) + assert tiop.size == utf16_file_size + + # Write to a buffer to see what actually gets sent + writer = BufferWriter() + await tiop.write(writer) + + # Check that the actual written bytes match file size + assert len(writer.buffer) == utf16_file_size + finally: + await loop.run_in_executor(None, f.close)