diff --git a/requirements.in b/requirements.in index 417afeb5d..00920aed4 100644 --- a/requirements.in +++ b/requirements.in @@ -1,3 +1,4 @@ +async-asgi-testclient black flake8 mypy>=0.941 diff --git a/requirements.txt b/requirements.txt index e496bc72e..93db5afc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,13 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.14 # by the following command: # # pip-compile # alabaster==1.0.0 # via sphinx +async-asgi-testclient==1.4.11 + # via -r requirements.in babel==2.17.0 # via sphinx black==25.1.0 @@ -48,6 +50,8 @@ markupsafe==3.0.2 # via jinja2 mccabe==0.7.0 # via flake8 +multidict==6.7.0 + # via async-asgi-testclient mypy==1.15.0 # via -r requirements.in mypy-extensions==1.1.0 @@ -85,7 +89,9 @@ pyproject-hooks==1.2.0 # build # pip-tools requests==2.32.4 - # via sphinx + # via + # async-asgi-testclient + # sphinx roman-numerals-py==3.1.0 # via sphinx snowballstemmer==3.0.1 diff --git a/tornado/asgi.py b/tornado/asgi.py new file mode 100644 index 000000000..1d11775e2 --- /dev/null +++ b/tornado/asgi.py @@ -0,0 +1,171 @@ +from asyncio import create_task, Future, Task, wait +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Optional, Union + +from tornado.httputil import ( + HTTPConnection, + HTTPHeaders, + RequestStartLine, + ResponseStartLine, +) +from tornado.web import Application + +ReceiveCallable = Callable[[], Awaitable[dict]] +SendCallable = Callable[[dict], Awaitable[None]] + + +@dataclass +class ASGIHTTPRequestContext: + """To convey connection details to the HTTPServerRequest object""" + + protocol: str + address: Optional[tuple] = None + remote_ip: str = "0.0.0.0" + + +class ASGIHTTPConnection(HTTPConnection): + """Represents the connection for 1 request/response pair + + This provides the API for sending the response. + """ + + def __init__(self, send_cb: SendCallable, context: ASGIHTTPRequestContext): + self.send_cb = send_cb + self.context = context + self.task_holder: set[Task] = set() + self._close_callback: Callable[[], None] | None = None + self._request_finished: Future[None] = Future() + + # Various tornado APIs (e.g. RequestHandler.flush()) return a Future which + # application code does not need to await. The operations these represent + # are expected to complete even if the Future is discarded. ASGI is based + # on 'awaitable callables', which do not guarantee this. So we need to hold + # references to tasks until they complete + def _bg_task(self, coro) -> Future: # type: ignore + task = create_task(coro) + self.task_holder.add(task) + task.add_done_callback(self.task_holder.discard) + return task + + async def _write_headers( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> None: + assert isinstance(start_line, ResponseStartLine) + await self.send_cb( + { + "type": "http.response.start", + "status": start_line.code, + "headers": [ + [k.lower().encode("latin1"), v.encode("latin1")] + for k, v in headers.get_all() + ], + } + ) + if chunk is not None: + await self._write(chunk) + + def write_headers( + self, + start_line: Union["RequestStartLine", "ResponseStartLine"], + headers: HTTPHeaders, + chunk: Optional[bytes] = None, + ) -> "Future[None]": + return self._bg_task(self._write_headers(start_line, headers, chunk)) + + async def _write(self, chunk: bytes) -> None: + await self.send_cb( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + + def write(self, chunk: bytes) -> "Future[None]": + return self._bg_task(self._write(chunk)) + + def finish(self) -> None: + self._bg_task( + self.send_cb( + { + "type": "http.response.body", + "body": b"", + "more_body": False, + } + ) + ) + self._request_finished.set_result(None) + + def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None: + self._close_callback = callback + + def _on_connection_close(self) -> None: + if self._close_callback is not None: + callback = self._close_callback + self._close_callback = None + callback() + self._request_finished.set_result(None) + + async def wait_finish(self) -> None: + """For the ASGI interface: wait for all input & output to finish""" + await self._request_finished + await wait(self.task_holder) + + +class ASGIAdapter: + """Wrap a tornado application object to use with an ASGI server""" + + def __init__(self, application: Application): + self.application = application + + async def __call__( + self, scope: dict, receive: ReceiveCallable, send: SendCallable + ) -> None: + if scope["type"] == "http": + return await self.http_scope(scope, receive, send) + raise KeyError(scope["type"]) + + async def http_scope( + self, scope: dict, receive: ReceiveCallable, send: SendCallable + ) -> None: + """Handles one HTTP request""" + ctx = ASGIHTTPRequestContext(scope["scheme"]) + if client_addr := scope.get("client", None): + ctx.address = tuple(client_addr) + ctx.remote_ip = client_addr[0] + + conn = ASGIHTTPConnection(send, ctx) + msg_delegate = self.application.start_request(None, conn) + start_line, req_headers = self._http_convert_req(scope) + if (fut := msg_delegate.headers_received(start_line, req_headers)) is not None: + await fut + + while True: + event = await receive() + if event["type"] == "http.request": + if chunk := event.get("body", b""): + if (fut := msg_delegate.data_received(chunk)) is not None: + await fut + if not event.get("more_body", False): + msg_delegate.finish() + break + elif event["type"] == "http.disconnect": + msg_delegate.on_connection_close() + conn._on_connection_close() + break + + await conn.wait_finish() + + @staticmethod + def _http_convert_req(scope: dict) -> tuple[RequestStartLine, HTTPHeaders]: + req_target = scope["path"] + if qs := scope["query_string"]: + req_target += "?" + qs.decode("latin1") + req_start_line = RequestStartLine( + scope["method"], req_target, scope["http_version"] + ) + req_headers = HTTPHeaders() + for k, v in scope["headers"]: + req_headers.add(k.decode("latin1"), v.decode("latin1")) + + return req_start_line, req_headers diff --git a/tornado/httputil.py b/tornado/httputil.py index 74dfb87f1..1224be92a 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -778,6 +778,22 @@ def finish(self) -> None: """Indicates that the last body data has been written.""" raise NotImplementedError() + def set_close_callback( + self, callback: Optional[collections.abc.Callable[[], None]] + ) -> None: + """Sets a callback that will be run when the connection is closed. + + Note that this callback is slightly different from + `.HTTPMessageDelegate.on_connection_close`: The + `.HTTPMessageDelegate` method is called when the connection is + closed while receiving a message. This callback is used when + there is not an active delegate (for example, on the server + side this callback is used if the client closes the connection + after sending its request but before receiving all the + response. + """ + raise NotImplementedError() + def url_concat( url: str, diff --git a/tornado/test/asgi_test.py b/tornado/test/asgi_test.py new file mode 100644 index 000000000..5a6f314e2 --- /dev/null +++ b/tornado/test/asgi_test.py @@ -0,0 +1,74 @@ +import unittest + +try: + import async_asgi_testclient # type: ignore +except ImportError: + async_asgi_testclient = None + +from tornado.asgi import ASGIAdapter +from tornado.web import Application, RequestHandler +from tornado.testing import AsyncTestCase, gen_test + + +class BasicHandler(RequestHandler): + def get(self): + name = self.get_argument("name", "world") + self.write(f"Hello, {name}") + + +class InspectHandler(RequestHandler): + def make_response(self, path_var): + # Send the response as JSON + self.finish( + { + "method": self.request.method, + "path": self.request.path, + "path_var": path_var, + "query_params": { + k: self.get_query_arguments(k) for k in self.request.query_arguments + }, + "body": self.request.body.decode("latin1"), + } + ) + + def get(self, path_var): + return self.make_response(path_var) + + def post(self, path_var): + return self.make_response(path_var) + + +@unittest.skipIf( + async_asgi_testclient is None, "async_asgi_testclient module not present" +) +class AsyncASGITestCase(AsyncTestCase): + def setUp(self) -> None: + super().setUp() + self.asgi_app = ASGIAdapter( + Application([(r"/", BasicHandler), (r"/inspect(/.*)", InspectHandler)]) + ) + self.client = async_asgi_testclient.TestClient(self.asgi_app) + + @gen_test(timeout=10) + async def test_basic_request(self): + resp = await self.client.get("/?name=foo") + assert resp.status_code == 200 + assert resp.text == "Hello, foo" + + @gen_test(timeout=10) + async def test_get_request_details(self): + resp = await self.client.get("/inspect/foo/?bar=baz") + d = resp.json() + assert d["method"] == "GET" + assert d["path"] == "/inspect/foo/" + assert d["query_params"] == {"bar": ["baz"]} + assert d["body"] == "" + + @gen_test(timeout=10) + async def test_post_request_details(self): + resp = await self.client.post("/inspect/foo/?bar=baz", data=b"123") + d = resp.json() + assert d["method"] == "POST" + assert d["path"] == "/inspect/foo/" + assert d["query_params"] == {"bar": ["baz"]} + assert d["body"] == "123" diff --git a/tornado/test/runtests.py b/tornado/test/runtests.py index d7eb51f9a..761cb9311 100644 --- a/tornado/test/runtests.py +++ b/tornado/test/runtests.py @@ -20,6 +20,7 @@ "tornado.httputil.doctests", "tornado.iostream.doctests", "tornado.util.doctests", + "tornado.test.asgi_test", "tornado.test.asyncio_test", "tornado.test.auth_test", "tornado.test.autoreload_test", diff --git a/tox.ini b/tox.ini index db0a4b604..e7496947f 100644 --- a/tox.ini +++ b/tox.ini @@ -50,6 +50,7 @@ deps = # And since CaresResolver is deprecated, I do not expect to fix it, so just # pin the previous version. (This should really be in requirements.{in,txt} instead) full: pycares<5 + full: async-asgi-testclient docs: -r{toxinidir}/requirements.txt lint: -r{toxinidir}/requirements.txt