diff --git a/README.md b/README.md index 5f68f51..ff8f95b 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ - **TLS/HTTPS**: OpenSSL-based with ALPN and certificate verification - **Header-Only Library** for easy integration - **Debugging Tools**: GDB/LLDB extensions and pstack-like CLI -- **Comprehensive Testing** with Catch2 and ASAN +- **Comprehensive Testing** with Catch2, ASAN, and TSAN - **Integrated Logging** with fmtlib - **CI/CD** with GitHub Actions @@ -59,6 +59,9 @@ ctest --output-on-failure # ASAN tests (memory safety) ./build/tests/elio_tests_asan + +# TSAN tests (thread safety) +./build/tests/elio_tests_tsan ``` ### Your First Coroutine @@ -179,6 +182,7 @@ outer() -> middle() -> inner() The scheduler manages a pool of worker threads, each with a local task queue. Key features: - **Lock-free operations**: Chase-Lev deques for optimal performance - **Work stealing**: Idle threads steal tasks from busy threads +- **Per-worker I/O context**: Each worker has its own io_uring/epoll backend for thread-safe I/O - **Dynamic sizing**: Adjust thread count at runtime - **Load balancing**: Automatic task distribution @@ -343,7 +347,8 @@ Elio includes comprehensive tests: - **Unit Tests**: Test each component in isolation - **Integration Tests**: Test components working together -- **ASAN Tests**: Detect memory errors +- **ASAN Tests**: Detect memory errors (use-after-free, buffer overflow, etc.) +- **TSAN Tests**: Detect data races and thread safety issues ```bash # Run all tests @@ -354,7 +359,8 @@ ctest --test-dir build --output-on-failure ./build/tests/elio_tests "[integration]" # Run with sanitizers -./build/tests/elio_tests_asan +./build/tests/elio_tests_asan # Memory safety +./build/tests/elio_tests_tsan # Thread safety ``` ## Performance diff --git a/examples/async_file_io.cpp b/examples/async_file_io.cpp index 7f8ff2c..e797300 100644 --- a/examples/async_file_io.cpp +++ b/examples/async_file_io.cpp @@ -22,8 +22,6 @@ using namespace elio::runtime; /// Async file copy using io_uring/epoll task async_copy_file(const std::string& src_path, const std::string& dst_path) { - auto& ctx = io::default_io_context(); - // Open source file int src_fd = open(src_path.c_str(), O_RDONLY); if (src_fd < 0) { @@ -62,7 +60,7 @@ task async_copy_file(const std::string& src_path, const std::string& dst_p size_t to_read = std::min(BUFFER_SIZE, file_size - total_copied); // Async read from source - auto read_result = co_await io::async_read(ctx, src_fd, buffer.data(), to_read, offset); + auto read_result = co_await io::async_read(src_fd, buffer.data(), to_read, offset); if (read_result.result <= 0) { if (read_result.result == 0) { @@ -77,7 +75,7 @@ task async_copy_file(const std::string& src_path, const std::string& dst_p size_t bytes_read = read_result.result; // Async write to destination - auto write_result = co_await io::async_write(ctx, dst_fd, buffer.data(), bytes_read, offset); + auto write_result = co_await io::async_write(dst_fd, buffer.data(), bytes_read, offset); if (write_result.result <= 0) { std::cerr << "Write error: " << strerror(-write_result.result) << std::endl; @@ -98,8 +96,8 @@ task async_copy_file(const std::string& src_path, const std::string& dst_p auto duration_ms = std::chrono::duration_cast(end - start).count(); // Close files - co_await io::async_close(ctx, src_fd); - co_await io::async_close(ctx, dst_fd); + co_await io::async_close(src_fd); + co_await io::async_close(dst_fd); std::cout << std::endl; std::cout << "Copy completed in " << duration_ms << " ms" << std::endl; @@ -114,8 +112,6 @@ task async_copy_file(const std::string& src_path, const std::string& dst_p /// Concurrent file read demonstration task concurrent_read_demo(const std::vector& files) { - auto& ctx = io::default_io_context(); - std::cout << "Reading " << files.size() << " files concurrently..." << std::endl; struct FileInfo { @@ -148,7 +144,7 @@ task concurrent_read_demo(const std::vector& files) { // Read all files (would be more concurrent with multiple coroutines) size_t total_bytes = 0; for (auto& info : file_infos) { - auto result = co_await io::async_read(ctx, info.fd, info.buffer.data(), info.size, 0); + auto result = co_await io::async_read(info.fd, info.buffer.data(), info.size, 0); if (result.result > 0) { total_bytes += result.result; std::cout << " Read " << result.result << " bytes from " << info.path << std::endl; @@ -166,13 +162,11 @@ task concurrent_read_demo(const std::vector& files) { /// Benchmark async vs sync file I/O task benchmark_io(size_t file_size_mb) { - auto& ctx = io::default_io_context(); - const std::string test_file = "/tmp/elio_benchmark_test.dat"; size_t file_size = file_size_mb * 1024 * 1024; std::cout << "Benchmark: " << file_size_mb << " MB file" << std::endl; - std::cout << "I/O Backend: " << ctx.get_backend_name() << std::endl; + std::cout << "I/O Backend: " << io::current_io_context().get_backend_name() << std::endl; std::cout << std::endl; // Create test file @@ -201,7 +195,7 @@ task benchmark_io(size_t file_size_mb) { size_t total_read = 0; int64_t offset = 0; while (total_read < file_size) { - auto result = co_await io::async_read(ctx, fd, buffer.data(), BUFFER_SIZE, offset); + auto result = co_await io::async_read(fd, buffer.data(), BUFFER_SIZE, offset); if (result.result <= 0) break; total_read += result.result; offset += result.result; @@ -261,9 +255,6 @@ int main(int argc, char* argv[]) { scheduler sched(2); - // Set the I/O context so workers can poll for I/O completions - sched.set_io_context(&io::default_io_context()); - sched.start(); std::atomic done{false}; diff --git a/examples/http2_client.cpp b/examples/http2_client.cpp index fe9a00d..ec7d55d 100644 --- a/examples/http2_client.cpp +++ b/examples/http2_client.cpp @@ -151,7 +151,6 @@ int main(int argc, char* argv[]) { // Create scheduler scheduler sched(2); - sched.set_io_context(&io::default_io_context()); sched.start(); // Run appropriate mode diff --git a/examples/http_client.cpp b/examples/http_client.cpp index ac1c708..7b20316 100644 --- a/examples/http_client.cpp +++ b/examples/http_client.cpp @@ -236,7 +236,6 @@ int main(int argc, char* argv[]) { // Create scheduler scheduler sched(2); - sched.set_io_context(&io::default_io_context()); sched.start(); // Run appropriate mode diff --git a/examples/http_server.cpp b/examples/http_server.cpp index d1bb7d7..756b37d 100644 --- a/examples/http_server.cpp +++ b/examples/http_server.cpp @@ -258,7 +258,6 @@ int main(int argc, char* argv[]) { // Create scheduler scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); // Spawn signal handler coroutine diff --git a/examples/rpc_client_example.cpp b/examples/rpc_client_example.cpp index e2e73ba..c1c8ffb 100644 --- a/examples/rpc_client_example.cpp +++ b/examples/rpc_client_example.cpp @@ -336,11 +336,9 @@ task run_demo(tcp_rpc_client::ptr client) { } task client_main(const char* host, uint16_t port) { - auto& ctx = io::default_io_context(); - std::cout << "Connecting to " << host << ":" << port << "..." << std::endl; - auto client = co_await tcp_rpc_client::connect(ctx, host, port); + auto client = co_await tcp_rpc_client::connect(host, port); if (!client) { std::cerr << "Failed to connect to server" << std::endl; co_return; @@ -368,7 +366,6 @@ int main(int argc, char* argv[]) { // Create and start scheduler scheduler sched(2); - sched.set_io_context(&io::default_io_context()); sched.start(); // Run client diff --git a/examples/rpc_server_example.cpp b/examples/rpc_server_example.cpp index 2bd9451..25bee20 100644 --- a/examples/rpc_server_example.cpp +++ b/examples/rpc_server_example.cpp @@ -318,7 +318,6 @@ int main(int argc, char* argv[]) { // Create and start scheduler scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); // Spawn signal handler coroutine diff --git a/examples/signal_handling.cpp b/examples/signal_handling.cpp index 3d34cce..dfba213 100644 --- a/examples/signal_handling.cpp +++ b/examples/signal_handling.cpp @@ -139,10 +139,6 @@ int main() { // Create scheduler with worker threads scheduler sched(4); - // Set up I/O context for async operations - io::io_context ctx; - sched.set_io_context(&ctx); - sched.start(); // Spawn main task @@ -151,18 +147,12 @@ int main() { // Run until shutdown while (g_running) { - ctx.poll(std::chrono::milliseconds(10)); std::this_thread::sleep_for(std::chrono::milliseconds(10)); } // Give coroutines time to clean up std::this_thread::sleep_for(std::chrono::milliseconds(200)); - // Poll any remaining I/O - for (int i = 0; i < 10 && ctx.has_pending(); ++i) { - ctx.poll(std::chrono::milliseconds(10)); - } - sched.shutdown(); ELIO_LOG_INFO("Application shutdown complete"); diff --git a/examples/sse_client.cpp b/examples/sse_client.cpp index 8aa3035..a3d26d2 100644 --- a/examples/sse_client.cpp +++ b/examples/sse_client.cpp @@ -235,7 +235,6 @@ int main(int argc, char* argv[]) { // Create scheduler runtime::scheduler sched(2); - sched.set_io_context(&io::default_io_context()); sched.start(); // Run client based on mode diff --git a/examples/sse_server.cpp b/examples/sse_server.cpp index 8f45f30..7ad7fe6 100644 --- a/examples/sse_server.cpp +++ b/examples/sse_server.cpp @@ -366,7 +366,6 @@ int main(int argc, char* argv[]) { // Create scheduler runtime::scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); // Spawn signal handler coroutine diff --git a/examples/tcp_echo_client.cpp b/examples/tcp_echo_client.cpp index cb94ac9..7cd3471 100644 --- a/examples/tcp_echo_client.cpp +++ b/examples/tcp_echo_client.cpp @@ -28,13 +28,10 @@ using namespace elio::net; /// Client coroutine - connects, sends messages, receives responses task client_main(std::string_view host, uint16_t port) { - // Use the default io_context which is polled by scheduler workers - auto& ctx = io::default_io_context(); - ELIO_LOG_INFO("Connecting to {}:{}...", host, port); // Connect to server - auto stream_result = co_await tcp_connect(ctx, host, port); + auto stream_result = co_await tcp_connect(host, port); if (!stream_result) { ELIO_LOG_ERROR("Connection failed: {}", strerror(errno)); @@ -81,12 +78,9 @@ task client_main(std::string_view host, uint16_t port) { /// Non-interactive benchmark mode task benchmark_main(std::string_view host, uint16_t port, int iterations) { - // Use the default io_context which is polled by scheduler workers - auto& ctx = io::default_io_context(); - ELIO_LOG_INFO("Connecting to {}:{} for benchmark...", host, port); - auto stream_result = co_await tcp_connect(ctx, host, port); + auto stream_result = co_await tcp_connect(host, port); if (!stream_result) { ELIO_LOG_ERROR("Connection failed: {}", strerror(errno)); co_return 1; @@ -189,9 +183,6 @@ int main(int argc, char* argv[]) { // Create scheduler scheduler sched(2); - // Set the I/O context so workers can poll for I/O completions - sched.set_io_context(&io::default_io_context()); - sched.start(); int result = 0; diff --git a/examples/tcp_echo_server.cpp b/examples/tcp_echo_server.cpp index 5f0ec87..fe5eaa2 100644 --- a/examples/tcp_echo_server.cpp +++ b/examples/tcp_echo_server.cpp @@ -199,9 +199,6 @@ int main(int argc, char* argv[]) { // Create scheduler with worker threads scheduler sched(4); - // Set the I/O context so workers can poll for I/O completions - sched.set_io_context(&io::default_io_context()); - sched.start(); // Spawn signal handler coroutine diff --git a/examples/uds_echo_client.cpp b/examples/uds_echo_client.cpp index 9dcbf67..45a3436 100644 --- a/examples/uds_echo_client.cpp +++ b/examples/uds_echo_client.cpp @@ -21,13 +21,10 @@ using namespace elio::net; /// Client coroutine - connects, sends messages, receives responses task client_main(const unix_address& addr) { - // Use the default io_context which is polled by scheduler workers - auto& ctx = io::default_io_context(); - ELIO_LOG_INFO("Connecting to {}...", addr.to_string()); // Connect to server - auto stream_result = co_await uds_connect(ctx, addr); + auto stream_result = co_await uds_connect(addr); if (!stream_result) { ELIO_LOG_ERROR("Connection failed: {}", strerror(errno)); @@ -74,12 +71,9 @@ task client_main(const unix_address& addr) { /// Non-interactive benchmark mode task benchmark_main(const unix_address& addr, int iterations) { - // Use the default io_context which is polled by scheduler workers - auto& ctx = io::default_io_context(); - ELIO_LOG_INFO("Connecting to {} for benchmark...", addr.to_string()); - auto stream_result = co_await uds_connect(ctx, addr); + auto stream_result = co_await uds_connect(addr); if (!stream_result) { ELIO_LOG_ERROR("Connection failed: {}", strerror(errno)); co_return 1; @@ -179,9 +173,6 @@ int main(int argc, char* argv[]) { // Create scheduler scheduler sched(2); - // Set the I/O context so workers can poll for I/O completions - sched.set_io_context(&io::default_io_context()); - sched.start(); int result = 0; diff --git a/examples/uds_echo_server.cpp b/examples/uds_echo_server.cpp index 6a46a14..f71b3a1 100644 --- a/examples/uds_echo_server.cpp +++ b/examples/uds_echo_server.cpp @@ -162,9 +162,6 @@ int main(int argc, char* argv[]) { // Create scheduler with worker threads scheduler sched(4); - // Set the I/O context so workers can poll for I/O completions - sched.set_io_context(&io::default_io_context()); - sched.start(); // Spawn signal handler coroutine diff --git a/examples/websocket_client.cpp b/examples/websocket_client.cpp index 38da4b3..ac6c0da 100644 --- a/examples/websocket_client.cpp +++ b/examples/websocket_client.cpp @@ -258,7 +258,6 @@ int main(int argc, char* argv[]) { // Create scheduler runtime::scheduler sched(2); - sched.set_io_context(&io::default_io_context()); sched.start(); // Run client diff --git a/examples/websocket_server.cpp b/examples/websocket_server.cpp index e7885fc..e004755 100644 --- a/examples/websocket_server.cpp +++ b/examples/websocket_server.cpp @@ -309,7 +309,6 @@ int main(int argc, char* argv[]) { // Create scheduler runtime::scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); // Spawn signal handler coroutine diff --git a/include/elio/http/http.hpp b/include/elio/http/http.hpp index 2e0fc7f..6449d38 100644 --- a/include/elio/http/http.hpp +++ b/include/elio/http/http.hpp @@ -52,7 +52,6 @@ namespace http { /// server srv(std::move(r)); /// /// runtime::scheduler sched(4); -/// sched.set_io_context(&io::default_io_context()); /// sched.start(); /// /// auto task = srv.listen(net::ipv4_address(8080), io::default_io_context(), sched); diff --git a/include/elio/http/http2_client.hpp b/include/elio/http/http2_client.hpp index 197e741..90fa637 100644 --- a/include/elio/http/http2_client.hpp +++ b/include/elio/http/http2_client.hpp @@ -197,7 +197,7 @@ class h2_client { // Create new HTTP/2 connection // First establish TCP connection - auto tcp_result = co_await net::tcp_connect(*io_ctx_, host, port); + auto tcp_result = co_await net::tcp_connect(host, port); if (!tcp_result) { ELIO_LOG_ERROR("Failed to connect to {}:{}", host, port); co_return std::nullopt; diff --git a/include/elio/http/http_client.hpp b/include/elio/http/http_client.hpp index 4e48da7..d58c8d7 100644 --- a/include/elio/http/http_client.hpp +++ b/include/elio/http/http_client.hpp @@ -120,8 +120,7 @@ class connection_pool { : config_(config) {} /// Get or create a connection to host - coro::task> acquire(io::io_context& io_ctx, - const std::string& host, + coro::task> acquire(const std::string& host, uint16_t port, bool secure, tls::tls_context* tls_ctx = nullptr) { @@ -152,7 +151,7 @@ class connection_pool { co_return std::nullopt; } - auto result = co_await tls::tls_connect(*tls_ctx, io_ctx, host, port); + auto result = co_await tls::tls_connect(*tls_ctx, host, port); if (!result) { ELIO_LOG_ERROR("Failed to connect to {}:{}: {}", host, port, strerror(errno)); co_return std::nullopt; @@ -160,7 +159,7 @@ class connection_pool { co_return connection(std::move(*result)); } else { - auto result = co_await net::tcp_connect(io_ctx, host, port); + auto result = co_await net::tcp_connect(host, port); if (!result) { ELIO_LOG_ERROR("Failed to connect to {}:{}: {}", host, port, strerror(errno)); co_return std::nullopt; @@ -361,7 +360,7 @@ class client { } // Get connection from pool - auto conn_opt = co_await pool_.acquire(*io_ctx_, target.host, target.effective_port(), + auto conn_opt = co_await pool_.acquire(target.host, target.effective_port(), target.is_secure(), &tls_ctx_); if (!conn_opt) { errno = ECONNREFUSED; diff --git a/include/elio/http/sse_client.hpp b/include/elio/http/sse_client.hpp index c6b5590..2a37299 100644 --- a/include/elio/http/sse_client.hpp +++ b/include/elio/http/sse_client.hpp @@ -363,7 +363,7 @@ class sse_client { // Establish TCP connection if (url_.is_secure()) { - auto result = co_await tls::tls_connect(tls_ctx_, *io_ctx_, + auto result = co_await tls::tls_connect(tls_ctx_, url_.host, url_.effective_port()); if (!result) { ELIO_LOG_ERROR("Failed to connect to {}:{}: {}", @@ -373,7 +373,7 @@ class sse_client { } stream_ = std::move(*result); } else { - auto result = co_await net::tcp_connect(*io_ctx_, url_.host, + auto result = co_await net::tcp_connect(url_.host, url_.effective_port()); if (!result) { ELIO_LOG_ERROR("Failed to connect to {}:{}: {}", diff --git a/include/elio/http/websocket.hpp b/include/elio/http/websocket.hpp index ededd7b..19b7d30 100644 --- a/include/elio/http/websocket.hpp +++ b/include/elio/http/websocket.hpp @@ -39,7 +39,6 @@ /// ws_server srv(std::move(router)); /// /// runtime::scheduler sched(4); -/// sched.set_io_context(&io::default_io_context()); /// sched.start(); /// /// auto task = srv.listen(net::ipv4_address(8080), diff --git a/include/elio/http/websocket_client.hpp b/include/elio/http/websocket_client.hpp index f7196b9..55e127e 100644 --- a/include/elio/http/websocket_client.hpp +++ b/include/elio/http/websocket_client.hpp @@ -257,14 +257,14 @@ class ws_client { // Establish TCP connection if (secure_) { - auto result = co_await tls::tls_connect(tls_ctx_, *io_ctx_, host_, port); + auto result = co_await tls::tls_connect(tls_ctx_, host_, port); if (!result) { ELIO_LOG_ERROR("Failed to connect to {}:{}: {}", host_, port, strerror(errno)); co_return false; } stream_ = std::move(*result); } else { - auto result = co_await net::tcp_connect(*io_ctx_, host_, port); + auto result = co_await net::tcp_connect(host_, port); if (!result) { ELIO_LOG_ERROR("Failed to connect to {}:{}: {}", host_, port, strerror(errno)); co_return false; diff --git a/include/elio/io/epoll_backend.hpp b/include/elio/io/epoll_backend.hpp index 4992de2..d746bc7 100644 --- a/include/elio/io/epoll_backend.hpp +++ b/include/elio/io/epoll_backend.hpp @@ -13,7 +13,6 @@ #include #include #include -#include #include namespace elio::io { @@ -53,33 +52,28 @@ class epoll_backend : public io_backend { /// Destructor ~epoll_backend() override { - // Collect handles to resume after releasing mutex - std::vector> deferred_resumes; - - { - std::lock_guard lock(mutex_); - for (auto& [fd, state] : fd_states_) { - for (auto& op : state.pending_ops) { - if (op.awaiter && !op.awaiter.done()) { - last_result_ = io_result{-ECANCELED, 0}; - deferred_resumes.push_back(op.awaiter); - } + // Collect handles to resume + std::vector deferred_resumes; + + for (auto& [fd, state] : fd_states_) { + for (auto& op : state.pending_ops) { + if (op.awaiter && !op.awaiter.done()) { + deferred_resumes.push_back({op.awaiter, io_result{-ECANCELED, 0}}); } } - fd_states_.clear(); - - // Cancel all pending timers - while (!timer_queue_.empty()) { - auto& entry = timer_queue_.top(); - if (entry.awaiter && !entry.awaiter.done()) { - last_result_ = io_result{-ECANCELED, 0}; - deferred_resumes.push_back(entry.awaiter); - } - timer_queue_.pop(); + } + fd_states_.clear(); + + // Cancel all pending timers + while (!timer_queue_.empty()) { + auto& entry = timer_queue_.top(); + if (entry.awaiter && !entry.awaiter.done()) { + deferred_resumes.push_back({entry.awaiter, io_result{-ECANCELED, 0}}); } - } // mutex released here + timer_queue_.pop(); + } - // Resume outside lock to prevent deadlock + // Resume coroutines resume_deferred(deferred_resumes); if (epoll_fd_ >= 0) { @@ -97,8 +91,6 @@ class epoll_backend : public io_backend { /// Prepare an I/O operation bool prepare(const io_request& req) override { - std::lock_guard lock(mutex_); - pending_operation op; op.req = req; op.awaiter = req.awaiter; @@ -185,7 +177,6 @@ class epoll_backend : public io_backend { /// For epoll, operations are "submitted" when they're added /// This just executes any synchronous operations int submit() override { - std::lock_guard lock(mutex_); int submitted = 0; // Execute synchronous operations (like close) @@ -215,19 +206,16 @@ class epoll_backend : public io_backend { } // Adjust timeout based on earliest timer deadline - { - std::lock_guard lock(mutex_); - if (!timer_queue_.empty()) { - auto now = std::chrono::steady_clock::now(); - auto earliest = timer_queue_.top().deadline; - if (earliest <= now) { - timeout_ms = 0; // Timer already expired - } else { - auto timer_timeout = std::chrono::duration_cast( - earliest - now).count(); - if (timeout_ms < 0 || timer_timeout < timeout_ms) { - timeout_ms = static_cast(timer_timeout); - } + if (!timer_queue_.empty()) { + auto now = std::chrono::steady_clock::now(); + auto earliest = timer_queue_.top().deadline; + if (earliest <= now) { + timeout_ms = 0; // Timer already expired + } else { + auto timer_timeout = std::chrono::duration_cast( + earliest - now).count(); + if (timeout_ms < 0 || timer_timeout < timeout_ms) { + timeout_ms = static_cast(timer_timeout); } } } @@ -245,42 +233,38 @@ class epoll_backend : public io_backend { int completions = 0; - // Collect handles to resume after releasing mutex (prevents deadlock - // when resumed coroutines call prepare()) - std::vector> deferred_resumes; + // Collect handles to resume after processing all completions + std::vector deferred_resumes; - { - std::lock_guard lock(mutex_); + // Process expired timers + auto now = std::chrono::steady_clock::now(); + while (!timer_queue_.empty() && timer_queue_.top().deadline <= now) { + auto entry = timer_queue_.top(); + timer_queue_.pop(); - // Process expired timers - auto now = std::chrono::steady_clock::now(); - while (!timer_queue_.empty() && timer_queue_.top().deadline <= now) { - auto entry = timer_queue_.top(); - timer_queue_.pop(); - - last_result_ = io_result{0, 0}; // Timeout completed successfully - - if (entry.awaiter && !entry.awaiter.done()) { - deferred_resumes.push_back(entry.awaiter); - } - - pending_count_--; - completions++; - - ELIO_LOG_DEBUG("Timer expired"); + io_result result{0, 0}; // Timeout completed successfully + + if (entry.awaiter && !entry.awaiter.done()) { + deferred_resumes.push_back({entry.awaiter, result}); } - // Process epoll events - for (int i = 0; i < nfds; ++i) { - int fd = events_[i].data.fd; - uint32_t revents = events_[i].events; - - auto it = fd_states_.find(fd); - if (it == fd_states_.end()) { - continue; - } - - fd_state& state = it->second; + pending_count_--; + completions++; + + ELIO_LOG_DEBUG("Timer expired"); + } + + // Process epoll events + for (int i = 0; i < nfds; ++i) { + int fd = events_[i].data.fd; + uint32_t revents = events_[i].events; + + auto it = fd_states_.find(fd); + if (it == fd_states_.end()) { + continue; + } + + fd_state& state = it->second; // Process pending operations for this fd auto op_it = state.pending_ops.begin(); @@ -325,9 +309,8 @@ class epoll_backend : public io_backend { state.events = 0; } } - } // mutex released here - // Resume coroutines outside the lock to prevent deadlock + // Resume coroutines after processing all completions resume_deferred(deferred_resumes); if (completions > 0) { @@ -349,66 +332,64 @@ class epoll_backend : public io_backend { /// Cancel a pending operation bool cancel(void* user_data) override { - std::coroutine_handle<> to_resume; + deferred_resume_entry to_resume{}; + bool found_entry = false; - { - std::lock_guard lock(mutex_); + // Search in fd_states first + for (auto& [fd, state] : fd_states_) { + auto it = std::find_if(state.pending_ops.begin(), + state.pending_ops.end(), + [user_data](const pending_operation& op) { + return op.awaiter.address() == user_data; + }); - // Search in fd_states first - for (auto& [fd, state] : fd_states_) { - auto it = std::find_if(state.pending_ops.begin(), - state.pending_ops.end(), - [user_data](const pending_operation& op) { - return op.awaiter.address() == user_data; - }); + if (it != state.pending_ops.end()) { + // Collect handle for deferred resumption + if (it->awaiter && !it->awaiter.done()) { + to_resume = {it->awaiter, io_result{-ECANCELED, 0}}; + found_entry = true; + } + state.pending_ops.erase(it); + pending_count_--; - if (it != state.pending_ops.end()) { - // Collect handle for deferred resumption - last_result_ = io_result{-ECANCELED, 0}; - if (it->awaiter && !it->awaiter.done()) { - to_resume = it->awaiter; - } - state.pending_ops.erase(it); - pending_count_--; - - // Cleanup if no more pending ops - if (state.pending_ops.empty() && state.registered) { - epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr); - state.registered = false; - state.events = 0; - } - goto found; + // Cleanup if no more pending ops + if (state.pending_ops.empty() && state.registered) { + epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr); + state.registered = false; + state.events = 0; } + goto found; } - - // Search in timer queue - need to rebuild queue without the cancelled entry - if (!timer_queue_.empty()) { - std::vector remaining; - while (!timer_queue_.empty()) { - auto entry = timer_queue_.top(); - timer_queue_.pop(); - if (entry.awaiter.address() == user_data) { - last_result_ = io_result{-ECANCELED, 0}; - if (entry.awaiter && !entry.awaiter.done()) { - to_resume = entry.awaiter; - } - pending_count_--; - // Don't add back to remaining - } else { - remaining.push_back(entry); + } + + // Search in timer queue - need to rebuild queue without the cancelled entry + if (!timer_queue_.empty()) { + std::vector remaining; + while (!timer_queue_.empty()) { + auto entry = timer_queue_.top(); + timer_queue_.pop(); + if (entry.awaiter.address() == user_data) { + if (entry.awaiter && !entry.awaiter.done()) { + to_resume = {entry.awaiter, io_result{-ECANCELED, 0}}; + found_entry = true; } - } - // Rebuild queue - for (auto& e : remaining) { - timer_queue_.push(std::move(e)); + pending_count_--; + // Don't add back to remaining + } else { + remaining.push_back(entry); } } - found:; - } // mutex released here + // Rebuild queue + for (auto& e : remaining) { + timer_queue_.push(std::move(e)); + } + } + found: - // Resume outside lock to prevent deadlock - if (to_resume && !to_resume.done()) { - to_resume.resume(); + // Resume the cancelled coroutine + if (found_entry && to_resume.handle && !to_resume.handle.done()) { + last_result_ = to_resume.result; + to_resume.handle.resume(); return true; } @@ -449,6 +430,12 @@ class epoll_backend : public io_backend { } }; + /// Deferred resume entry - stores handle with its result + struct deferred_resume_entry { + std::coroutine_handle<> handle; + io_result result; + }; + /// Min-heap priority queue for timers using timer_queue_t = std::priority_queue, @@ -480,9 +467,9 @@ class epoll_backend : public io_backend { /// Execute async I/O operation /// @param op The pending operation to execute /// @param revents The epoll events that triggered this operation - /// @param deferred_resumes If non-null, collect handle for later resumption (avoids deadlock) + /// @param deferred_resumes If non-null, collect handle+result for later resumption (avoids deadlock) void execute_async_op(pending_operation& op, uint32_t revents, - std::vector>* deferred_resumes = nullptr) { + std::vector* deferred_resumes = nullptr) { int result = 0; // Check for errors first @@ -579,25 +566,28 @@ class epoll_backend : public io_backend { } } - last_result_ = io_result{result, 0}; + io_result io_res{result, 0}; ELIO_LOG_DEBUG("Completed io_op::{} on fd={}: result={}", static_cast(op.req.op), op.req.fd, result); if (op.awaiter && !op.awaiter.done()) { if (deferred_resumes) { - deferred_resumes->push_back(op.awaiter); + deferred_resumes->push_back({op.awaiter, io_res}); } else { + last_result_ = io_res; op.awaiter.resume(); } } } /// Resume collected coroutine handles (call outside of lock) - static void resume_deferred(std::vector>& handles) { - for (auto& h : handles) { - if (h && !h.done()) { - h.resume(); + /// Sets last_result_ before resuming each coroutine + static void resume_deferred(std::vector& entries) { + for (auto& entry : entries) { + if (entry.handle && !entry.handle.done()) { + last_result_ = entry.result; + entry.handle.resume(); } } } @@ -608,7 +598,6 @@ class epoll_backend : public io_backend { std::unordered_map fd_states_; ///< Per-fd state timer_queue_t timer_queue_; ///< Timer queue for timeouts size_t pending_count_ = 0; ///< Number of pending operations - mutable std::mutex mutex_; ///< Protects fd_states_ and timer_queue_ static inline thread_local io_result last_result_{}; }; diff --git a/include/elio/io/io_awaitables.hpp b/include/elio/io/io_awaitables.hpp index ae14475..dfdab5e 100644 --- a/include/elio/io/io_awaitables.hpp +++ b/include/elio/io/io_awaitables.hpp @@ -2,6 +2,9 @@ #include "io_context.hpp" #include +#include +#include +#include #include #include #include @@ -9,12 +12,21 @@ namespace elio::io { +/// Get the io_context for the current worker thread +/// Falls back to default_io_context if not running in a worker +inline io_context& current_io_context() noexcept { + auto* worker = runtime::worker_thread::current(); + if (worker) { + return worker->io_context(); + } + return default_io_context(); +} + /// Base class for I/O awaitables /// Provides common functionality for all async I/O operations class io_awaitable_base { public: - explicit io_awaitable_base(io_context& ctx) noexcept - : ctx_(ctx) {} + io_awaitable_base() noexcept = default; /// Never ready immediately - always suspend bool await_ready() const noexcept { @@ -27,22 +39,54 @@ class io_awaitable_base { } protected: - io_context& ctx_; io_result result_{}; + size_t saved_affinity_ = coro::NO_AFFINITY; + void* handle_address_ = nullptr; + + /// Save current affinity and bind to current worker + template + void bind_to_worker(std::coroutine_handle handle) { + handle_address_ = handle.address(); + if constexpr (std::is_base_of_v) { + saved_affinity_ = handle.promise().affinity(); + auto* worker = runtime::worker_thread::current(); + if (worker) { + handle.promise().set_affinity(worker->worker_id()); + } + } + } + + /// Restore previous affinity (called from await_resume) + void restore_affinity() { + if (handle_address_) { + auto* promise = coro::get_promise_base(handle_address_); + if (promise) { + if (saved_affinity_ == coro::NO_AFFINITY) { + promise->clear_affinity(); + } else { + promise->set_affinity(saved_affinity_); + } + } + } + } }; /// Awaitable for async read operations class async_read_awaitable : public io_awaitable_base { public: - async_read_awaitable(io_context& ctx, int fd, void* buffer, size_t length, + async_read_awaitable(int fd, void* buffer, size_t length, int64_t offset = -1) noexcept - : io_awaitable_base(ctx) + : io_awaitable_base() , fd_(fd) , buffer_(buffer) , length_(length) , offset_(offset) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::read; req.fd = fd_; @@ -51,16 +95,17 @@ class async_read_awaitable : public io_awaitable_base { req.offset = offset_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -74,15 +119,19 @@ class async_read_awaitable : public io_awaitable_base { /// Awaitable for async write operations class async_write_awaitable : public io_awaitable_base { public: - async_write_awaitable(io_context& ctx, int fd, const void* buffer, + async_write_awaitable(int fd, const void* buffer, size_t length, int64_t offset = -1) noexcept - : io_awaitable_base(ctx) + : io_awaitable_base() , fd_(fd) , buffer_(buffer) , length_(length) , offset_(offset) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::write; req.fd = fd_; @@ -91,16 +140,17 @@ class async_write_awaitable : public io_awaitable_base { req.offset = offset_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -114,15 +164,19 @@ class async_write_awaitable : public io_awaitable_base { /// Awaitable for async recv operations class async_recv_awaitable : public io_awaitable_base { public: - async_recv_awaitable(io_context& ctx, int fd, void* buffer, size_t length, + async_recv_awaitable(int fd, void* buffer, size_t length, int flags = 0) noexcept - : io_awaitable_base(ctx) + : io_awaitable_base() , fd_(fd) , buffer_(buffer) , length_(length) , flags_(flags) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::recv; req.fd = fd_; @@ -131,16 +185,17 @@ class async_recv_awaitable : public io_awaitable_base { req.socket_flags = flags_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -154,15 +209,19 @@ class async_recv_awaitable : public io_awaitable_base { /// Awaitable for async send operations class async_send_awaitable : public io_awaitable_base { public: - async_send_awaitable(io_context& ctx, int fd, const void* buffer, + async_send_awaitable(int fd, const void* buffer, size_t length, int flags = 0) noexcept - : io_awaitable_base(ctx) + : io_awaitable_base() , fd_(fd) , buffer_(buffer) , length_(length) , flags_(flags) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::send; req.fd = fd_; @@ -171,16 +230,17 @@ class async_send_awaitable : public io_awaitable_base { req.socket_flags = flags_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -194,17 +254,21 @@ class async_send_awaitable : public io_awaitable_base { /// Awaitable for async accept operations class async_accept_awaitable : public io_awaitable_base { public: - async_accept_awaitable(io_context& ctx, int listen_fd, + async_accept_awaitable(int listen_fd, struct sockaddr* addr = nullptr, socklen_t* addrlen = nullptr, int flags = 0) noexcept - : io_awaitable_base(ctx) + : io_awaitable_base() , listen_fd_(listen_fd) , addr_(addr) , addrlen_(addrlen) , flags_(flags) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::accept; req.fd = listen_fd_; @@ -213,16 +277,17 @@ class async_accept_awaitable : public io_awaitable_base { req.socket_flags = flags_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -241,15 +306,19 @@ class async_accept_awaitable : public io_awaitable_base { /// Awaitable for async connect operations class async_connect_awaitable : public io_awaitable_base { public: - async_connect_awaitable(io_context& ctx, int fd, + async_connect_awaitable(int fd, const struct sockaddr* addr, socklen_t addrlen) noexcept - : io_awaitable_base(ctx) + : io_awaitable_base() , fd_(fd) , addr_(addr) , addrlen_(addrlen) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::connect; req.fd = fd_; @@ -257,16 +326,17 @@ class async_connect_awaitable : public io_awaitable_base { req.addrlen = &addrlen_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -279,26 +349,31 @@ class async_connect_awaitable : public io_awaitable_base { /// Awaitable for async close operations class async_close_awaitable : public io_awaitable_base { public: - async_close_awaitable(io_context& ctx, int fd) noexcept - : io_awaitable_base(ctx) + async_close_awaitable(int fd) noexcept + : io_awaitable_base() , fd_(fd) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::close; req.fd = fd_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -309,14 +384,18 @@ class async_close_awaitable : public io_awaitable_base { /// Awaitable for async writev operations (scatter-gather write) class async_writev_awaitable : public io_awaitable_base { public: - async_writev_awaitable(io_context& ctx, int fd, struct iovec* iovecs, + async_writev_awaitable(int fd, struct iovec* iovecs, size_t iovec_count) noexcept - : io_awaitable_base(ctx) + : io_awaitable_base() , fd_(fd) , iovecs_(iovecs) , iovec_count_(iovec_count) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = io_op::writev; req.fd = fd_; @@ -324,16 +403,17 @@ class async_writev_awaitable : public io_awaitable_base { req.iovec_count = iovec_count_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -346,27 +426,32 @@ class async_writev_awaitable : public io_awaitable_base { /// Awaitable for poll (wait for socket readable/writable) class async_poll_awaitable : public io_awaitable_base { public: - async_poll_awaitable(io_context& ctx, int fd, bool for_read) noexcept - : io_awaitable_base(ctx) + async_poll_awaitable(int fd, bool for_read) noexcept + : io_awaitable_base() , fd_(fd) , for_read_(for_read) {} - void await_suspend(std::coroutine_handle<> awaiter) { + template + void await_suspend(std::coroutine_handle awaiter) { + bind_to_worker(awaiter); + auto& ctx = current_io_context(); + io_request req{}; req.op = for_read_ ? io_op::poll_read : io_op::poll_write; req.fd = fd_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { result_ = io_result{-EAGAIN, 0}; awaiter.resume(); return; } - ctx_.submit(); + ctx.submit(); } io_result await_resume() noexcept { result_ = io_context::get_last_result(); + restore_affinity(); return result_; } @@ -378,119 +463,82 @@ class async_poll_awaitable : public io_awaitable_base { /// Factory functions for creating awaitables /// Create an async read awaitable -/// @param ctx The I/O context /// @param fd File descriptor to read from /// @param buffer Buffer to read into /// @param length Number of bytes to read /// @param offset File offset (-1 for current position) -inline auto async_read(io_context& ctx, int fd, void* buffer, size_t length, +inline auto async_read(int fd, void* buffer, size_t length, int64_t offset = -1) { - return async_read_awaitable(ctx, fd, buffer, length, offset); + return async_read_awaitable(fd, buffer, length, offset); } /// Create an async read awaitable using span template -inline auto async_read(io_context& ctx, int fd, std::span buffer, +inline auto async_read(int fd, std::span buffer, int64_t offset = -1) { - return async_read_awaitable(ctx, fd, buffer.data(), + return async_read_awaitable(fd, buffer.data(), buffer.size_bytes(), offset); } /// Create an async write awaitable -inline auto async_write(io_context& ctx, int fd, const void* buffer, +inline auto async_write(int fd, const void* buffer, size_t length, int64_t offset = -1) { - return async_write_awaitable(ctx, fd, buffer, length, offset); + return async_write_awaitable(fd, buffer, length, offset); } /// Create an async write awaitable using span template -inline auto async_write(io_context& ctx, int fd, std::span buffer, +inline auto async_write(int fd, std::span buffer, int64_t offset = -1) { - return async_write_awaitable(ctx, fd, buffer.data(), + return async_write_awaitable(fd, buffer.data(), buffer.size_bytes(), offset); } /// Create an async recv awaitable -inline auto async_recv(io_context& ctx, int fd, void* buffer, size_t length, +inline auto async_recv(int fd, void* buffer, size_t length, int flags = 0) { - return async_recv_awaitable(ctx, fd, buffer, length, flags); + return async_recv_awaitable(fd, buffer, length, flags); } /// Create an async send awaitable -inline auto async_send(io_context& ctx, int fd, const void* buffer, +inline auto async_send(int fd, const void* buffer, size_t length, int flags = 0) { - return async_send_awaitable(ctx, fd, buffer, length, flags); + return async_send_awaitable(fd, buffer, length, flags); } /// Create an async writev awaitable (scatter-gather write) -inline auto async_writev(io_context& ctx, int fd, struct iovec* iovecs, +inline auto async_writev(int fd, struct iovec* iovecs, size_t iovec_count) { - return async_writev_awaitable(ctx, fd, iovecs, iovec_count); + return async_writev_awaitable(fd, iovecs, iovec_count); } /// Create an async accept awaitable -inline auto async_accept(io_context& ctx, int listen_fd, +inline auto async_accept(int listen_fd, struct sockaddr* addr = nullptr, socklen_t* addrlen = nullptr, int flags = 0) { - return async_accept_awaitable(ctx, listen_fd, addr, addrlen, flags); + return async_accept_awaitable(listen_fd, addr, addrlen, flags); } /// Create an async connect awaitable -inline auto async_connect(io_context& ctx, int fd, +inline auto async_connect(int fd, const struct sockaddr* addr, socklen_t addrlen) { - return async_connect_awaitable(ctx, fd, addr, addrlen); + return async_connect_awaitable(fd, addr, addrlen); } /// Create an async close awaitable -inline auto async_close(io_context& ctx, int fd) { - return async_close_awaitable(ctx, fd); +inline auto async_close(int fd) { + return async_close_awaitable(fd); } /// Create an async poll awaitable for reading -inline auto async_poll_read(io_context& ctx, int fd) { - return async_poll_awaitable(ctx, fd, true); +inline auto async_poll_read(int fd) { + return async_poll_awaitable(fd, true); } /// Create an async poll awaitable for writing -inline auto async_poll_write(io_context& ctx, int fd) { - return async_poll_awaitable(ctx, fd, false); -} - -// Convenience overloads using default io_context - -inline auto async_read(int fd, void* buffer, size_t length, int64_t offset = -1) { - return async_read(default_io_context(), fd, buffer, length, offset); -} - -inline auto async_write(int fd, const void* buffer, size_t length, - int64_t offset = -1) { - return async_write(default_io_context(), fd, buffer, length, offset); -} - -inline auto async_recv(int fd, void* buffer, size_t length, int flags = 0) { - return async_recv(default_io_context(), fd, buffer, length, flags); -} - -inline auto async_send(int fd, const void* buffer, size_t length, int flags = 0) { - return async_send(default_io_context(), fd, buffer, length, flags); -} - -inline auto async_writev(int fd, struct iovec* iovecs, size_t iovec_count) { - return async_writev(default_io_context(), fd, iovecs, iovec_count); -} - -inline auto async_accept(int listen_fd, struct sockaddr* addr = nullptr, - socklen_t* addrlen = nullptr, int flags = 0) { - return async_accept(default_io_context(), listen_fd, addr, addrlen, flags); -} - -inline auto async_connect(int fd, const struct sockaddr* addr, socklen_t addrlen) { - return async_connect(default_io_context(), fd, addr, addrlen); -} - -inline auto async_close(int fd) { - return async_close(default_io_context(), fd); +inline auto async_poll_write(int fd) { + return async_poll_awaitable(fd, false); } } // namespace elio::io diff --git a/include/elio/io/io_uring_backend.hpp b/include/elio/io/io_uring_backend.hpp index 52c77bd..b6803df 100644 --- a/include/elio/io/io_uring_backend.hpp +++ b/include/elio/io/io_uring_backend.hpp @@ -201,11 +201,12 @@ class io_uring_backend : public io_backend { int poll(std::chrono::milliseconds timeout) override { struct io_uring_cqe* cqe = nullptr; int completions = 0; + std::vector deferred_resumes; if (timeout.count() == 0) { // Non-blocking: peek for available CQEs while (io_uring_peek_cqe(&ring_, &cqe) == 0 && cqe) { - process_completion(cqe); + process_completion(cqe, &deferred_resumes); io_uring_cqe_seen(&ring_, cqe); completions++; cqe = nullptr; @@ -214,14 +215,14 @@ class io_uring_backend : public io_backend { // Blocking: wait for at least one CQE int ret = io_uring_wait_cqe(&ring_, &cqe); if (ret == 0 && cqe) { - process_completion(cqe); + process_completion(cqe, &deferred_resumes); io_uring_cqe_seen(&ring_, cqe); completions++; // Drain any additional ready CQEs cqe = nullptr; while (io_uring_peek_cqe(&ring_, &cqe) == 0 && cqe) { - process_completion(cqe); + process_completion(cqe, &deferred_resumes); io_uring_cqe_seen(&ring_, cqe); completions++; cqe = nullptr; @@ -235,14 +236,14 @@ class io_uring_backend : public io_backend { int ret = io_uring_wait_cqe_timeout(&ring_, &cqe, &ts); if (ret == 0 && cqe) { - process_completion(cqe); + process_completion(cqe, &deferred_resumes); io_uring_cqe_seen(&ring_, cqe); completions++; // Drain any additional ready CQEs cqe = nullptr; while (io_uring_peek_cqe(&ring_, &cqe) == 0 && cqe) { - process_completion(cqe); + process_completion(cqe, &deferred_resumes); io_uring_cqe_seen(&ring_, cqe); completions++; cqe = nullptr; @@ -250,6 +251,10 @@ class io_uring_backend : public io_backend { } } + // Resume coroutines after processing all completions + // Each coroutine gets its correct result via deferred_resume_entry + resume_deferred(deferred_resumes); + if (completions > 0) { ELIO_LOG_DEBUG("Processed {} completions", completions); } @@ -278,6 +283,9 @@ class io_uring_backend : public io_backend { io_uring_sqe_set_data(sqe, nullptr); // No awaiter for cancel itself pending_ops_.fetch_add(1, std::memory_order_relaxed); + + // Submit immediately - cancel must be acted upon right away + io_uring_submit(&ring_); return true; } @@ -293,7 +301,14 @@ class io_uring_backend : public io_backend { } private: - void process_completion(struct io_uring_cqe* cqe) { + /// Deferred resume entry - stores handle with its result + struct deferred_resume_entry { + std::coroutine_handle<> handle; + io_result result; + }; + + void process_completion(struct io_uring_cqe* cqe, + std::vector* deferred_resumes = nullptr) { void* user_data = io_uring_cqe_get_data(cqe); pending_ops_.fetch_sub(1, std::memory_order_relaxed); @@ -304,16 +319,28 @@ class io_uring_backend : public io_backend { // Store result in promise and resume coroutine auto handle = std::coroutine_handle<>::from_address(user_data); - - // The result is stored via the awaitable's promise - // We need a way to pass the result - use a thread-local for simplicity - last_result_ = io_result{cqe->res, cqe->flags}; + io_result result{cqe->res, cqe->flags}; ELIO_LOG_DEBUG("Completing operation: result={}, flags={}", cqe->res, cqe->flags); if (handle && !handle.done()) { - handle.resume(); + if (deferred_resumes) { + deferred_resumes->push_back({handle, result}); + } else { + last_result_ = result; + handle.resume(); + } + } + } + + /// Resume collected coroutine handles (call outside of lock) + static void resume_deferred(std::vector& entries) { + for (auto& entry : entries) { + if (entry.handle && !entry.handle.done()) { + last_result_ = entry.result; + entry.handle.resume(); + } } } diff --git a/include/elio/net/tcp.hpp b/include/elio/net/tcp.hpp index d127844..b4e1fec 100644 --- a/include/elio/net/tcp.hpp +++ b/include/elio/net/tcp.hpp @@ -349,51 +349,51 @@ class tcp_stream { /// Async read auto read(void* buffer, size_t length) { - return io::async_recv(*ctx_, fd_, buffer, length); + return io::async_recv(fd_, buffer, length); } /// Async read into span template auto read(std::span buffer) { - return io::async_recv(*ctx_, fd_, buffer.data(), buffer.size_bytes()); + return io::async_recv(fd_, buffer.data(), buffer.size_bytes()); } /// Async write auto write(const void* buffer, size_t length) { - return io::async_send(*ctx_, fd_, buffer, length); + return io::async_send(fd_, buffer, length); } /// Async write from span template auto write(std::span buffer) { - return io::async_send(*ctx_, fd_, buffer.data(), buffer.size_bytes()); + return io::async_send(fd_, buffer.data(), buffer.size_bytes()); } /// Async write string auto write(std::string_view str) { - return io::async_send(*ctx_, fd_, str.data(), str.size()); + return io::async_send(fd_, str.data(), str.size()); } /// Async writev (scatter-gather write) auto writev(struct iovec* iovecs, size_t count) { - return io::async_writev(*ctx_, fd_, iovecs, count); + return io::async_writev(fd_, iovecs, count); } /// Wait for socket to be readable auto poll_read() { - return io::async_poll_read(*ctx_, fd_); + return io::async_poll_read(fd_); } /// Wait for socket to be writable auto poll_write() { - return io::async_poll_write(*ctx_, fd_); + return io::async_poll_write(fd_); } /// Async close auto close() { int fd = fd_; fd_ = -1; - return io::async_close(*ctx_, fd); + return io::async_close(fd); } /// Set TCP_NODELAY option @@ -504,6 +504,8 @@ class tcp_listener { bool await_ready() const noexcept { return false; } void await_suspend(std::coroutine_handle<> awaiter) { + auto& ctx = io::current_io_context(); + io::io_request req{}; req.op = io::io_op::accept; req.fd = listener_.fd_; @@ -512,12 +514,12 @@ class tcp_listener { req.socket_flags = SOCK_NONBLOCK | SOCK_CLOEXEC; req.awaiter = awaiter; - if (!listener_.ctx_->prepare(req)) { + if (!ctx.prepare(req)) { result_ = io::io_result{-EAGAIN, 0}; awaiter.resume(); return; } - listener_.ctx_->submit(); + ctx.submit(); } std::optional await_resume() { @@ -653,9 +655,9 @@ class tcp_listener { /// Connect to a remote TCP server class tcp_connect_awaitable { public: - tcp_connect_awaitable(io::io_context& ctx, const socket_address& addr, + tcp_connect_awaitable(const socket_address& addr, const tcp_options& opts = {}) - : ctx_(ctx), addr_(addr), opts_(opts) {} + : addr_(addr), opts_(opts) {} bool await_ready() const noexcept { return false; } @@ -692,6 +694,8 @@ class tcp_connect_awaitable { } // Connection in progress, wait for socket to become writable + auto& ctx = io::current_io_context(); + io::io_request req{}; req.op = io::io_op::connect; req.fd = fd_; @@ -699,13 +703,13 @@ class tcp_connect_awaitable { req.addrlen = &sa_len_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { ::close(fd_); fd_ = -1; result_ = io::io_result{-EAGAIN, 0}; return false; // Don't suspend, resume immediately } - ctx_.submit(); + ctx.submit(); return true; // Suspend, will be resumed by epoll } @@ -726,7 +730,7 @@ class tcp_connect_awaitable { return std::nullopt; } - tcp_stream stream(fd_, ctx_); + tcp_stream stream(fd_, io::current_io_context()); fd_ = -1; // Transfer ownership stream.set_peer_address(addr_); @@ -736,7 +740,6 @@ class tcp_connect_awaitable { } private: - io::io_context& ctx_; socket_address addr_; tcp_options opts_; struct sockaddr_storage sa_{}; @@ -746,27 +749,27 @@ class tcp_connect_awaitable { }; /// Connect to a remote TCP server (IPv4) -inline auto tcp_connect(io::io_context& ctx, const ipv4_address& addr, +inline auto tcp_connect(const ipv4_address& addr, const tcp_options& opts = {}) { - return tcp_connect_awaitable(ctx, socket_address(addr), opts); + return tcp_connect_awaitable(socket_address(addr), opts); } /// Connect to a remote TCP server (IPv6) -inline auto tcp_connect(io::io_context& ctx, const ipv6_address& addr, +inline auto tcp_connect(const ipv6_address& addr, const tcp_options& opts = {}) { - return tcp_connect_awaitable(ctx, socket_address(addr), opts); + return tcp_connect_awaitable(socket_address(addr), opts); } /// Connect to a remote TCP server (generic address) -inline auto tcp_connect(io::io_context& ctx, const socket_address& addr, +inline auto tcp_connect(const socket_address& addr, const tcp_options& opts = {}) { - return tcp_connect_awaitable(ctx, addr, opts); + return tcp_connect_awaitable(addr, opts); } /// Connect to a remote TCP server by host and port (auto-detects IPv4/IPv6) -inline auto tcp_connect(io::io_context& ctx, std::string_view host, uint16_t port, +inline auto tcp_connect(std::string_view host, uint16_t port, const tcp_options& opts = {}) { - return tcp_connect_awaitable(ctx, socket_address(host, port), opts); + return tcp_connect_awaitable(socket_address(host, port), opts); } } // namespace elio::net diff --git a/include/elio/net/uds.hpp b/include/elio/net/uds.hpp index ed8ab47..d930020 100644 --- a/include/elio/net/uds.hpp +++ b/include/elio/net/uds.hpp @@ -177,51 +177,51 @@ class uds_stream { /// Async read auto read(void* buffer, size_t length) { - return io::async_recv(*ctx_, fd_, buffer, length); + return io::async_recv(fd_, buffer, length); } /// Async read into span template auto read(std::span buffer) { - return io::async_recv(*ctx_, fd_, buffer.data(), buffer.size_bytes()); + return io::async_recv(fd_, buffer.data(), buffer.size_bytes()); } /// Async write auto write(const void* buffer, size_t length) { - return io::async_send(*ctx_, fd_, buffer, length); + return io::async_send(fd_, buffer, length); } /// Async write from span template auto write(std::span buffer) { - return io::async_send(*ctx_, fd_, buffer.data(), buffer.size_bytes()); + return io::async_send(fd_, buffer.data(), buffer.size_bytes()); } /// Async write string auto write(std::string_view str) { - return io::async_send(*ctx_, fd_, str.data(), str.size()); + return io::async_send(fd_, str.data(), str.size()); } /// Async writev (scatter-gather write) auto writev(struct iovec* iovecs, size_t count) { - return io::async_writev(*ctx_, fd_, iovecs, count); + return io::async_writev(fd_, iovecs, count); } /// Wait for socket to be readable auto poll_read() { - return io::async_poll_read(*ctx_, fd_); + return io::async_poll_read(fd_); } /// Wait for socket to be writable auto poll_write() { - return io::async_poll_write(*ctx_, fd_); + return io::async_poll_write(fd_); } /// Async close auto close() { int fd = fd_; fd_ = -1; - return io::async_close(*ctx_, fd); + return io::async_close(fd); } /// Set SO_RCVBUF option @@ -362,6 +362,8 @@ class uds_listener { bool await_ready() const noexcept { return false; } void await_suspend(std::coroutine_handle<> awaiter) { + auto& ctx = io::current_io_context(); + io::io_request req{}; req.op = io::io_op::accept; req.fd = listener_.fd_; @@ -370,12 +372,12 @@ class uds_listener { req.socket_flags = SOCK_NONBLOCK | SOCK_CLOEXEC; req.awaiter = awaiter; - if (!listener_.ctx_->prepare(req)) { + if (!ctx.prepare(req)) { result_ = io::io_result{-EAGAIN, 0}; awaiter.resume(); return; } - listener_.ctx_->submit(); + ctx.submit(); } std::optional await_resume() { @@ -443,9 +445,9 @@ class uds_listener { /// Connect to a Unix Domain Socket server class uds_connect_awaitable { public: - uds_connect_awaitable(io::io_context& ctx, const unix_address& addr, + uds_connect_awaitable(const unix_address& addr, const uds_options& opts = {}) - : ctx_(ctx), addr_(addr), opts_(opts) {} + : addr_(addr), opts_(opts) {} bool await_ready() const noexcept { return false; } @@ -485,6 +487,8 @@ class uds_connect_awaitable { } // Connection in progress, wait for socket to become writable + auto& ctx = io::current_io_context(); + io::io_request req{}; req.op = io::io_op::connect; req.fd = fd_; @@ -492,13 +496,13 @@ class uds_connect_awaitable { req.addrlen = &sa_len_; req.awaiter = awaiter; - if (!ctx_.prepare(req)) { + if (!ctx.prepare(req)) { ::close(fd_); fd_ = -1; result_ = io::io_result{-EAGAIN, 0}; return false; // Don't suspend, resume immediately } - ctx_.submit(); + ctx.submit(); return true; // Suspend, will be resumed by epoll } @@ -523,7 +527,7 @@ class uds_connect_awaitable { return std::nullopt; } - uds_stream stream(fd_, ctx_); + uds_stream stream(fd_, io::current_io_context()); fd_ = -1; // Transfer ownership stream.set_peer_address(addr_); @@ -533,7 +537,6 @@ class uds_connect_awaitable { } private: - io::io_context& ctx_; unix_address addr_; uds_options opts_; struct sockaddr_un sa_{}; @@ -543,15 +546,15 @@ class uds_connect_awaitable { }; /// Connect to a Unix Domain Socket server -inline auto uds_connect(io::io_context& ctx, const unix_address& addr, +inline auto uds_connect(const unix_address& addr, const uds_options& opts = {}) { - return uds_connect_awaitable(ctx, addr, opts); + return uds_connect_awaitable(addr, opts); } /// Connect to a Unix Domain Socket server by path -inline auto uds_connect(io::io_context& ctx, std::string_view path, +inline auto uds_connect(std::string_view path, const uds_options& opts = {}) { - return uds_connect_awaitable(ctx, unix_address(path), opts); + return uds_connect_awaitable(unix_address(path), opts); } } // namespace elio::net diff --git a/include/elio/rpc/rpc.hpp b/include/elio/rpc/rpc.hpp index 05cb627..5ee385e 100644 --- a/include/elio/rpc/rpc.hpp +++ b/include/elio/rpc/rpc.hpp @@ -36,7 +36,7 @@ /// }); /// /// // Client side -/// auto client = co_await elio::rpc::tcp_rpc_client::connect(ctx, addr); +/// auto client = co_await elio::rpc::tcp_rpc_client::connect(addr); /// auto result = co_await client->call(GetUserRequest{42}); /// if (result.ok()) { /// std::cout << "Name: " << result->name << std::endl; diff --git a/include/elio/rpc/rpc_client.hpp b/include/elio/rpc/rpc_client.hpp index a037ac4..2c0ebfc 100644 --- a/include/elio/rpc/rpc_client.hpp +++ b/include/elio/rpc/rpc_client.hpp @@ -11,7 +11,7 @@ /// /// Usage: /// @code -/// auto client = co_await rpc_client::connect(ctx, addr); +/// auto client = co_await rpc_client::connect(addr); /// auto result = co_await client->call(request, 5000ms); /// if (result.ok()) { /// process(result.value()); @@ -81,12 +81,10 @@ class rpc_client : public std::enable_shared_from_this> { /// Connect to a TCP server and create client template - static coro::task> connect( - io::io_context& ctx, - Args&&... args) + static coro::task> connect(Args&&... args) requires std::is_same_v { - auto stream = co_await net::tcp_connect(ctx, std::forward(args)...); + auto stream = co_await net::tcp_connect(std::forward(args)...); if (!stream) { co_return std::nullopt; } @@ -97,12 +95,10 @@ class rpc_client : public std::enable_shared_from_this> { /// Connect to a UDS server and create client template - static coro::task> connect( - io::io_context& ctx, - Args&&... args) + static coro::task> connect(Args&&... args) requires std::is_same_v { - auto stream = co_await net::uds_connect(ctx, std::forward(args)...); + auto stream = co_await net::uds_connect(std::forward(args)...); if (!stream) { co_return std::nullopt; } diff --git a/include/elio/runtime/async_main.hpp b/include/elio/runtime/async_main.hpp index 0c91c35..d41b232 100644 --- a/include/elio/runtime/async_main.hpp +++ b/include/elio/runtime/async_main.hpp @@ -2,7 +2,6 @@ #include "scheduler.hpp" #include -#include #include #include #include @@ -18,9 +17,6 @@ namespace elio::runtime { struct run_config { /// Number of worker threads (0 = hardware concurrency) size_t num_threads = 0; - - /// Custom I/O context (nullptr = create default) - io::io_context* io_context = nullptr; }; namespace detail { @@ -112,14 +108,13 @@ coro::task completion_wrapper(coro::task inner, completion_signal* s /// run async code from a synchronous context (like main()). /// /// @param task The coroutine task to run -/// @param config Configuration (threads, io_context) +/// @param config Configuration (threads) /// @return The result of the task /// /// Example: /// @code /// coro::task async_main() { -/// auto& ctx = io::default_io_context(); -/// // Use ctx for async I/O +/// // Your async code here - each worker has its own io_context /// co_return 42; /// } /// @@ -137,16 +132,7 @@ T run(coro::task task, const run_config& config = {}) { if (threads == 0) threads = 1; } - // Use provided io_context or create default - io::io_context* io_ctx = config.io_context; - std::unique_ptr owned_ctx; - if (!io_ctx) { - owned_ctx = std::make_unique(); - io_ctx = owned_ctx.get(); - } - scheduler sched(threads); - sched.set_io_context(io_ctx); sched.start(); // Create wrapper that signals completion diff --git a/include/elio/runtime/scheduler.hpp b/include/elio/runtime/scheduler.hpp index df91539..a1f8b23 100644 --- a/include/elio/runtime/scheduler.hpp +++ b/include/elio/runtime/scheduler.hpp @@ -11,10 +11,6 @@ #include #include -namespace elio::io { -class io_context; -} - namespace elio::runtime { /// Work-stealing scheduler for coroutines @@ -212,16 +208,6 @@ class scheduler { return workers_[worker_id]->tasks_executed(); } - void set_io_context(io::io_context* ctx) noexcept { - io_context_ = ctx; - } - - [[nodiscard]] io::io_context* get_io_context() const noexcept { - return io_context_; - } - - bool try_poll_io(std::chrono::milliseconds timeout = std::chrono::milliseconds(0)); - private: void do_spawn(std::coroutine_handle<> handle) { // Release fence ensures all writes to the coroutine frame (including @@ -273,28 +259,14 @@ class scheduler { std::atomic paused_; std::atomic spawn_index_; mutable std::mutex workers_mutex_; - io::io_context* io_context_ = nullptr; - mutable std::mutex io_poll_mutex_; static inline thread_local scheduler* current_scheduler_ = nullptr; }; } // namespace elio::runtime -#include - namespace elio::runtime { -inline bool scheduler::try_poll_io(std::chrono::milliseconds timeout) { - if (!io_context_) return false; - - std::unique_lock lock(io_poll_mutex_, std::try_to_lock); - if (!lock.owns_lock()) return false; - - io_context_->poll(timeout); - return true; -} - inline scheduler* get_current_scheduler() noexcept { return scheduler::current(); } @@ -507,19 +479,19 @@ inline std::coroutine_handle<> worker_thread::try_steal() noexcept { } inline void worker_thread::poll_io_when_idle() { - // Try to poll IO - only one worker can do this at a time - // Use a non-zero timeout so the polling thread blocks on epoll/io_uring - // while other workers block on their eventfd + // Poll this worker's own io_context + // Each worker has its own io_context, so no locking needed constexpr int idle_timeout_ms = 10; - if (scheduler_->try_poll_io(std::chrono::milliseconds(idle_timeout_ms))) { - // Successfully polled IO (with blocking timeout) - // Check for new tasks immediately after IO completions + // Poll with timeout - will block on epoll/io_uring if no completions ready + int completions = io_context_->poll(std::chrono::milliseconds(idle_timeout_ms)); + + if (completions > 0) { + // Got IO completions, return immediately to check for new tasks return; } - // Couldn't acquire IO poll lock - another worker is handling IO - // Block on our eventfd until woken or timeout + // No IO completions - wait on eventfd for task submissions wait_for_work(idle_timeout_ms); } diff --git a/include/elio/runtime/worker_thread.hpp b/include/elio/runtime/worker_thread.hpp index eb67a87..a45bdaf 100644 --- a/include/elio/runtime/worker_thread.hpp +++ b/include/elio/runtime/worker_thread.hpp @@ -3,6 +3,7 @@ #include "chase_lev_deque.hpp" #include "mpsc_queue.hpp" #include +#include #include #include #include @@ -14,10 +15,6 @@ #include #include -namespace elio::io { -class io_context; -} - namespace elio::runtime { class scheduler; @@ -41,7 +38,8 @@ class worker_thread { , running_(false) , tasks_executed_(0) , wake_fd_(-1) - , wait_epoll_fd_(-1) { + , wait_epoll_fd_(-1) + , io_context_(std::make_unique()) { // Create eventfd for wake-up notifications (non-blocking, semaphore mode) wake_fd_ = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC); @@ -161,6 +159,11 @@ class worker_thread { return worker_id_; } + /// Get the io_context for this worker thread + [[nodiscard]] io::io_context& io_context() noexcept { + return *io_context_; + } + /// Get the current worker thread (if called from a worker thread) [[nodiscard]] static worker_thread* current() noexcept { return current_worker_; @@ -226,6 +229,7 @@ class worker_thread { bool needs_sync_ = false; // Whether current task needs memory synchronization int wake_fd_; // eventfd for wake-up notifications int wait_epoll_fd_; // epoll fd for waiting on wake_fd + std::unique_ptr io_context_; // Per-worker io_context static inline thread_local worker_thread* current_worker_ = nullptr; }; diff --git a/include/elio/signal/signalfd.hpp b/include/elio/signal/signalfd.hpp index bc81cba..52982f9 100644 --- a/include/elio/signal/signalfd.hpp +++ b/include/elio/signal/signalfd.hpp @@ -214,11 +214,11 @@ class signal_fd { public: /// Construct a signal_fd for the given signal set /// @param signals The set of signals to handle - /// @param ctx Optional I/O context (defaults to the global context) + /// @param ctx Optional I/O context (defaults to the current worker's context) /// @param auto_block If true (default), automatically block the signals /// @throws std::system_error if signalfd creation fails explicit signal_fd(const signal_set& signals, - io::io_context& ctx = io::default_io_context(), + io::io_context& ctx = io::current_io_context(), bool auto_block = true) : ctx_(ctx) , signals_(signals) @@ -387,7 +387,7 @@ class signal_block_guard { /// @param auto_block If true (default), automatically block the signals /// @return task that yields signal_info when a signal is received inline coro::task wait_signal(const signal_set& signals, - io::io_context& ctx = io::default_io_context(), + io::io_context& ctx = io::current_io_context(), bool auto_block = true) { signal_fd sigfd(signals, ctx, auto_block); auto info = co_await sigfd.wait(); @@ -402,7 +402,7 @@ inline coro::task wait_signal(const signal_set& signals, /// @param ctx Optional I/O context /// @return task that yields signal_info when the signal is received inline coro::task wait_signal(int signo, - io::io_context& ctx = io::default_io_context()) { + io::io_context& ctx = io::current_io_context()) { signal_set signals; signals.add(signo); co_return co_await wait_signal(signals, ctx); diff --git a/include/elio/time/timer.hpp b/include/elio/time/timer.hpp index 1967bab..3aeb03e 100644 --- a/include/elio/time/timer.hpp +++ b/include/elio/time/timer.hpp @@ -1,11 +1,12 @@ #pragma once -#include +#include #include #include #include #include #include +#include namespace elio::time { @@ -31,21 +32,10 @@ class sleep_awaitable { } void await_suspend(std::coroutine_handle<> awaiter) { - // Get io_context from scheduler or use provided one + // Get io_context from current worker or use provided one io::io_context* ctx = ctx_; if (!ctx) { - auto* sched = runtime::scheduler::current(); - if (sched) { - ctx = sched->get_io_context(); - } - } - - if (!ctx) { - // No io_context available, fall back to thread sleep - ELIO_LOG_DEBUG("sleep_awaitable: no io_context, using thread sleep"); - std::this_thread::sleep_for(std::chrono::nanoseconds(duration_ns_)); - awaiter.resume(); - return; + ctx = &io::current_io_context(); } // Use io_context timeout mechanism @@ -111,43 +101,44 @@ class cancellable_sleep_awaitable { return; } - // Get io_context from scheduler or use provided one + // Get io_context from current worker or use provided one io::io_context* ctx = ctx_; if (!ctx) { - auto* sched = runtime::scheduler::current(); - if (sched) { - ctx = sched->get_io_context(); - } + ctx = &io::current_io_context(); } + ctx_ = ctx; // Save for later cancel - if (!ctx) { - // No io_context available - check cancellation in a loop - ELIO_LOG_DEBUG("cancellable_sleep: no io_context, using polling sleep"); - auto end_time = std::chrono::steady_clock::now() + - std::chrono::nanoseconds(duration_ns_); - while (std::chrono::steady_clock::now() < end_time) { - if (token_.is_cancelled()) { - cancelled_ = true; - break; - } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - awaiter.resume(); - return; - } + // Save worker pointer for thread-safe cancellation + worker_ = runtime::worker_thread::current(); + + // Create a fire-and-forget cancel executor coroutine + // This will be scheduled on the worker thread to execute the actual cancel + cancel_executor_handle_ = create_cancel_executor(); // Register cancellation callback before setting up the timer - // The callback will cancel the pending timeout operation - cancel_registration_ = token_.on_cancel([this, ctx]() { + // The callback schedules the cancel executor on the worker thread + // This ensures io_uring operations are only called from the owning thread + cancel_registration_ = token_.on_cancel([this]() { cancelled_ = true; - // Cancel the pending timeout operation - ctx->cancel(awaiter_.address()); + // Schedule cancel execution on the correct worker thread + // This is thread-safe because schedule() uses lock-free MPSC queue + if (worker_ && cancel_executor_handle_) { + // Mark as handed off BEFORE scheduling - this prevents double-free + // since the coroutine will auto-destroy via final_suspend returning suspend_never + cancel_executor_handed_off_.store(true, std::memory_order_release); + worker_->schedule(cancel_executor_handle_); + } }); // Check again after registration (in case cancelled between check and register) if (token_.is_cancelled()) { cancel_registration_.unregister(); cancelled_ = true; + // Destroy unused cancel executor + if (cancel_executor_handle_) { + cancel_executor_handle_.destroy(); + cancel_executor_handle_ = nullptr; + } awaiter.resume(); return; } @@ -160,6 +151,11 @@ class cancellable_sleep_awaitable { if (!ctx->prepare(req)) { cancel_registration_.unregister(); + // Destroy unused cancel executor + if (cancel_executor_handle_) { + cancel_executor_handle_.destroy(); + cancel_executor_handle_ = nullptr; + } // Failed to prepare, fall back to polling sleep ELIO_LOG_WARNING("cancellable_sleep: failed to prepare timeout, using polling sleep"); auto end_time = std::chrono::steady_clock::now() + @@ -181,17 +177,57 @@ class cancellable_sleep_awaitable { cancel_result await_resume() noexcept { // Unregister callback to prevent use-after-free cancel_registration_.unregister(); + // Only destroy cancel executor if it wasn't handed off to worker thread + // If handed off, it will auto-destroy via final_suspend returning suspend_never + if (!cancel_executor_handed_off_.load(std::memory_order_acquire) && cancel_executor_handle_) { + cancel_executor_handle_.destroy(); + } + cancel_executor_handle_ = nullptr; // Check both the flag and the token state (for await_ready() early return case) return (cancelled_ || token_.is_cancelled()) ? cancel_result::cancelled : cancel_result::completed; } private: + /// Fire-and-forget coroutine that executes the actual cancel operation + /// This runs on the worker thread that owns the io_context + struct cancel_executor { + struct promise_type { + cancellable_sleep_awaitable* self = nullptr; + + cancel_executor get_return_object() { + return {std::coroutine_handle::from_promise(*this)}; + } + std::suspend_always initial_suspend() noexcept { return {}; } + std::suspend_never final_suspend() noexcept { return {}; } // Self-destroy + void return_void() noexcept {} + void unhandled_exception() noexcept {} + }; + std::coroutine_handle handle; + }; + + std::coroutine_handle<> create_cancel_executor() { + auto executor = [](cancellable_sleep_awaitable* self) -> cancel_executor { + // Execute the actual cancel on the worker thread + if (self->ctx_) { + self->ctx_->cancel(self->awaiter_.address()); + } + co_return; + }(this); + + // Store self pointer in promise for access during execution + executor.handle.promise().self = this; + return executor.handle; + } + io::io_context* ctx_ = nullptr; + runtime::worker_thread* worker_ = nullptr; int64_t duration_ns_; coro::cancel_token token_; coro::cancel_token::registration cancel_registration_; std::coroutine_handle<> awaiter_; + std::coroutine_handle<> cancel_executor_handle_; + std::atomic cancel_executor_handed_off_{false}; bool cancelled_ = false; }; diff --git a/include/elio/tls/tls_stream.hpp b/include/elio/tls/tls_stream.hpp index 6f4877a..7dc78ba 100644 --- a/include/elio/tls/tls_stream.hpp +++ b/include/elio/tls/tls_stream.hpp @@ -333,14 +333,13 @@ class tls_stream { /// Connect to a TLS server /// @param ctx TLS context (client mode) -/// @param io_ctx I/O context for async operations /// @param host Hostname to connect to /// @param port Port to connect to /// @return TLS stream on success, std::nullopt on error (check errno) inline coro::task> -tls_connect(tls_context& ctx, io::io_context& io_ctx, std::string_view host, uint16_t port) { +tls_connect(tls_context& ctx, std::string_view host, uint16_t port) { // First establish TCP connection - auto tcp_result = co_await net::tcp_connect(io_ctx, host, port); + auto tcp_result = co_await net::tcp_connect(host, port); if (!tcp_result) { co_return std::nullopt; } diff --git a/tests/unit/test_io.cpp b/tests/unit/test_io.cpp index 35982c2..f973b9e 100644 --- a/tests/unit/test_io.cpp +++ b/tests/unit/test_io.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -17,6 +18,7 @@ using namespace elio::io; using namespace elio::coro; +using namespace elio::runtime; TEST_CASE("io_context creation", "[io][context]") { SECTION("default constructor uses auto-detection") { @@ -83,36 +85,36 @@ TEST_CASE("Pipe read/write with epoll", "[io][epoll][pipe]") { fcntl(pipefd[0], F_SETFL, O_NONBLOCK); fcntl(pipefd[1], F_SETFL, O_NONBLOCK); - io_context ctx(io_context::backend_type::epoll); - // Write some data synchronously first const char* test_data = "Hello, Elio!"; ssize_t written = write(pipefd[1], test_data, strlen(test_data)); REQUIRE(written == static_cast(strlen(test_data))); - // Read using epoll backend + // Read using scheduler (coroutines run on worker threads with their own io_context) char buffer[64] = {0}; std::atomic completed{false}; io_result read_result{}; + scheduler sched(1); + sched.start(); + // Create a simple test coroutine auto read_coro = [&]() -> task { - auto result = co_await async_read(ctx, pipefd[0], buffer, sizeof(buffer) - 1); + auto result = co_await async_read(pipefd[0], buffer, sizeof(buffer) - 1); read_result = result; completed = true; }; auto t = read_coro(); - auto handle = t.handle(); - - // Start the coroutine - handle.resume(); + sched.spawn(t.release()); - // Poll for completion + // Wait for completion for (int i = 0; i < 100 && !completed; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(completed); REQUIRE(read_result.success()); REQUIRE(read_result.bytes_transferred() == static_cast(strlen(test_data))); @@ -158,32 +160,35 @@ TEST_CASE("Socket pair with epoll", "[io][epoll][socket]") { int sv[2]; REQUIRE(socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sv) == 0); - io_context ctx(io_context::backend_type::epoll); - const char* msg = "Socket test message"; // Send on one end ssize_t sent = send(sv[0], msg, strlen(msg), 0); REQUIRE(sent == static_cast(strlen(msg))); - // Receive on the other end using async + // Receive on the other end using scheduler char buffer[64] = {0}; std::atomic completed{false}; io_result recv_result{}; + scheduler sched(1); + sched.start(); + auto recv_coro = [&]() -> task { - auto result = co_await async_recv(ctx, sv[1], buffer, sizeof(buffer) - 1); + auto result = co_await async_recv(sv[1], buffer, sizeof(buffer) - 1); recv_result = result; completed = true; }; auto t = recv_coro(); - t.handle().resume(); + sched.spawn(t.release()); for (int i = 0; i < 100 && !completed; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(completed); REQUIRE(recv_result.success()); REQUIRE(recv_result.bytes_transferred() == static_cast(strlen(msg))); @@ -217,39 +222,37 @@ TEST_CASE("Cancel operation", "[io][epoll][cancel]") { int sv[2]; REQUIRE(socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sv) == 0); - io_context ctx(io_context::backend_type::epoll); - // Start a read that won't complete (no data sent) char buffer[64]; + std::atomic started{false}; std::atomic completed{false}; io_result recv_result{}; - void* cancel_key = nullptr; + + scheduler sched(1); + sched.start(); auto recv_coro = [&]() -> task { - auto result = co_await async_recv(ctx, sv[1], buffer, sizeof(buffer)); + started = true; + auto result = co_await async_recv(sv[1], buffer, sizeof(buffer)); recv_result = result; completed = true; }; auto t = recv_coro(); - auto handle = t.handle(); - cancel_key = handle.address(); - - // Start the coroutine - handle.resume(); - - // Poll once to register - ctx.poll(std::chrono::milliseconds(1)); + sched.spawn(t.release()); - // Cancel the operation - bool cancelled = ctx.cancel(cancel_key); - (void)cancelled; // May or may not succeed depending on timing + // Wait for coroutine to start + for (int i = 0; i < 100 && !started; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } - // Poll to process cancellation - ctx.poll(std::chrono::milliseconds(10)); + // Give it a bit more time + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // Note: cancel behavior depends on backend implementation - // Just verify we don't crash + // Just verify we don't crash on shutdown with pending operation + + sched.shutdown(); close(sv[0]); close(sv[1]); @@ -260,8 +263,6 @@ TEST_CASE("Multiple concurrent operations", "[io][epoll][concurrent]") { REQUIRE(socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sv1) == 0); REQUIRE(socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sv2) == 0); - io_context ctx(io_context::backend_type::epoll); - const char* msg1 = "Message 1"; const char* msg2 = "Message 2"; @@ -273,27 +274,32 @@ TEST_CASE("Multiple concurrent operations", "[io][epoll][concurrent]") { char buffer2[64] = {0}; std::atomic completed{0}; + scheduler sched(2); // 2 workers for concurrent operations + sched.start(); + auto recv_coro1 = [&]() -> task { - co_await async_recv(ctx, sv1[1], buffer1, sizeof(buffer1) - 1); + co_await async_recv(sv1[1], buffer1, sizeof(buffer1) - 1); completed++; }; auto recv_coro2 = [&]() -> task { - co_await async_recv(ctx, sv2[1], buffer2, sizeof(buffer2) - 1); + co_await async_recv(sv2[1], buffer2, sizeof(buffer2) - 1); completed++; }; auto t1 = recv_coro1(); auto t2 = recv_coro2(); - t1.handle().resume(); - t2.handle().resume(); + sched.spawn(t1.release()); + sched.spawn(t2.release()); - // Poll until both complete + // Wait until both complete for (int i = 0; i < 100 && completed < 2; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(completed == 2); REQUIRE(std::string(buffer1) == msg1); REQUIRE(std::string(buffer2) == msg2); @@ -313,45 +319,51 @@ TEST_CASE("Default io_context singleton", "[io][context][singleton]") { } TEST_CASE("epoll_backend registers fd before data available", "[io][epoll][registration]") { - // This test verifies that async operations are properly registered with epoll - // even when no data is immediately available. This catches use-after-move bugs - // in the prepare() function. + // This test verifies that async operations work correctly when data + // is not immediately available. int sv[2]; REQUIRE(socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sv) == 0); - io_context ctx(io_context::backend_type::epoll); - char buffer[64] = {0}; + std::atomic started{false}; std::atomic completed{false}; io_result recv_result{}; + scheduler sched(1); + sched.start(); + auto recv_coro = [&]() -> task { - auto result = co_await async_recv(ctx, sv[1], buffer, sizeof(buffer) - 1); + started = true; + auto result = co_await async_recv(sv[1], buffer, sizeof(buffer) - 1); recv_result = result; completed = true; }; auto t = recv_coro(); - t.handle().resume(); + sched.spawn(t.release()); - // Poll once to ensure the operation is registered - ctx.poll(std::chrono::milliseconds(1)); + // Wait for coroutine to start and register the operation + for (int i = 0; i < 100 && !started; ++i) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + REQUIRE(started); - // Verify the operation is pending (fd should be registered with epoll) - REQUIRE(ctx.has_pending()); - REQUIRE(ctx.pending_count() >= 1); + // Give the I/O operation time to be registered + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // Now send data - this should trigger the read to complete const char* msg = "delayed message"; ssize_t sent = send(sv[0], msg, strlen(msg), 0); REQUIRE(sent == static_cast(strlen(msg))); - // Poll until completion + // Wait for completion for (int i = 0; i < 100 && !completed; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + // The read should have completed successfully REQUIRE(completed); REQUIRE(recv_result.success()); @@ -368,51 +380,52 @@ TEST_CASE("epoll_backend handles multiple pending ops on same fd", "[io][epoll][ int sv[2]; REQUIRE(socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sv) == 0); - io_context ctx(io_context::backend_type::epoll); - char buffer1[32] = {0}; char buffer2[32] = {0}; std::atomic completed{0}; + scheduler sched(1); + sched.start(); + // Start two recv operations on the same fd auto recv_coro1 = [&]() -> task { - co_await async_recv(ctx, sv[1], buffer1, sizeof(buffer1) - 1); + co_await async_recv(sv[1], buffer1, sizeof(buffer1) - 1); completed++; }; auto recv_coro2 = [&]() -> task { - co_await async_recv(ctx, sv[1], buffer2, sizeof(buffer2) - 1); + co_await async_recv(sv[1], buffer2, sizeof(buffer2) - 1); completed++; }; auto t1 = recv_coro1(); auto t2 = recv_coro2(); - t1.handle().resume(); - t2.handle().resume(); + sched.spawn(t1.release()); + sched.spawn(t2.release()); - // Register both - ctx.poll(std::chrono::milliseconds(1)); - - REQUIRE(ctx.pending_count() >= 2); + // Give operations time to be registered + std::this_thread::sleep_for(std::chrono::milliseconds(50)); // Send enough data for both reads const char* msg1 = "first"; const char* msg2 = "second"; send(sv[0], msg1, strlen(msg1), 0); - // Poll to complete first read + // Wait for first read for (int i = 0; i < 100 && completed < 1; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } send(sv[0], msg2, strlen(msg2), 0); - // Poll to complete second read + // Wait for second read for (int i = 0; i < 100 && completed < 2; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(completed == 2); close(sv[0]); @@ -420,31 +433,34 @@ TEST_CASE("epoll_backend handles multiple pending ops on same fd", "[io][epoll][ } TEST_CASE("epoll_backend write operation registration", "[io][epoll][write]") { - // Verify write operations are properly registered + // Verify write operations work correctly int sv[2]; REQUIRE(socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0, sv) == 0); - io_context ctx(io_context::backend_type::epoll); - const char* msg = "write test data"; std::atomic completed{false}; io_result send_result{}; + scheduler sched(1); + sched.start(); + auto send_coro = [&]() -> task { - auto result = co_await async_send(ctx, sv[0], msg, strlen(msg)); + auto result = co_await async_send(sv[0], msg, strlen(msg)); send_result = result; completed = true; }; auto t = send_coro(); - t.handle().resume(); + sched.spawn(t.release()); - // Poll for completion + // Wait for completion for (int i = 0; i < 100 && !completed; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(completed); REQUIRE(send_result.success()); REQUIRE(send_result.bytes_transferred() == static_cast(strlen(msg))); @@ -512,13 +528,11 @@ TEST_CASE("unix_address basic operations", "[uds][address]") { } TEST_CASE("UDS listener bind and accept", "[uds][listener]") { - io_context ctx(io_context::backend_type::epoll); - // Use abstract socket to avoid filesystem cleanup issues auto addr = unix_address::abstract("elio_test_listener_" + std::to_string(getpid())); SECTION("bind creates listener") { - auto listener = uds_listener::bind(addr, ctx); + auto listener = uds_listener::bind(addr, default_io_context()); REQUIRE(listener.has_value()); REQUIRE(listener->is_valid()); REQUIRE(listener->fd() >= 0); @@ -526,14 +540,14 @@ TEST_CASE("UDS listener bind and accept", "[uds][listener]") { } SECTION("accept returns connection") { - auto listener = uds_listener::bind(addr, ctx); + auto listener = uds_listener::bind(addr, default_io_context()); REQUIRE(listener.has_value()); // Create a client connection in a separate thread std::atomic client_connected{false}; std::thread client_thread([&]() { // Wait a bit for the accept to be registered - std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); int client_fd = socket(AF_UNIX, SOCK_STREAM, 0); REQUIRE(client_fd >= 0); @@ -552,6 +566,9 @@ TEST_CASE("UDS listener bind and accept", "[uds][listener]") { std::atomic accepted{false}; std::optional accepted_stream; + scheduler sched(1); + sched.start(); + auto accept_coro = [&]() -> task { auto stream = co_await listener->accept(); accepted_stream = std::move(stream); @@ -559,13 +576,15 @@ TEST_CASE("UDS listener bind and accept", "[uds][listener]") { }; auto t = accept_coro(); - t.handle().resume(); + sched.spawn(t.release()); - // Poll for completion + // Wait for completion for (int i = 0; i < 200 && !accepted; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(accepted); REQUIRE(accepted_stream.has_value()); REQUIRE(accepted_stream->is_valid()); @@ -575,45 +594,47 @@ TEST_CASE("UDS listener bind and accept", "[uds][listener]") { } TEST_CASE("UDS connect", "[uds][connect]") { - io_context ctx(io_context::backend_type::epoll); - auto addr = unix_address::abstract("elio_test_connect_" + std::to_string(getpid())); // Create server listener - auto listener = uds_listener::bind(addr, ctx); + auto listener = uds_listener::bind(addr, default_io_context()); REQUIRE(listener.has_value()); // Start accept in background std::atomic server_accepted{false}; std::optional server_stream; + std::atomic client_connected{false}; + std::optional client_stream; + + scheduler sched(2); + sched.start(); + auto accept_coro = [&]() -> task { auto stream = co_await listener->accept(); server_stream = std::move(stream); server_accepted = true; }; - auto accept_task = accept_coro(); - accept_task.handle().resume(); - - // Connect client - std::atomic client_connected{false}; - std::optional client_stream; - auto connect_coro = [&]() -> task { - auto stream = co_await uds_connect(ctx, addr); + auto stream = co_await uds_connect(addr); client_stream = std::move(stream); client_connected = true; }; + auto accept_task = accept_coro(); auto connect_task = connect_coro(); - connect_task.handle().resume(); - // Poll until both complete + sched.spawn(accept_task.release()); + sched.spawn(connect_task.release()); + + // Wait until both complete for (int i = 0; i < 200 && (!server_accepted || !client_connected); ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(server_accepted); REQUIRE(client_connected); REQUIRE(server_stream.has_value()); @@ -623,18 +644,19 @@ TEST_CASE("UDS connect", "[uds][connect]") { } TEST_CASE("UDS stream read/write", "[uds][stream]") { - io_context ctx(io_context::backend_type::epoll); - auto addr = unix_address::abstract("elio_test_rw_" + std::to_string(getpid())); // Create server and client - auto listener = uds_listener::bind(addr, ctx); + auto listener = uds_listener::bind(addr, default_io_context()); REQUIRE(listener.has_value()); std::optional server_stream; std::optional client_stream; std::atomic setup_complete{0}; + scheduler sched(2); + sched.start(); + auto accept_coro = [&]() -> task { auto stream = co_await listener->accept(); server_stream = std::move(stream); @@ -642,18 +664,18 @@ TEST_CASE("UDS stream read/write", "[uds][stream]") { }; auto connect_coro = [&]() -> task { - auto stream = co_await uds_connect(ctx, addr); + auto stream = co_await uds_connect(addr); client_stream = std::move(stream); setup_complete++; }; auto accept_task = accept_coro(); auto connect_task = connect_coro(); - accept_task.handle().resume(); - connect_task.handle().resume(); + sched.spawn(accept_task.release()); + sched.spawn(connect_task.release()); for (int i = 0; i < 200 && setup_complete < 2; ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } REQUIRE(setup_complete == 2); @@ -680,11 +702,11 @@ TEST_CASE("UDS stream read/write", "[uds][stream]") { auto write_task = write_coro(); auto read_task = read_coro(); - write_task.handle().resume(); - read_task.handle().resume(); + sched.spawn(write_task.release()); + sched.spawn(read_task.release()); for (int i = 0; i < 200 && (!write_done || !read_done); ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } REQUIRE(write_done); @@ -716,11 +738,11 @@ TEST_CASE("UDS stream read/write", "[uds][stream]") { auto write_task = write_coro(); auto read_task = read_coro(); - write_task.handle().resume(); - read_task.handle().resume(); + sched.spawn(write_task.release()); + sched.spawn(read_task.release()); for (int i = 0; i < 200 && (!write_done || !read_done); ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } REQUIRE(write_done); @@ -731,14 +753,14 @@ TEST_CASE("UDS stream read/write", "[uds][stream]") { REQUIRE(read_result.bytes_transferred() == static_cast(strlen(msg))); REQUIRE(std::string(buffer) == msg); } + + sched.shutdown(); } TEST_CASE("UDS multiple concurrent connections", "[uds][concurrent]") { - io_context ctx(io_context::backend_type::epoll); - auto addr = unix_address::abstract("elio_test_concurrent_" + std::to_string(getpid())); - auto listener = uds_listener::bind(addr, ctx); + auto listener = uds_listener::bind(addr, default_io_context()); REQUIRE(listener.has_value()); constexpr int NUM_CLIENTS = 3; @@ -747,7 +769,10 @@ TEST_CASE("UDS multiple concurrent connections", "[uds][concurrent]") { std::atomic accepts_done{0}; std::atomic connects_done{0}; - // Accept coroutines - use array to avoid vector reallocation issues + scheduler sched(4); + sched.start(); + + // Accept coroutines auto accept0 = [&]() -> task { auto stream = co_await listener->accept(); server_streams[0] = std::move(stream); @@ -766,17 +791,17 @@ TEST_CASE("UDS multiple concurrent connections", "[uds][concurrent]") { // Connect coroutines auto connect0 = [&]() -> task { - auto stream = co_await uds_connect(ctx, addr); + auto stream = co_await uds_connect(addr); client_streams[0] = std::move(stream); connects_done++; }; auto connect1 = [&]() -> task { - auto stream = co_await uds_connect(ctx, addr); + auto stream = co_await uds_connect(addr); client_streams[1] = std::move(stream); connects_done++; }; auto connect2 = [&]() -> task { - auto stream = co_await uds_connect(ctx, addr); + auto stream = co_await uds_connect(addr); client_streams[2] = std::move(stream); connects_done++; }; @@ -785,18 +810,20 @@ TEST_CASE("UDS multiple concurrent connections", "[uds][concurrent]") { auto c0 = connect0(); auto c1 = connect1(); auto c2 = connect2(); // Start all coroutines - a0.handle().resume(); - a1.handle().resume(); - a2.handle().resume(); - c0.handle().resume(); - c1.handle().resume(); - c2.handle().resume(); - - // Poll until all connections are made + sched.spawn(a0.release()); + sched.spawn(a1.release()); + sched.spawn(a2.release()); + sched.spawn(c0.release()); + sched.spawn(c1.release()); + sched.spawn(c2.release()); + + // Wait until all connections are made for (int i = 0; i < 500 && (accepts_done < NUM_CLIENTS || connects_done < NUM_CLIENTS); ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(accepts_done == NUM_CLIENTS); REQUIRE(connects_done == NUM_CLIENTS); @@ -809,8 +836,6 @@ TEST_CASE("UDS multiple concurrent connections", "[uds][concurrent]") { } TEST_CASE("UDS filesystem socket", "[uds][filesystem]") { - io_context ctx(io_context::backend_type::epoll); - // Use filesystem socket std::string path = "/tmp/elio_test_fs_" + std::to_string(getpid()) + ".sock"; unix_address addr(path); @@ -818,7 +843,7 @@ TEST_CASE("UDS filesystem socket", "[uds][filesystem]") { // Ensure socket file doesn't exist ::unlink(path.c_str()); - auto listener = uds_listener::bind(addr, ctx); + auto listener = uds_listener::bind(addr, default_io_context()); REQUIRE(listener.has_value()); // Socket file should exist @@ -830,16 +855,18 @@ TEST_CASE("UDS filesystem socket", "[uds][filesystem]") { std::atomic connected{false}; std::optional client_stream; + std::atomic accepted{false}; + std::optional server_stream; + + scheduler sched(2); + sched.start(); + auto connect_coro = [&]() -> task { - auto stream = co_await uds_connect(ctx, addr); + auto stream = co_await uds_connect(addr); client_stream = std::move(stream); connected = true; }; - // Accept on server - std::atomic accepted{false}; - std::optional server_stream; - auto accept_coro = [&]() -> task { auto stream = co_await listener->accept(); server_stream = std::move(stream); @@ -848,13 +875,15 @@ TEST_CASE("UDS filesystem socket", "[uds][filesystem]") { auto accept_task = accept_coro(); auto connect_task = connect_coro(); - accept_task.handle().resume(); - connect_task.handle().resume(); + sched.spawn(accept_task.release()); + sched.spawn(connect_task.release()); for (int i = 0; i < 200 && (!connected || !accepted); ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(connected); REQUIRE(accepted); @@ -864,11 +893,9 @@ TEST_CASE("UDS filesystem socket", "[uds][filesystem]") { } TEST_CASE("UDS echo test", "[uds][echo]") { - io_context ctx(io_context::backend_type::epoll); - auto addr = unix_address::abstract("elio_test_echo_" + std::to_string(getpid())); - auto listener = uds_listener::bind(addr, ctx); + auto listener = uds_listener::bind(addr, default_io_context()); REQUIRE(listener.has_value()); // Use a simpler pattern: thread for client, coroutine for server @@ -879,6 +906,9 @@ TEST_CASE("UDS echo test", "[uds][echo]") { int server_bytes = 0; int client_bytes = 0; + scheduler sched(1); + sched.start(); + // Server coroutine auto server_coro = [&]() -> task { auto stream = co_await listener->accept(); @@ -898,12 +928,12 @@ TEST_CASE("UDS echo test", "[uds][echo]") { }; auto server_task = server_coro(); - server_task.handle().resume(); + sched.spawn(server_task.release()); // Client in a thread (to avoid coroutine complexity) std::thread client_thread([&]() { // Wait briefly for server to be ready - std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); int fd = socket(AF_UNIX, SOCK_STREAM, 0); if (fd < 0) { @@ -930,12 +960,13 @@ TEST_CASE("UDS echo test", "[uds][echo]") { client_done = true; }); - // Poll until both complete + // Wait until both complete for (int i = 0; i < 500 && (!server_done || !client_done); ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } client_thread.join(); + sched.shutdown(); REQUIRE(server_done); REQUIRE(client_done); @@ -1097,11 +1128,9 @@ TEST_CASE("socket_address variant operations", "[tcp][address][socket_address]") } TEST_CASE("TCP IPv6 listener and connect", "[tcp][ipv6][integration]") { - io_context ctx(io_context::backend_type::epoll); - SECTION("IPv6 listener binds successfully") { // Use IPv6 loopback to avoid network issues - auto listener = tcp_listener::bind(ipv6_address("::1", 0), ctx); + auto listener = tcp_listener::bind(ipv6_address("::1", 0), default_io_context()); REQUIRE(listener.has_value()); REQUIRE(listener->is_valid()); REQUIRE(listener->local_address().family() == AF_INET6); @@ -1110,7 +1139,7 @@ TEST_CASE("TCP IPv6 listener and connect", "[tcp][ipv6][integration]") { SECTION("IPv6 accept and connect") { // Create listener on IPv6 loopback - auto listener = tcp_listener::bind(ipv6_address("::1", 0), ctx); + auto listener = tcp_listener::bind(ipv6_address("::1", 0), default_io_context()); REQUIRE(listener.has_value()); // Get the assigned port @@ -1122,6 +1151,9 @@ TEST_CASE("TCP IPv6 listener and connect", "[tcp][ipv6][integration]") { std::optional server_stream; std::optional client_stream; + scheduler sched(2); + sched.start(); + auto accept_coro = [&]() -> task { auto stream = co_await listener->accept(); server_stream = std::move(stream); @@ -1129,20 +1161,22 @@ TEST_CASE("TCP IPv6 listener and connect", "[tcp][ipv6][integration]") { }; auto connect_coro = [&]() -> task { - auto stream = co_await tcp_connect(ctx, ipv6_address("::1", port)); + auto stream = co_await tcp_connect(ipv6_address("::1", port)); client_stream = std::move(stream); connected = true; }; auto accept_task = accept_coro(); auto connect_task = connect_coro(); - accept_task.handle().resume(); - connect_task.handle().resume(); + sched.spawn(accept_task.release()); + sched.spawn(connect_task.release()); for (int i = 0; i < 200 && (!accepted || !connected); ++i) { - ctx.poll(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + sched.shutdown(); + REQUIRE(accepted); REQUIRE(connected); REQUIRE(server_stream.has_value()); diff --git a/tests/unit/test_signalfd.cpp b/tests/unit/test_signalfd.cpp index a8f13a8..070e857 100644 --- a/tests/unit/test_signalfd.cpp +++ b/tests/unit/test_signalfd.cpp @@ -11,6 +11,7 @@ using namespace elio::signal; using namespace elio::coro; using namespace elio::runtime; +using namespace elio::io; using namespace std::chrono_literals; TEST_CASE("signal_set basic operations", "[signal][signal_set]") { @@ -197,10 +198,8 @@ TEST_CASE("signal_fd async wait", "[signal][signal_fd]") { sigset_t old_mask; sigs.block(&old_mask); - elio::io::io_context ctx; - auto wait_task = [&]() -> task { - signal_fd sigfd(sigs, ctx, false); // Don't re-block, already blocked + signal_fd sigfd(sigs, current_io_context(), false); // Don't re-block, already blocked auto info = co_await sigfd.wait(); if (info) { @@ -210,7 +209,6 @@ TEST_CASE("signal_fd async wait", "[signal][signal_fd]") { }; scheduler sched(1); - sched.set_io_context(&ctx); sched.start(); { @@ -249,10 +247,8 @@ TEST_CASE("signal_fd multiple signals", "[signal][signal_fd]") { sigset_t old_mask; sigs.block(&old_mask); - elio::io::io_context ctx; - auto wait_task = [&]() -> task { - signal_fd sigfd(sigs, ctx, false); // Don't re-block, already blocked + signal_fd sigfd(sigs, current_io_context(), false); // Don't re-block, already blocked for (int i = 0; i < 2; ++i) { auto info = co_await sigfd.wait(); @@ -265,7 +261,6 @@ TEST_CASE("signal_fd multiple signals", "[signal][signal_fd]") { }; scheduler sched(1); - sched.set_io_context(&ctx); sched.start(); { @@ -378,16 +373,13 @@ TEST_CASE("wait_signal convenience function", "[signal][wait_signal]") { sigset_t old_mask; sigs.block(&old_mask); - elio::io::io_context ctx; - auto wait_task = [&]() -> task { - auto info = co_await wait_signal(sigs, ctx, false); // Don't re-block, already blocked + auto info = co_await wait_signal(sigs, current_io_context(), false); // Don't re-block, already blocked REQUIRE(info.signo == SIGUSR1); received = true; }; scheduler sched(1); - sched.set_io_context(&ctx); sched.start(); { diff --git a/wiki/API-Reference.md b/wiki/API-Reference.md index ce58440..e2cac8b 100644 --- a/wiki/API-Reference.md +++ b/wiki/API-Reference.md @@ -234,9 +234,6 @@ public: template void spawn(Task&& t); // Accepts any type with release() method - // Set the I/O context for workers to poll - void set_io_context(io::io_context* ctx); - // Get number of worker threads size_t worker_count() const noexcept; @@ -250,11 +247,7 @@ public: runtime::scheduler sched(4); sched.start(); -// Old API (still works) -auto t = my_coroutine(); -sched.spawn(t.release()); - -// New simplified API +// Spawn tasks directly sched.spawn(my_coroutine()); // Accepts task directly sched.shutdown(); @@ -303,7 +296,6 @@ Configuration for running async tasks. ```cpp struct run_config { size_t num_threads = 0; // 0 = hardware concurrency - io::io_context* io_context = nullptr; // nullptr = create default }; ``` @@ -751,7 +743,7 @@ public: }; // Connect to address (awaitable, returns std::optional) -/* awaitable */ tcp_connect(const ipv4_address& addr, io_context& ctx); +/* awaitable */ tcp_connect(const ipv4_address& addr); ``` --- diff --git a/wiki/Core-Concepts.md b/wiki/Core-Concepts.md index 0cbcc17..ecfdfb6 100644 --- a/wiki/Core-Concepts.md +++ b/wiki/Core-Concepts.md @@ -215,18 +215,15 @@ This design ensures: ## I/O Context -The I/O context manages async I/O operations. +The I/O context manages async I/O operations. Each worker thread has its own I/O context for lock-free operation. ### Using the Default Context ```cpp #include -// Get the global I/O context +// Get the global I/O context (for TCP listeners, etc.) auto& ctx = io::default_io_context(); - -// Associate with scheduler -sched.set_io_context(&ctx); ``` ### I/O Backends diff --git a/wiki/Examples.md b/wiki/Examples.md index beff0fa..b7dfc54 100644 --- a/wiki/Examples.md +++ b/wiki/Examples.md @@ -146,7 +146,6 @@ int main() { sigs.block_all_threads(); runtime::scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); // Spawn signal handler coroutine @@ -249,7 +248,6 @@ int main() { // auto addr = net::unix_address::abstract("echo_server"); runtime::scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); // Spawn signal handler @@ -278,11 +276,9 @@ A Unix Domain Socket client that connects to a UDS server: using namespace elio; coro::task client_main(const net::unix_address& addr) { - auto& ctx = io::default_io_context(); - ELIO_LOG_INFO("Connecting to {}...", addr.to_string()); - auto stream = co_await net::uds_connect(ctx, addr); + auto stream = co_await net::uds_connect(addr); if (!stream) { ELIO_LOG_ERROR("Connect failed: {}", strerror(errno)); co_return; @@ -312,7 +308,6 @@ int main() { // auto addr = net::unix_address::abstract("echo_server"); runtime::scheduler sched(2); - sched.set_io_context(&io::default_io_context()); sched.start(); std::atomic done{false}; @@ -477,7 +472,6 @@ coro::task router(request& req, response& resp) { int main() { runtime::scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); server_config config; diff --git a/wiki/Getting-Started.md b/wiki/Getting-Started.md index 3fb544b..aef6a34 100644 --- a/wiki/Getting-Started.md +++ b/wiki/Getting-Started.md @@ -32,8 +32,11 @@ cmake --build build cd build ctest --output-on-failure -# With AddressSanitizer +# With AddressSanitizer (memory safety) ./tests/elio_tests_asan + +# With ThreadSanitizer (thread safety) +./tests/elio_tests_tsan ``` ### CMake Integration diff --git a/wiki/Home.md b/wiki/Home.md index 76119a5..90ecde3 100644 --- a/wiki/Home.md +++ b/wiki/Home.md @@ -7,18 +7,18 @@ ## Features - **C++20 Coroutines**: First-class coroutine support with `co_await` and `co_return` -- **Multi-threaded Scheduler**: Work-stealing scheduler with configurable worker threads +- **Multi-threaded Scheduler**: Work-stealing scheduler with per-worker I/O contexts - **Async I/O Backends**: io_uring (preferred) and epoll fallback - **Signal Handling**: Coroutine-friendly signal handling via signalfd - **Synchronization Primitives**: mutex, shared_mutex, semaphore, event, channel -- **Timers**: sleep_for, sleep_until, yield +- **Timers**: sleep_for, sleep_until, yield with cancellation support - **TCP Networking**: Async TCP client/server with connection management - **HTTP/1.1**: Full HTTP client and server implementation - **TLS/HTTPS**: OpenSSL-based TLS with ALPN and certificate verification - **RPC Framework**: High-performance RPC with zero-copy serialization, checksums, and cleanup callbacks - **Hash Functions**: CRC32, SHA-1, and SHA-256 with incremental hashing support - **Header-only**: Easy integration - just include and go -- **CI/CD**: Automated testing with GitHub Actions +- **CI/CD**: Automated testing with ASAN and TSAN ## Requirements diff --git a/wiki/Networking.md b/wiki/Networking.md index bf707f3..fc3b647 100644 --- a/wiki/Networking.md +++ b/wiki/Networking.md @@ -47,10 +47,8 @@ coro::task server(uint16_t port, runtime::scheduler& sched) { ```cpp coro::task client(const std::string& host, uint16_t port) { - auto& ctx = io::default_io_context(); - // Connect to server (hostname is resolved automatically) - auto stream = co_await tcp_connect(ipv4_address(host, port), ctx); + auto stream = co_await tcp_connect(ipv4_address(host, port)); if (!stream) { ELIO_LOG_ERROR("Connect failed: {}", strerror(errno)); co_return; @@ -165,10 +163,8 @@ coro::task server(const unix_address& addr, runtime::scheduler& sched) { ```cpp coro::task client(const unix_address& addr) { - auto& ctx = io::default_io_context(); - // Connect to server - auto stream = co_await uds_connect(ctx, addr); + auto stream = co_await uds_connect(addr); if (!stream) { ELIO_LOG_ERROR("Connect failed: {}", strerror(errno)); co_return; @@ -410,14 +406,16 @@ coro::task advanced_h2_client(io::io_context& ctx) { using namespace elio::tls; -coro::task secure_connection(io::io_context& ctx) { +coro::task secure_connection() { + auto& ctx = io::default_io_context(); + // Create TLS context tls_context tls_ctx(tls_method::client); tls_ctx.use_default_verify_paths(); tls_ctx.set_verify_mode(true); // Connect TCP - auto tcp = co_await tcp_connect(ipv4_address("example.com", 443), ctx); + auto tcp = co_await tcp_connect(ipv4_address("example.com", 443)); if (!tcp) co_return; // Wrap with TLS diff --git a/wiki/WebSocket-SSE.md b/wiki/WebSocket-SSE.md index e5cba65..e0d59a9 100644 --- a/wiki/WebSocket-SSE.md +++ b/wiki/WebSocket-SSE.md @@ -49,7 +49,6 @@ int main() { ws_server srv(std::move(router)); runtime::scheduler sched(4); - sched.set_io_context(&io::default_io_context()); sched.start(); auto task = srv.listen(net::ipv4_address(8080),