diff --git a/Cargo.toml b/Cargo.toml index 18cceffbd0c..02e05c10309 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ hashbrown = { version = ">= 0.14.5, < 0.16", optional = true } indexmap = { version = ">= 2.5.0, < 3", optional = true } num-bigint = { version = "0.4.2", optional = true } num-complex = { version = ">= 0.4.6, < 0.5", optional = true } -num-rational = {version = "0.4.1", optional = true } +num-rational = { version = "0.4.1", optional = true } rust_decimal = { version = "1.15", default-features = false, optional = true } serde = { version = "1.0", optional = true } smallvec = { version = "1.0", optional = true } @@ -63,7 +63,7 @@ rayon = "1.6.1" futures = "0.3.28" tempfile = "3.12.0" static_assertions = "1.1.0" -uuid = {version = "1.10.0", features = ["v4"] } +uuid = { version = "1.10.0", features = ["v4"] } [build-dependencies] pyo3-build-config = { path = "pyo3-build-config", version = "=0.23.3", features = ["resolve-config"] } @@ -74,6 +74,9 @@ default = ["macros"] # Enables support for `async fn` for `#[pyfunction]` and `#[pymethods]`. experimental-async = ["macros", "pyo3-macros/experimental-async"] +# Switch coroutine implementation to anyio instead of asyncio +anyio = ["experimental-async"] + # Enables pyo3::inspect module and additional type information on FromPyObject # and IntoPy traits experimental-inspect = [] diff --git a/guide/src/SUMMARY.md b/guide/src/SUMMARY.md index cf987b72625..998434b6403 100644 --- a/guide/src/SUMMARY.md +++ b/guide/src/SUMMARY.md @@ -25,6 +25,7 @@ - [Mapping of Rust types to Python types](conversions/tables.md) - [Conversion traits](conversions/traits.md) - [Using `async` and `await`](async-await.md) + - [Awaiting Python awaitables](async-await/awaiting_python_awaitables) - [Parallelism](parallelism.md) - [Supporting Free-Threaded Python](free-threading.md) - [Debugging](debugging.md) diff --git a/guide/src/async-await.md b/guide/src/async-await.md index 27574181804..613aca5f97c 100644 --- a/guide/src/async-await.md +++ b/guide/src/async-await.md @@ -25,8 +25,6 @@ async fn sleep(seconds: f64, result: Option) -> Option { # } ``` -*Python awaitables instantiated with this method can only be awaited in *asyncio* context. Other Python async runtime may be supported in the future.* - ## `Send + 'static` constraint Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + 'static` to be embedded in a Python object. @@ -94,6 +92,13 @@ async fn cancellable(#[pyo3(cancel_handle)] mut cancel: CancelHandle) { # } ``` +## *asyncio* vs. *anyio* + +By default, Python awaitables instantiated with `async fn` can only be awaited in *asyncio* context. + +PyO3 can also target [*anyio*](https://github.com/agronholm/anyio) with the dedicated `anyio` Cargo feature. With it enabled, `async fn` become awaitable both in *asyncio* or [*trio*](https://github.com/python-trio/trio) context. +However, it requires to have the [*sniffio*](https://github.com/python-trio/sniffio) (or *anyio*) library installed. + ## The `Coroutine` type To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine). diff --git a/guide/src/async-await/awaiting_python_awaitables.md b/guide/src/async-await/awaiting_python_awaitables.md new file mode 100644 index 00000000000..fdfa7ebbc2b --- /dev/null +++ b/guide/src/async-await/awaiting_python_awaitables.md @@ -0,0 +1,62 @@ +# Awaiting Python awaitables + +Python awaitable can be awaited on Rust side +using [`await_in_coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/fn.await_in_coroutine). + +```rust +# # ![allow(dead_code)] +# #[cfg(feature = "experimental-async")] { +use pyo3::{prelude::*, coroutine::await_in_coroutine}; + +#[pyfunction] +async fn wrap_awaitable(awaitable: PyObject) -> PyResult { + Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?.await +} +# } +``` + +Behind the scene, `await_in_coroutine` calls the `__await__` method of the Python awaitable (or `__iter__` for +generator-based coroutine). + +## Restrictions + +As the name suggests, `await_in_coroutine` resulting future can only be awaited in coroutine context. Otherwise, it +panics. + +```rust +# # ![allow(dead_code)] +# #[cfg(feature = "experimental-async")] { +use pyo3::{prelude::*, coroutine::await_in_coroutine}; + +#[pyfunction] +fn block_on(awaitable: PyObject) -> PyResult { + let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?; + futures::executor::block_on(future) // ERROR: Python awaitable must be awaited in coroutine context +} +# } +``` + +The future must also be the only one to be awaited at a time; it means that it's forbidden to await it in a `select!`. +Otherwise, it panics. + +```rust +# # ![allow(dead_code)] +# #[cfg(feature = "experimental-async")] { +use futures::FutureExt; +use pyo3::{prelude::*, coroutine::await_in_coroutine}; + +#[pyfunction] +async fn select(awaitable: PyObject) -> PyResult { + let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?; + futures::select_biased! { + _ = std::future::pending::<()>().fuse() => unreachable!(), + res = future.fuse() => res, // ERROR: Python awaitable mixed with Rust future + } +} +# } +``` + +These restrictions exist because awaiting a `await_in_coroutine` future strongly binds it to the +enclosing coroutine. The coroutine will then delegate its `send`/`throw`/`close` methods to the +awaited future. If it was awaited in a `select!`, `Coroutine::send` would no able to know if +the value passed would have to be delegated or not. diff --git a/guide/src/building-and-distribution.md b/guide/src/building-and-distribution.md index d3474fedaf7..9c14e5782ad 100644 --- a/guide/src/building-and-distribution.md +++ b/guide/src/building-and-distribution.md @@ -62,6 +62,7 @@ There are many ways to go about this: it is possible to use `cargo` to build the PyO3 has some Cargo features to configure projects for building Python extension modules: - The `extension-module` feature, which must be enabled when building Python extension modules. - The `abi3` feature and its version-specific `abi3-pyXY` companions, which are used to opt-in to the limited Python API in order to support multiple Python versions in a single wheel. +- The `anyio` feature, making PyO3 coroutines target [*anyio*](https://github.com/agronholm/anyio) instead of *asyncio*; either [*sniffio*](https://github.com/python-trio/sniffio) or *anyio* should be added as dependency of the Python extension. This section describes each of these packaging tools before describing how to build manually without them. It then proceeds with an explanation of the `extension-module` feature. Finally, there is a section describing PyO3's `abi3` features. diff --git a/newsfragments/3611.added.md b/newsfragments/3611.added.md new file mode 100644 index 00000000000..75aa9aee8fd --- /dev/null +++ b/newsfragments/3611.added.md @@ -0,0 +1 @@ +Add `coroutine::await_in_coroutine` to await awaitables in coroutine context diff --git a/newsfragments/3612.added.md b/newsfragments/3612.added.md new file mode 100644 index 00000000000..4f5f2f24014 --- /dev/null +++ b/newsfragments/3612.added.md @@ -0,0 +1 @@ +Support anyio with a Cargo feature \ No newline at end of file diff --git a/pyo3-ffi/src/abstract_.rs b/pyo3-ffi/src/abstract_.rs index 82eecce05bd..0db4e726a0f 100644 --- a/pyo3-ffi/src/abstract_.rs +++ b/pyo3-ffi/src/abstract_.rs @@ -1,8 +1,9 @@ -use crate::object::*; -use crate::pyport::Py_ssize_t; +use std::os::raw::{c_char, c_int}; + #[cfg(any(Py_3_12, all(Py_3_8, not(Py_LIMITED_API))))] use libc::size_t; -use std::os::raw::{c_char, c_int}; + +use crate::{object::*, pyport::Py_ssize_t}; #[inline] #[cfg(all(not(Py_3_13), not(PyPy)))] // CPython exposed as a function in 3.13, in object.h @@ -143,7 +144,11 @@ extern "C" { pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject; #[cfg(all(not(PyPy), Py_3_10))] #[cfg_attr(PyPy, link_name = "PyPyIter_Send")] - pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject); + pub fn PyIter_Send( + iter: *mut PyObject, + arg: *mut PyObject, + presult: *mut *mut PyObject, + ) -> c_int; #[cfg_attr(PyPy, link_name = "PyPyNumber_Check")] pub fn PyNumber_Check(o: *mut PyObject) -> c_int; diff --git a/pytests/Cargo.toml b/pytests/Cargo.toml index 1fee3093275..4c6d958dd46 100644 --- a/pytests/Cargo.toml +++ b/pytests/Cargo.toml @@ -8,7 +8,8 @@ publish = false rust-version = "1.63" [dependencies] -pyo3 = { path = "../", features = ["extension-module"] } +futures = "0.3.29" +pyo3 = { path = "../", features = ["extension-module", "anyio"] } [build-dependencies] pyo3-build-config = { path = "../pyo3-build-config" } diff --git a/pytests/pyproject.toml b/pytests/pyproject.toml index 5f78a573124..90d1867b88d 100644 --- a/pytests/pyproject.toml +++ b/pytests/pyproject.toml @@ -20,6 +20,7 @@ classifiers = [ [project.optional-dependencies] dev = [ + "anyio[trio]>=4.0", "hypothesis>=3.55", "pytest-asyncio>=0.21", "pytest-benchmark>=3.4", diff --git a/pytests/src/anyio.rs b/pytests/src/anyio.rs new file mode 100644 index 00000000000..e123a5ae2a3 --- /dev/null +++ b/pytests/src/anyio.rs @@ -0,0 +1,34 @@ +use std::{task::Poll, thread, time::Duration}; + +use futures::{channel::oneshot, future::poll_fn}; +use pyo3::prelude::*; + +#[pyfunction(signature = (seconds, result = None))] +async fn sleep(seconds: f64, result: Option) -> Option { + if seconds <= 0.0 { + let mut ready = false; + poll_fn(|cx| { + if ready { + return Poll::Ready(()); + } + ready = true; + cx.waker().wake_by_ref(); + Poll::Pending + }) + .await; + } else { + let (tx, rx) = oneshot::channel(); + thread::spawn(move || { + thread::sleep(Duration::from_secs_f64(seconds)); + tx.send(()).unwrap(); + }); + rx.await.unwrap(); + } + result +} + +#[pymodule] +pub fn anyio(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sleep, m)?)?; + Ok(()) +} diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index b6c32230dac..e1589125484 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3::wrap_pymodule; +pub mod anyio; pub mod awaitable; pub mod buf_and_str; pub mod comparisons; @@ -19,6 +20,7 @@ pub mod subclassing; #[pymodule(gil_used = false)] fn pyo3_pytests(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_wrapped(wrap_pymodule!(anyio::anyio))?; m.add_wrapped(wrap_pymodule!(awaitable::awaitable))?; #[cfg(not(Py_LIMITED_API))] m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?; @@ -41,6 +43,7 @@ fn pyo3_pytests(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { let sys = PyModule::import(py, "sys")?; let sys_modules = sys.getattr("modules")?.downcast_into::()?; + sys_modules.set_item("pyo3_pytests.anyio", m.getattr("anyio")?)?; sys_modules.set_item("pyo3_pytests.awaitable", m.getattr("awaitable")?)?; sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?; sys_modules.set_item("pyo3_pytests.comparisons", m.getattr("comparisons")?)?; diff --git a/pytests/tests/test_anyio.py b/pytests/tests/test_anyio.py new file mode 100644 index 00000000000..c48435bd2fc --- /dev/null +++ b/pytests/tests/test_anyio.py @@ -0,0 +1,14 @@ +import asyncio + +from pyo3_pytests.anyio import sleep +import trio + + +def test_asyncio(): + assert asyncio.run(sleep(0)) is None + assert asyncio.run(sleep(0.1, 42)) == 42 + + +def test_trio(): + assert trio.run(sleep, 0) is None + assert trio.run(sleep, 0.1, 42) == 42 diff --git a/src/coroutine.rs b/src/coroutine.rs index 671defb1770..8dbb4bd9512 100644 --- a/src/coroutine.rs +++ b/src/coroutine.rs @@ -11,20 +11,33 @@ use std::{ use pyo3_macros::{pyclass, pymethods}; use crate::{ - coroutine::{cancel::ThrowCallback, waker::AsyncioWaker}, - exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration}, + coroutine::waker::CoroutineWaker, + exceptions::{PyAttributeError, PyGeneratorExit, PyRuntimeError, PyStopIteration}, panic::PanicException, - types::{string::PyStringMethods, PyIterator, PyString}, - Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, PyObject, PyResult, Python, + types::{string::PyStringMethods, PyString}, + Bound, IntoPyObject, IntoPyObjectExt, Py, PyErr, PyObject, PyResult, Python, }; -pub(crate) mod cancel; +#[cfg(feature = "anyio")] +mod anyio; +mod asyncio; +mod awaitable; +mod cancel; +#[cfg(feature = "anyio")] +mod trio; mod waker; -pub use cancel::CancelHandle; +pub use awaitable::await_in_coroutine; +pub use cancel::{CancelHandle, ThrowCallback}; const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine"; +pub(crate) enum CoroOp { + Send(PyObject), + Throw(PyObject), + Close, +} + /// Python coroutine wrapping a [`Future`]. #[pyclass(crate = "crate")] pub struct Coroutine { @@ -32,7 +45,7 @@ pub struct Coroutine { qualname_prefix: Option<&'static str>, throw_callback: Option, future: Option> + Send>>>, - waker: Option>, + waker: Option>, } // Safety: `Coroutine` is allowed to be `Sync` even though the future is not, @@ -71,55 +84,58 @@ impl Coroutine { } } - fn poll(&mut self, py: Python<'_>, throw: Option) -> PyResult { + fn poll_inner(&mut self, py: Python<'_>, mut op: CoroOp) -> PyResult { // raise if the coroutine has already been run to completion let future_rs = match self.future { Some(ref mut fut) => fut, None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)), }; - // reraise thrown exception it - match (throw, &self.throw_callback) { - (Some(exc), Some(cb)) => cb.throw(exc), - (Some(exc), None) => { - self.close(); - return Err(PyErr::from_value(exc.into_bound(py))); - } - (None, _) => {} + // if the future is not pending on a Python awaitable, + // execute throw callback or complete on close + if !matches!(self.waker, Some(ref w) if w.is_delegated(py)) { + match op { + send @ CoroOp::Send(_) => op = send, + CoroOp::Throw(exc) => match &self.throw_callback { + Some(cb) => { + cb.throw(exc.clone_ref(py)); + op = CoroOp::Send(py.None()); + } + None => return Err(PyErr::from_value(exc.into_bound(py))), + }, + CoroOp::Close => return Err(PyGeneratorExit::new_err(py.None())), + }; } // create a new waker, or try to reset it in place if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) { - waker.reset(); + waker.reset(op); } else { - self.waker = Some(Arc::new(AsyncioWaker::new())); + self.waker = Some(Arc::new(CoroutineWaker::new(op))); } - let waker = Waker::from(self.waker.clone().unwrap()); - // poll the Rust future and forward its results if ready + // poll the future and forward its results if ready; otherwise, yield from waker // polling is UnwindSafe because the future is dropped in case of panic + let waker = Waker::from(self.waker.clone().unwrap()); let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker)); match panic::catch_unwind(panic::AssertUnwindSafe(poll)) { - Ok(Poll::Ready(res)) => { - self.close(); - return Err(PyStopIteration::new_err((res?,))); - } - Err(err) => { - self.close(); - return Err(PanicException::from_panic_payload(err)); - } - _ => {} + Err(err) => Err(PanicException::from_panic_payload(err)), + // See #4407, `PyStopIteration::new_err` argument must be wrap in a tuple, + // otherwise, when a tuple is returned, its fields would be expanded as error + // arguments + Ok(Poll::Ready(res)) => Err(PyStopIteration::new_err((res?,))), + Ok(Poll::Pending) => match self.waker.as_ref().unwrap().yield_(py) { + Ok(to_yield) => Ok(to_yield), + Err(err) => Err(err), + }, } - // otherwise, initialize the waker `asyncio.Future` - if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? { - // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__` - // and will yield itself if its result has not been set in polling above - if let Some(future) = PyIterator::from_object(future).unwrap().next() { - // future has not been leaked into Python for now, and Rust code can only call - // `set_result(None)` in `Wake` implementation, so it's safe to unwrap - return Ok(future.unwrap().into()); - } + } + + fn poll(&mut self, py: Python<'_>, op: CoroOp) -> PyResult { + let result = self.poll_inner(py, op); + if result.is_err() { + // the Rust future is dropped, and the field set to `None` + // to indicate the coroutine has been run to completion + drop(self.future.take()); } - // if waker has been waken during future polling, this is roughly equivalent to - // `await asyncio.sleep(0)`, so just yield `None`. - Ok(py.None()) + result } } @@ -145,18 +161,20 @@ impl Coroutine { } } - fn send(&mut self, py: Python<'_>, _value: &Bound<'_, PyAny>) -> PyResult { - self.poll(py, None) + fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult { + self.poll(py, CoroOp::Send(value)) } fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult { - self.poll(py, Some(exc)) + self.poll(py, CoroOp::Throw(exc)) } - fn close(&mut self) { - // the Rust future is dropped, and the field set to `None` - // to indicate the coroutine has been run to completion - drop(self.future.take()); + fn close(&mut self, py: Python<'_>) -> PyResult<()> { + match self.poll(py, CoroOp::Close) { + Ok(_) => Ok(()), + Err(err) if err.is_instance_of::(py) => Ok(()), + Err(err) => Err(err), + } } fn __await__(self_: Py) -> Py { @@ -164,6 +182,6 @@ impl Coroutine { } fn __next__(&mut self, py: Python<'_>) -> PyResult { - self.poll(py, None) + self.poll(py, CoroOp::Send(py.None())) } } diff --git a/src/coroutine/anyio.rs b/src/coroutine/anyio.rs new file mode 100644 index 00000000000..bb120dd456f --- /dev/null +++ b/src/coroutine/anyio.rs @@ -0,0 +1,73 @@ +//! Coroutine implementation using sniffio to select the appropriate implementation, +//! compatible with anyio. +use crate::{ + coroutine::{asyncio::AsyncioWaker, trio::TrioWaker}, + exceptions::PyRuntimeError, + sync::GILOnceCell, + types::PyAnyMethods, + PyObject, PyResult, Python, +}; + +enum AsyncLib { + Asyncio, + Trio, +} + +fn current_async_library(py: Python<'_>) -> PyResult { + static CURRENT_ASYNC_LIBRARY: GILOnceCell> = GILOnceCell::new(); + let import = || -> PyResult<_> { + Ok(match py.import("sniffio") { + Ok(module) => Some(module.getattr("current_async_library")?.into()), + Err(_) => None, + }) + }; + let Some(func) = CURRENT_ASYNC_LIBRARY.get_or_try_init(py, import)? else { + return Ok(AsyncLib::Asyncio); + }; + match func.bind(py).call0()?.extract()? { + "asyncio" => Ok(AsyncLib::Asyncio), + "trio" => Ok(AsyncLib::Trio), + rt => Err(PyRuntimeError::new_err(format!("unsupported runtime {rt}"))), + } +} + +/// Sniffio/anyio-compatible coroutine waker. +/// +/// Polling a Rust future calls `sniffio.current_async_library` to select the appropriate +/// implementation, either asyncio or trio. +pub(super) enum AnyioWaker { + /// [`AsyncioWaker`] + Asyncio(AsyncioWaker), + /// [`TrioWaker`] + Trio(TrioWaker), +} + +impl AnyioWaker { + pub(super) fn new(py: Python<'_>) -> PyResult { + match current_async_library(py)? { + AsyncLib::Asyncio => Ok(Self::Asyncio(AsyncioWaker::new(py)?)), + AsyncLib::Trio => Ok(Self::Trio(TrioWaker::new(py)?)), + } + } + + pub(super) fn yield_(&self, py: Python<'_>) -> PyResult { + match self { + AnyioWaker::Asyncio(w) => w.yield_(py), + AnyioWaker::Trio(w) => w.yield_(py), + } + } + + pub(super) fn yield_waken(py: Python<'_>) -> PyResult { + match current_async_library(py)? { + AsyncLib::Asyncio => AsyncioWaker::yield_waken(py), + AsyncLib::Trio => TrioWaker::yield_waken(py), + } + } + + pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> { + match self { + AnyioWaker::Asyncio(w) => w.wake(py), + AnyioWaker::Trio(w) => w.wake(py), + } + } +} diff --git a/src/coroutine/asyncio.rs b/src/coroutine/asyncio.rs new file mode 100644 index 00000000000..80ac4f7f95f --- /dev/null +++ b/src/coroutine/asyncio.rs @@ -0,0 +1,95 @@ +//! Coroutine implementation compatible with asyncio. +use pyo3_macros::pyfunction; + +use crate::{ + intern, + sync::GILOnceCell, + types::{PyAnyMethods, PyCFunction, PyIterator}, + wrap_pyfunction, Bound, IntoPyObjectExt, Py, PyAny, PyObject, PyResult, Python, +}; + +/// `asyncio.get_running_loop` +fn get_running_loop(py: Python<'_>) -> PyResult> { + static GET_RUNNING_LOOP: GILOnceCell = GILOnceCell::new(); + let import = || -> PyResult<_> { + let module = py.import("asyncio")?; + Ok(module.getattr("get_running_loop")?.into()) + }; + GET_RUNNING_LOOP + .get_or_try_init(py, import)? + .bind(py) + .call0() +} + +/// Asyncio-compatible coroutine waker. +/// +/// Polling a Rust future yields an `asyncio.Future`, whose `set_result` method is called +/// when `Waker::wake` is called. +pub(super) struct AsyncioWaker { + event_loop: PyObject, + future: PyObject, +} + +impl AsyncioWaker { + pub(super) fn new(py: Python<'_>) -> PyResult { + let event_loop = get_running_loop(py)?.into_py_any(py)?; + let future = event_loop.call_method0(py, "create_future")?; + Ok(Self { event_loop, future }) + } + + pub(super) fn yield_(&self, py: Python<'_>) -> PyResult { + let __await__; + // `asyncio.Future` must be awaited; in normal case, it implements `__iter__ = __await__`, + // but `create_future` may have been overriden + let mut iter = match PyIterator::from_object(self.future.bind(py)) { + Ok(iter) => iter, + Err(_) => { + __await__ = self.future.call_method0(py, intern!(py, "__await__"))?; + PyIterator::from_object(__await__.bind(py))? + } + }; + // future has not been wakened (because `yield_waken` would have been called + // otherwise), so it is expected to yield itself + Ok(iter.next().expect("future didn't yield")?.into_py_any(py)?) + } + + #[allow(clippy::unnecessary_wraps)] + pub(super) fn yield_waken(py: Python<'_>) -> PyResult { + Ok(py.None()) + } + + pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> { + static RELEASE_WAITER: GILOnceCell> = GILOnceCell::new(); + let release_waiter = RELEASE_WAITER + .get_or_try_init(py, || wrap_pyfunction!(release_waiter, py).map(Into::into))?; + // `Future.set_result` must be called in event loop thread, + // so it requires `call_soon_threadsafe` + let call_soon_threadsafe = self.event_loop.call_method1( + py, + intern!(py, "call_soon_threadsafe"), + (release_waiter, &self.future), + ); + if let Err(err) = call_soon_threadsafe { + // `call_soon_threadsafe` will raise if the event loop is closed; + // instead of catching an unspecific `RuntimeError`, check directly if it's closed. + let is_closed = self.event_loop.call_method0(py, "is_closed")?; + if !is_closed.extract(py)? { + return Err(err); + } + } + Ok(()) + } +} + +/// Call `future.set_result` if the future is not done. +/// +/// Future can be cancelled by the event loop before being wakened. +/// See +#[pyfunction(crate = "crate")] +fn release_waiter(future: &Bound<'_, PyAny>) -> PyResult<()> { + let done = future.call_method0(intern!(future.py(), "done"))?; + if !done.extract::()? { + future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?; + } + Ok(()) +} diff --git a/src/coroutine/awaitable.rs b/src/coroutine/awaitable.rs new file mode 100644 index 00000000000..e1be029c146 --- /dev/null +++ b/src/coroutine/awaitable.rs @@ -0,0 +1,151 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use super::waker::try_delegate; +use crate::{ + coroutine::CoroOp, + exceptions::{PyAttributeError, PyTypeError}, + intern, + sync::GILOnceCell, + types::{PyAnyMethods, PyIterator, PyTypeMethods}, + Bound, PyAny, PyErr, PyObject, PyResult, Python, +}; + +const NOT_IN_COROUTINE_CONTEXT: &str = "Python awaitable must be awaited in coroutine context"; + +fn is_awaitable(obj: &Bound<'_, PyAny>) -> PyResult { + static IS_AWAITABLE: GILOnceCell = GILOnceCell::new(); + let import = || PyResult::Ok(obj.py().import("inspect")?.getattr("isawaitable")?.into()); + IS_AWAITABLE + .get_or_try_init(obj.py(), import)? + .call1(obj.py(), (obj,))? + .extract(obj.py()) +} + +pub(crate) enum YieldOrReturn { + Return(PyObject), + Yield(PyObject), +} + +pub(crate) fn delegate( + py: Python<'_>, + await_impl: PyObject, + op: &CoroOp, +) -> PyResult { + match op { + CoroOp::Send(obj) => { + cfg_if::cfg_if! { + if #[cfg(all(Py_3_10, not(PyPy), not(Py_LIMITED_API)))] { + let mut result = std::ptr::null_mut(); + match unsafe { crate::ffi::PyIter_Send(await_impl.as_ptr(), obj.as_ptr(), &mut result) } + { + -1 => Err(PyErr::take(py).unwrap()), + 0 => Ok(YieldOrReturn::Return(unsafe { + PyObject::from_owned_ptr(py, result) + })), + 1 => Ok(YieldOrReturn::Yield(unsafe { + PyObject::from_owned_ptr(py, result) + })), + _ => unreachable!(), + } + } else { + let send = intern!(py, "send"); + if obj.is_none(py) || !await_impl.bind(py).hasattr(send).unwrap_or(false) { + await_impl.call_method0(py, intern!(py, "__next__")) + } else { + await_impl.call_method1(py, send, (obj,)) + } + .map(YieldOrReturn::Yield) + } + } + } + CoroOp::Throw(exc) => { + let throw = intern!(py, "throw"); + if await_impl.bind(py).hasattr(throw).unwrap_or(false) { + await_impl + .call_method1(py, throw, (exc,)) + .map(YieldOrReturn::Yield) + } else { + Err(PyErr::from_value(exc.bind(py).clone())) + } + } + CoroOp::Close => { + let close = intern!(py, "close"); + if await_impl.bind(py).hasattr(close).unwrap_or(false) { + await_impl + .call_method0(py, close) + .map(YieldOrReturn::Return) + } else { + Ok(YieldOrReturn::Return(py.None())) + } + } + } +} + +struct AwaitImpl(PyObject); + +impl AwaitImpl { + fn new(obj: &Bound<'_, PyAny>) -> PyResult { + let __await__ = intern!(obj.py(), "__await__"); + match obj.call_method0(__await__) { + Ok(iter) => Ok(Self(iter.unbind())), + Err(err) if err.is_instance_of::(obj.py()) => { + if obj.hasattr(__await__)? { + Err(err) + } else if is_awaitable(obj)? { + Ok(Self(PyIterator::from_object(obj)?.unbind().into_any())) + } else { + Err(PyTypeError::new_err(format!( + "object {tp} can't be used in 'await' expression", + tp = obj.get_type().name()? + ))) + } + } + Err(err) => Err(err), + } + } +} + +impl Future for AwaitImpl { + type Output = PyResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match try_delegate(cx.waker(), Python::with_gil(|py| self.0.clone_ref(py))) { + Some(poll) => poll, + None => panic!("{}", NOT_IN_COROUTINE_CONTEXT), + } + } +} + +/// Allows awaiting arbitrary Python awaitable inside PyO3 coroutine context, e.g. async pyfunction. +/// +/// Awaiting the resulting future will panic if it's not done in coroutine context. +/// However, the future can be instantiated outside of coroutine context. +/// +/// ```rust +/// use pyo3::{coroutine::await_in_coroutine, prelude::*, py_run, wrap_pyfunction}; +/// +/// # fn main() { +/// #[pyfunction] +/// async fn wrap_awaitable(awaitable: PyObject) -> PyResult { +/// let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?; +/// future.await +/// } +/// Python::with_gil(|py| { +/// let wrap_awaitable = wrap_pyfunction!(wrap_awaitable, py).unwrap(); +/// let test = r#" +/// import asyncio +/// assert asyncio.run(wrap_awaitable(asyncio.sleep(1, result=42))) == 42 +/// "#; +/// py_run!(py, wrap_awaitable, test); +/// }) +/// # } +/// ``` +pub fn await_in_coroutine( + obj: &Bound<'_, PyAny>, +) -> PyResult> + Send + Sync + 'static> { + AwaitImpl::new(obj) +} diff --git a/src/coroutine/trio.rs b/src/coroutine/trio.rs new file mode 100644 index 00000000000..725e3ba56ef --- /dev/null +++ b/src/coroutine/trio.rs @@ -0,0 +1,88 @@ +//! Coroutine implementation compatible with trio. +use pyo3_macros::pyfunction; + +use crate::{ + intern, + sync::GILOnceCell, + types::{PyAnyMethods, PyCFunction, PyIterator}, + wrap_pyfunction, Bound, Py, PyAny, PyObject, PyResult, Python, +}; + +struct Trio { + cancel_shielded_checkpoint: PyObject, + current_task: PyObject, + current_trio_token: PyObject, + reschedule: PyObject, + succeeded: PyObject, + wait_task_rescheduled: PyObject, +} +impl Trio { + fn get(py: Python<'_>) -> PyResult<&Self> { + static TRIO: GILOnceCell = GILOnceCell::new(); + TRIO.get_or_try_init(py, || { + let module = py.import("trio.lowlevel")?; + Ok(Self { + cancel_shielded_checkpoint: module.getattr("cancel_shielded_checkpoint")?.into(), + current_task: module.getattr("current_task")?.into(), + current_trio_token: module.getattr("current_trio_token")?.into(), + reschedule: module.getattr("reschedule")?.into(), + succeeded: module.getattr("Abort")?.getattr("SUCCEEDED")?.into(), + wait_task_rescheduled: module.getattr("wait_task_rescheduled")?.into(), + }) + }) + } +} + +fn yield_from(coro_func: &Bound<'_, PyAny>) -> PyResult { + PyIterator::from_object(&coro_func.call_method0("__await__")?)? + .next() + .expect("cancel_shielded_checkpoint didn't yield") + .map(Into::into) +} + +/// Asyncio-compatible coroutine waker. +/// +/// Polling a Rust future yields `trio.lowlevel.wait_task_rescheduled()`, while `Waker::wake` +/// reschedule the current task. +pub(super) struct TrioWaker { + task: PyObject, + token: PyObject, +} + +impl TrioWaker { + pub(super) fn new(py: Python<'_>) -> PyResult { + let trio = Trio::get(py)?; + let task = trio.current_task.call0(py)?; + let token = trio.current_trio_token.call0(py)?; + Ok(Self { task, token }) + } + + pub(super) fn yield_(&self, py: Python<'_>) -> PyResult { + static ABORT_FUNC: GILOnceCell> = GILOnceCell::new(); + let abort_func = + ABORT_FUNC.get_or_try_init(py, || wrap_pyfunction!(abort_func, py).map(Into::into))?; + let wait_task_rescheduled = Trio::get(py)? + .wait_task_rescheduled + .call1(py, (abort_func,))?; + yield_from(wait_task_rescheduled.bind(py)) + } + + pub(super) fn yield_waken(py: Python<'_>) -> PyResult { + let checkpoint = Trio::get(py)?.cancel_shielded_checkpoint.call0(py)?; + yield_from(checkpoint.bind(py)) + } + + pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> { + self.token.call_method1( + py, + intern!(py, "run_sync_soon"), + (&Trio::get(py)?.reschedule, &self.task), + )?; + Ok(()) + } +} + +#[pyfunction(crate = "crate")] +fn abort_func(py: Python<'_>, _arg: &Bound<'_, PyAny>) -> PyResult { + Ok(Trio::get(py)?.succeeded.clone_ref(py)) +} diff --git a/src/coroutine/waker.rs b/src/coroutine/waker.rs index 1600f56d9c6..4132eeb76af 100644 --- a/src/coroutine/waker.rs +++ b/src/coroutine/waker.rs @@ -1,106 +1,126 @@ -use crate::sync::GILOnceCell; -use crate::types::any::PyAnyMethods; -use crate::types::PyCFunction; -use crate::{intern, wrap_pyfunction, Bound, Py, PyAny, PyObject, PyResult, Python}; -use pyo3_macros::pyfunction; -use std::sync::Arc; -use std::task::Wake; +use std::{ + cell::Cell, + sync::Arc, + task::{Poll, Wake, Waker}, +}; -/// Lazy `asyncio.Future` wrapper, implementing [`Wake`] by calling `Future.set_result`. -/// -/// asyncio future is let uninitialized until [`initialize_future`][1] is called. -/// If [`wake`][2] is called before future initialization (during Rust future polling), -/// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`) -/// -/// [1]: AsyncioWaker::initialize_future -/// [2]: AsyncioWaker::wake -pub struct AsyncioWaker(GILOnceCell>); +use crate::{ + coroutine::{ + awaitable::{delegate, YieldOrReturn}, + CoroOp, + }, + exceptions::PyStopIteration, + intern, + sync::GILOnceCell, + types::PyAnyMethods, + Bound, PyObject, PyResult, Python, +}; -impl AsyncioWaker { - pub(super) fn new() -> Self { - Self(GILOnceCell::new()) +cfg_if::cfg_if! { + if #[cfg(feature = "anyio")] { + type WakerImpl = crate::coroutine::anyio::AnyioWaker; + } else { + type WakerImpl = crate::coroutine::asyncio::AsyncioWaker; + } +} + +const MIXED_AWAITABLE_AND_FUTURE_ERROR: &str = "Python awaitable mixed with Rust future"; + +enum State { + Pending(WakerImpl), + Waken, + Delegated(PyObject), +} + +pub(super) struct CoroutineWaker { + state: GILOnceCell, + op: CoroOp, +} + +impl CoroutineWaker { + pub(super) fn new(op: CoroOp) -> Self { + Self { + state: GILOnceCell::new(), + op, + } + } + + pub(super) fn reset(&mut self, op: CoroOp) { + self.state.take(); + self.op = op; } - pub(super) fn reset(&mut self) { - self.0.take(); + pub(super) fn is_delegated(&self, py: Python<'_>) -> bool { + matches!(self.state.get(py), Some(State::Delegated(_))) } - pub(super) fn initialize_future<'py>( - &self, - py: Python<'py>, - ) -> PyResult>> { - let init = || LoopAndFuture::new(py).map(Some); - let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref(); - Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.bind(py))) + pub(super) fn yield_(&self, py: Python<'_>) -> PyResult { + let init = || PyResult::Ok(State::Pending(WakerImpl::new(py)?)); + let state = self.state.get_or_try_init(py, init)?; + match state { + State::Pending(waker) => waker.yield_(py), + State::Waken => WakerImpl::yield_waken(py), + State::Delegated(obj) => Ok(obj.clone_ref(py)), + } + } + + fn delegate(&self, py: Python<'_>, await_impl: PyObject) -> Poll> { + match delegate(py, await_impl, &self.op) { + Ok(YieldOrReturn::Yield(obj)) => { + let delegated = self.state.set(py, State::Delegated(obj)); + assert!(delegated.is_ok(), "{}", MIXED_AWAITABLE_AND_FUTURE_ERROR); + Poll::Pending + } + Ok(YieldOrReturn::Return(obj)) => Poll::Ready(Ok(obj)), + Err(err) if err.is_instance_of::(py) => Poll::Ready( + err.value(py) + .getattr(intern!(py, "value")) + .map(Bound::unbind), + ), + Err(err) => Poll::Ready(Err(err)), + } } } -impl Wake for AsyncioWaker { +impl Wake for CoroutineWaker { fn wake(self: Arc) { self.wake_by_ref() } fn wake_by_ref(self: &Arc) { - Python::with_gil(|gil| { - if let Some(loop_and_future) = self.0.get_or_init(gil, || None) { - loop_and_future - .set_result(gil) - .expect("unexpected error in coroutine waker"); - } - }); + Python::with_gil(|gil| match WAKER_HACK.with(|cell| cell.take()) { + Some(WakerHack::Argument(await_impl)) => WAKER_HACK.with(|cell| { + let res = self.delegate(gil, await_impl); + cell.set(Some(WakerHack::Result(res))) + }), + Some(WakerHack::Result(_)) => unreachable!(), + None => match self.state.get_or_init(gil, || State::Waken) { + State::Pending(waker) => waker.wake(gil).expect("wake error"), + State::Waken => {} + State::Delegated(_) => panic!("{}", MIXED_AWAITABLE_AND_FUTURE_ERROR), + }, + }) } } -struct LoopAndFuture { - event_loop: PyObject, - future: PyObject, +enum WakerHack { + Argument(PyObject), + Result(Poll>), } -impl LoopAndFuture { - fn new(py: Python<'_>) -> PyResult { - static GET_RUNNING_LOOP: GILOnceCell = GILOnceCell::new(); - let import = || -> PyResult<_> { - let module = py.import("asyncio")?; - Ok(module.getattr("get_running_loop")?.into()) - }; - let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?; - let future = event_loop.call_method0(py, "create_future")?; - Ok(Self { event_loop, future }) - } - - fn set_result(&self, py: Python<'_>) -> PyResult<()> { - static RELEASE_WAITER: GILOnceCell> = GILOnceCell::new(); - let release_waiter = RELEASE_WAITER.get_or_try_init(py, || { - wrap_pyfunction!(release_waiter, py).map(Bound::unbind) - })?; - // `Future.set_result` must be called in event loop thread, - // so it requires `call_soon_threadsafe` - let call_soon_threadsafe = self.event_loop.call_method1( - py, - intern!(py, "call_soon_threadsafe"), - (release_waiter, self.future.bind(py)), - ); - if let Err(err) = call_soon_threadsafe { - // `call_soon_threadsafe` will raise if the event loop is closed; - // instead of catching an unspecific `RuntimeError`, check directly if it's closed. - let is_closed = self.event_loop.call_method0(py, "is_closed")?; - if !is_closed.extract(py)? { - return Err(err); - } - } - Ok(()) - } +thread_local! { + static WAKER_HACK: Cell> = Cell::new(None); } -/// Call `future.set_result` if the future is not done. -/// -/// Future can be cancelled by the event loop before being waken. -/// See -#[pyfunction(crate = "crate")] -fn release_waiter(future: &Bound<'_, PyAny>) -> PyResult<()> { - let done = future.call_method0(intern!(future.py(), "done"))?; - if !done.extract::()? { - future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?; +pub(crate) fn try_delegate( + waker: &Waker, + await_impl: PyObject, +) -> Option>> { + WAKER_HACK.with(|cell| cell.set(Some(WakerHack::Argument(await_impl)))); + waker.wake_by_ref(); + match WAKER_HACK.with(|cell| cell.take()) { + Some(WakerHack::Result(poll)) => Some(poll), + Some(WakerHack::Argument(_)) => None, + None => unreachable!(), } - Ok(()) } diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs index f893a2c2fe9..e2a30e331c9 100644 --- a/src/impl_/coroutine.rs +++ b/src/impl_/coroutine.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - coroutine::{cancel::ThrowCallback, Coroutine}, + coroutine::{Coroutine, ThrowCallback}, instance::Bound, pycell::impl_::PyClassBorrowChecker, pyclass::boolean_struct::False, diff --git a/src/lib.rs b/src/lib.rs index e5146c81c00..39f8b45b79c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -333,21 +333,23 @@ //! [Rust from Python]: https://github.com/PyO3/pyo3#using-rust-from-python #![doc = concat!("[Features chapter of the guide]: https://pyo3.rs/v", env!("CARGO_PKG_VERSION"), "/features.html#features-reference \"Features Reference - PyO3 user guide\"")] //! [`Ungil`]: crate::marker::Ungil -pub use crate::class::*; -pub use crate::conversion::{AsPyPointer, FromPyObject, IntoPyObject, IntoPyObjectExt}; #[allow(deprecated)] pub use crate::conversion::{IntoPy, ToPyObject}; -pub use crate::err::{DowncastError, DowncastIntoError, PyErr, PyErrArguments, PyResult, ToPyErr}; #[cfg(not(any(PyPy, GraalPy)))] pub use crate::gil::{prepare_freethreaded_python, with_embedded_python_interpreter}; -pub use crate::instance::{Borrowed, Bound, BoundObject, Py, PyObject}; -pub use crate::marker::Python; -pub use crate::pycell::{PyRef, PyRefMut}; -pub use crate::pyclass::PyClass; -pub use crate::pyclass_init::PyClassInitializer; -pub use crate::type_object::{PyTypeCheck, PyTypeInfo}; -pub use crate::types::PyAny; -pub use crate::version::PythonVersionInfo; +pub use crate::{ + class::*, + conversion::{AsPyPointer, FromPyObject, IntoPyObject, IntoPyObjectExt}, + err::{DowncastError, DowncastIntoError, PyErr, PyErrArguments, PyResult, ToPyErr}, + instance::{Borrowed, Bound, BoundObject, Py, PyObject}, + marker::Python, + pycell::{PyRef, PyRefMut}, + pyclass::PyClass, + pyclass_init::PyClassInitializer, + type_object::{PyTypeCheck, PyTypeInfo}, + types::PyAny, + version::PythonVersionInfo, +}; pub(crate) mod ffi_ptr_ext; pub(crate) mod py_result_ext; @@ -361,9 +363,10 @@ pub(crate) mod sealed; /// For compatibility reasons this has not yet been removed, however will be done so /// once is resolved. pub mod class { - pub use self::gc::{PyTraverseError, PyVisit}; - - pub use self::methods::*; + pub use self::{ + gc::{PyTraverseError, PyVisit}, + methods::*, + }; #[doc(hidden)] pub mod methods { @@ -406,16 +409,15 @@ pub mod class { } } +#[cfg(all(feature = "macros", feature = "multiple-pymethods"))] +#[doc(hidden)] +pub use inventory; #[cfg(feature = "macros")] #[doc(hidden)] pub use { indoc, // Re-exported for py_run unindent, // Re-exported for py_run -}; - -#[cfg(all(feature = "macros", feature = "multiple-pymethods"))] -#[doc(hidden)] -pub use inventory; // Re-exported for `#[pyclass]` and `#[pymethods]` with `multiple-pymethods`. +}; // Re-exported for `#[pyclass]` and `#[pymethods]` with `multiple-pymethods`. /// Tests and helpers which reside inside PyO3's main library. Declared first so that macros /// are available in unit tests. @@ -454,14 +456,6 @@ pub mod type_object; pub mod types; mod version; -#[allow(unused_imports)] // with no features enabled this module has no public exports -pub use crate::conversions::*; - -#[cfg(feature = "macros")] -pub use pyo3_macros::{ - pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject, IntoPyObjectRef, -}; - /// A proc macro used to expose Rust structs and fieldless enums as Python objects. /// #[doc = include_str!("../guide/pyclass-parameters.md")] @@ -472,6 +466,13 @@ pub use pyo3_macros::{ #[doc = concat!("[1]: https://pyo3.rs/v", env!("CARGO_PKG_VERSION"), "/class.html")] #[cfg(feature = "macros")] pub use pyo3_macros::pyclass; +#[cfg(feature = "macros")] +pub use pyo3_macros::{ + pyfunction, pymethods, pymodule, FromPyObject, IntoPyObject, IntoPyObjectRef, +}; + +#[allow(unused_imports)] // with no features enabled this module has no public exports +pub use crate::conversions::*; #[cfg(feature = "macros")] #[macro_use] @@ -500,6 +501,7 @@ pub mod doc_test { "README.md" => readme_md, "guide/src/advanced.md" => guide_advanced_md, "guide/src/async-await.md" => guide_async_await_md, + "guide/src/async-await/awaiting_python_awaitables.md" => guide_async_await_awaiting_python_awaitable_md, "guide/src/building-and-distribution.md" => guide_building_and_distribution_md, "guide/src/building-and-distribution/multiple-python-versions.md" => guide_bnd_multiple_python_versions_md, "guide/src/class.md" => guide_class_md, diff --git a/src/tests/common.rs b/src/tests/common.rs index cd0374e9019..6bb4d9cd2e8 100644 --- a/src/tests/common.rs +++ b/src/tests/common.rs @@ -175,6 +175,17 @@ mod inner { let uuid = Uuid::new_v4().simple().to_string(); std::ffi::CString::new(format!("{base}_{uuid}")).unwrap() } + + // see https://stackoverflow.com/questions/60359157/valueerror-set-wakeup-fd-only-works-in-main-thread-on-windows-on-python-3-8-wit + #[cfg(feature = "macros")] + pub fn asyncio_windows(test: &str) -> String { + let set_event_loop_policy = r#" + import asyncio, sys + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + "#; + pyo3::unindent::unindent(set_event_loop_policy) + &pyo3::unindent::unindent(test) + } } #[allow(unused_imports)] // some tests use just the macros and none of the other functionality diff --git a/tests/test_await_in_coroutine.rs b/tests/test_await_in_coroutine.rs new file mode 100644 index 00000000000..5fe40194ecc --- /dev/null +++ b/tests/test_await_in_coroutine.rs @@ -0,0 +1,189 @@ +#![cfg(feature = "experimental-async")] + +use std::{ffi::CString, task::Poll}; + +use futures::{future::poll_fn, FutureExt}; +use pyo3::{ + coroutine::{await_in_coroutine, CancelHandle}, + exceptions::{PyAttributeError, PyTypeError}, + prelude::*, + py_run, +}; + +#[path = "../src/tests/common.rs"] +mod common; + +#[pyfunction] +async fn wrap_awaitable(awaitable: PyObject) -> PyResult { + let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?; + future.await +} + +#[test] +fn awaitable() { + Python::with_gil(|gil| { + let wrap_awaitable = wrap_pyfunction!(wrap_awaitable, gil).unwrap(); + let test = r#" + import types + import asyncio; + + class BadAwaitable: + def __await__(self): + raise AttributeError("__await__") + + @types.coroutine + def gen_coro(): + yield None + + async def main(): + await wrap_awaitable(...) + asyncio.run(main()) + "#; + let globals = gil.import("__main__").unwrap().dict(); + globals.set_item("wrap_awaitable", wrap_awaitable).unwrap(); + let run = |awaitable| { + gil.run( + &CString::new(common::asyncio_windows(test).replace("...", awaitable)).unwrap(), + Some(&globals), + None, + ) + }; + run("asyncio.sleep(0.001)").unwrap(); + run("gen_coro()").unwrap(); + assert!(run("None").unwrap_err().is_instance_of::(gil)); + assert!(run("BadAwaitable()") + .unwrap_err() + .is_instance_of::(gil)); + }) +} + +#[test] +fn cancel_delegation() { + #[pyfunction] + async fn wrap_cancellable(awaitable: PyObject, #[pyo3(cancel_handle)] cancel: CancelHandle) { + let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil))).unwrap(); + let result = future.await; + Python::with_gil(|gil| { + assert_eq!( + result.unwrap_err().get_type(gil).name().unwrap(), + "CancelledError" + ) + }); + assert!(!cancel.is_cancelled()); + } + Python::with_gil(|gil| { + let wrap_cancellable = wrap_pyfunction!(wrap_cancellable, gil).unwrap(); + let test = r#" + import asyncio; + + async def main(): + task = asyncio.create_task(wrap_cancellable(asyncio.sleep(0.001))) + await asyncio.sleep(0) + task.cancel() + await task + asyncio.run(main()) + "#; + let globals = gil.import("__main__").unwrap().dict(); + globals + .set_item("wrap_cancellable", wrap_cancellable) + .unwrap(); + gil.run( + &CString::new(common::asyncio_windows(test)).unwrap(), + Some(&globals), + None, + ) + .unwrap(); + }) +} + +#[test] +#[should_panic(expected = "Python awaitable must be awaited in coroutine context")] +fn awaitable_without_coroutine() { + #[pyfunction] + fn block_on(awaitable: PyObject) -> PyResult { + let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?; + futures::executor::block_on(future) + } + Python::with_gil(|gil| { + let block_on = wrap_pyfunction!(block_on, gil).unwrap(); + let test = r#" + async def coro(): + ... + block_on(coro()) + "#; + py_run!(gil, block_on, &common::asyncio_windows(test)); + }) +} + +async fn checkpoint() { + let mut ready = false; + poll_fn(|cx| { + if ready { + return Poll::Ready(()); + } + ready = true; + cx.waker().wake_by_ref(); + Poll::Pending + }) + .await +} + +#[test] +#[should_panic(expected = "Python awaitable mixed with Rust future")] +fn awaitable_in_select() { + #[pyfunction] + async fn select(awaitable: PyObject) -> PyResult { + let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?; + futures::select_biased! { + _ = checkpoint().fuse() => unreachable!(), + res = future.fuse() => res, + } + } + Python::with_gil(|gil| { + let select = wrap_pyfunction!(select, gil).unwrap(); + let test = r#" + import asyncio; + async def main(): + return await select(asyncio.sleep(1)) + asyncio.run(main()) + "#; + let globals = gil.import("__main__").unwrap().dict(); + globals.set_item("select", select).unwrap(); + gil.run( + &CString::new(common::asyncio_windows(test)).unwrap(), + Some(&globals), + None, + ) + .unwrap(); + }) +} + +#[test] +#[should_panic(expected = "Python awaitable mixed with Rust future")] +fn awaitable_in_select2() { + #[pyfunction] + async fn select2(awaitable: PyObject) -> PyResult { + let future = Python::with_gil(|gil| await_in_coroutine(awaitable.bind(gil)))?; + futures::select_biased! { + res = future.fuse() => res, + _ = checkpoint().fuse() => unreachable!(), + } + } + Python::with_gil(|gil| { + let select2 = wrap_pyfunction!(select2, gil).unwrap(); + let test = r#" + import asyncio; + async def main(): + return await select2(asyncio.sleep(1)) + asyncio.run(main()) + "#; + let globals = gil.import("__main__").unwrap().dict(); + globals.set_item("select2", select2).unwrap(); + gil.run( + &CString::new(common::asyncio_windows(test)).unwrap(), + Some(&globals), + None, + ) + .unwrap(); + }) +} diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 89ab3d64a4b..a4352c51020 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -1,5 +1,7 @@ #![cfg(feature = "experimental-async")] #![cfg(not(target_arch = "wasm32"))] +#[cfg(target_has_atomic = "64")] +use std::sync::atomic::{AtomicBool, Ordering}; use std::{ffi::CString, task::Poll, thread, time::Duration}; use futures::{channel::oneshot, future::poll_fn, FutureExt}; @@ -11,21 +13,10 @@ use pyo3::{ py_run, types::{IntoPyDict, PyType}, }; -#[cfg(target_has_atomic = "64")] -use std::sync::atomic::{AtomicBool, Ordering}; #[path = "../src/tests/common.rs"] mod common; -fn handle_windows(test: &str) -> String { - let set_event_loop_policy = r#" - import asyncio, sys - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - "#; - pyo3::unindent::unindent(set_event_loop_policy) + &pyo3::unindent::unindent(test) -} - #[test] fn noop_coroutine() { #[pyfunction] @@ -35,7 +26,7 @@ fn noop_coroutine() { Python::with_gil(|gil| { let noop = wrap_pyfunction!(noop, gil).unwrap(); let test = "import asyncio; assert asyncio.run(noop()) == 42"; - py_run!(gil, noop, &handle_windows(test)); + py_run!(gil, noop, &common::asyncio_windows(test)); }) } @@ -74,7 +65,7 @@ fn test_coroutine_qualname() { ] .into_py_dict(gil) .unwrap(); - py_run!(gil, *locals, &handle_windows(test)); + py_run!(gil, *locals, &common::asyncio_windows(test)); }) } @@ -96,7 +87,7 @@ fn sleep_0_like_coroutine() { Python::with_gil(|gil| { let sleep_0 = wrap_pyfunction!(sleep_0, gil).unwrap(); let test = "import asyncio; assert asyncio.run(sleep_0()) == 42"; - py_run!(gil, sleep_0, &handle_windows(test)); + py_run!(gil, sleep_0, &common::asyncio_windows(test)); }) } @@ -115,7 +106,7 @@ fn sleep_coroutine() { Python::with_gil(|gil| { let sleep = wrap_pyfunction!(sleep, gil).unwrap(); let test = r#"import asyncio; assert asyncio.run(sleep(0.1)) == 42"#; - py_run!(gil, sleep, &handle_windows(test)); + py_run!(gil, sleep, &common::asyncio_windows(test)); }) } @@ -129,7 +120,7 @@ fn tuple_coroutine() { Python::with_gil(|gil| { let func = wrap_pyfunction!(return_tuple, gil).unwrap(); let test = r#"import asyncio; assert asyncio.run(func()) == (42, 43)"#; - py_run!(gil, func, &handle_windows(test)); + py_run!(gil, func, &common::asyncio_windows(test)); }) } @@ -150,7 +141,7 @@ fn cancelled_coroutine() { globals.set_item("sleep", sleep).unwrap(); let err = gil .run( - &CString::new(pyo3::unindent::unindent(&handle_windows(test))).unwrap(), + &CString::new(common::asyncio_windows(test)).unwrap(), Some(&globals), None, ) @@ -190,7 +181,7 @@ fn coroutine_cancel_handle() { .set_item("cancellable_sleep", cancellable_sleep) .unwrap(); gil.run( - &CString::new(pyo3::unindent::unindent(&handle_windows(test))).unwrap(), + &CString::new(common::asyncio_windows(test)).unwrap(), Some(&globals), None, ) @@ -220,7 +211,7 @@ fn coroutine_is_cancelled() { let globals = gil.import("__main__").unwrap().dict(); globals.set_item("sleep_loop", sleep_loop).unwrap(); gil.run( - &CString::new(pyo3::unindent::unindent(&handle_windows(test))).unwrap(), + &CString::new(common::asyncio_windows(test)).unwrap(), Some(&globals), None, ) @@ -253,7 +244,7 @@ fn coroutine_panic() { else: assert False; "#; - py_run!(gil, panic, &handle_windows(test)); + py_run!(gil, panic, &common::asyncio_windows(test)); }) } @@ -354,6 +345,6 @@ fn test_async_method_receiver_with_other_args() { let locals = [("Value", gil.get_type::())] .into_py_dict(gil) .unwrap(); - py_run!(gil, *locals, test); + py_run!(gil, *locals, &common::asyncio_windows(test)); }); }