diff --git a/helix-container/src/main.rs b/helix-container/src/main.rs index c9e57eba6..ebe7ea2fe 100644 --- a/helix-container/src/main.rs +++ b/helix-container/src/main.rs @@ -4,6 +4,7 @@ use helixdb::helix_gateway::{ gateway::{GatewayOpts, HelixGateway}, router::router::{HandlerFn, HandlerSubmission}, }; +use helixdb::helix_transport::tokio_transport::TokioTransport; use helixdb::helix_runtime::tokio_runtime::TokioRuntime; use inventory; use std::{collections::HashMap, sync::Arc}; @@ -72,7 +73,7 @@ async fn main() { println!("Routes: {:?}", routes.keys()); // create gateway - let gateway = HelixGateway::new( + let gateway = HelixGateway::::new( &format!("0.0.0.0:{}", port), graph, GatewayOpts::DEFAULT_POOL_SIZE, @@ -81,8 +82,8 @@ async fn main() { ).await; // start server println!("Starting server..."); - let a = gateway.connection_handler.accept_conns().await.unwrap(); - let b = a.await.unwrap(); + let handle = gateway.connection_handler.accept_conns().await.unwrap(); + handle.await; } diff --git a/helixdb/src/helix_gateway/connection/connection.rs b/helixdb/src/helix_gateway/connection/connection.rs index b2e0c1e6e..a9e8b5410 100644 --- a/helixdb/src/helix_gateway/connection/connection.rs +++ b/helixdb/src/helix_gateway/connection/connection.rs @@ -7,15 +7,15 @@ use std::{ collections::HashMap, sync::{Arc, Mutex}, }; -use tokio::net::TcpListener; +use crate::helix_transport::Transport; use crate::helix_runtime::AsyncRuntime; use crate::helix_gateway::{router::router::HelixRouter, thread_pool::thread_pool::ThreadPool}; -pub struct ConnectionHandler { +pub struct ConnectionHandler { pub address: String, pub active_connections: Arc>>, - pub thread_pool: ThreadPool, + pub thread_pool: ThreadPool, pub runtime: R, } @@ -25,7 +25,7 @@ pub struct ClientConnection { pub addr: SocketAddr, } -impl ConnectionHandler { +impl ConnectionHandler { pub fn new( address: &str, graph: Arc, @@ -36,37 +36,33 @@ impl ConnectionHandler { Ok(Self { address: address.to_string(), active_connections: Arc::new(Mutex::new(HashMap::new())), - thread_pool: ThreadPool::new(size, graph, Arc::new(router), runtime.clone())?, + thread_pool: ThreadPool::::new(size, graph, Arc::new(router), runtime.clone())?, runtime, }) } pub async fn accept_conns(&self) -> Result<::JoinHandle<()>, GraphError> { - // Create a new TcpListener for each accept_conns call - let listener = TcpListener::bind(&self.address).await.map_err(|e| { - eprintln!("Failed to bind to address {}: {}", self.address, e); - GraphError::GraphConnectionError("Failed to bind to address".to_string(), e) - })?; + // Bind transport listener + let listener = T::bind(&self.address) + .await + .map_err(|e| { + eprintln!("Failed to bind to address {}: {}", self.address, e); + GraphError::GraphConnectionError("Failed to bind to address".to_string(), e) + })?; // Log binding success to stderr since stdout might be buffered let active_connections = Arc::clone(&self.active_connections); let thread_pool_sender = self.thread_pool.sender.clone(); - let address = self.address.clone(); let runtime = self.runtime.clone(); let handle = runtime.spawn(async move { loop { - match listener.accept().await { + match T::accept(&listener).await { Ok((stream, addr)) => { - // Configure TCP stream - if let Err(e) = stream.set_nodelay(true) { - eprintln!("Failed to set TCP_NODELAY: {}", e); - } - // Create a client connection record let client_id = Uuid::new_v4().to_string(); let client = ClientConnection { diff --git a/helixdb/src/helix_gateway/gateway.rs b/helixdb/src/helix_gateway/gateway.rs index f33abb89c..afd243933 100644 --- a/helixdb/src/helix_gateway/gateway.rs +++ b/helixdb/src/helix_gateway/gateway.rs @@ -1,6 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use super::connection::connection::ConnectionHandler; +use crate::helix_transport::Transport; use crate::helix_runtime::AsyncRuntime; use crate::helix_engine::graph_core::graph_core::HelixGraphEngine; use super::router::router::{HandlerFn, HelixRouter}; @@ -11,21 +12,21 @@ impl GatewayOpts { pub const DEFAULT_POOL_SIZE: usize = 1024; } -pub struct HelixGateway { - pub connection_handler: ConnectionHandler, +pub struct HelixGateway { + pub connection_handler: ConnectionHandler, pub runtime: R, } -impl HelixGateway { +impl HelixGateway { pub async fn new( address: &str, graph: Arc, size: usize, routes: Option>, runtime: R, - ) -> HelixGateway { + ) -> HelixGateway { let router = HelixRouter::new(routes); - let connection_handler = ConnectionHandler::new(address, graph, size, router, runtime.clone()).unwrap(); + let connection_handler = ConnectionHandler::::new(address, graph, size, router, runtime.clone()).unwrap(); println!("Gateway created"); HelixGateway { connection_handler, runtime } } diff --git a/helixdb/src/helix_gateway/thread_pool/thread_pool.rs b/helixdb/src/helix_gateway/thread_pool/thread_pool.rs index ec7b419da..6ef858329 100644 --- a/helixdb/src/helix_gateway/thread_pool/thread_pool.rs +++ b/helixdb/src/helix_gateway/thread_pool/thread_pool.rs @@ -10,22 +10,23 @@ use crate::protocol::response::Response; extern crate tokio; -use tokio::net::TcpStream; +use crate::helix_transport::Transport; -pub struct Worker { +pub struct Worker { pub id: usize, pub handle: ::JoinHandle<()>, pub runtime: R, + _marker: std::marker::PhantomData, } -impl Worker { +impl Worker { fn new( id: usize, graph_access: Arc, router: Arc, - rx: Receiver, + rx: Receiver, runtime: R, - ) -> Worker { + ) -> Worker { let handle = runtime.spawn(async move { loop { let mut conn = match rx.recv_async().await { @@ -68,31 +69,31 @@ impl Worker { } }); - Worker { id, handle, runtime } + Worker { id, handle, runtime, _marker: std::marker::PhantomData } } } -pub struct ThreadPool { - pub sender: Sender, +pub struct ThreadPool { + pub sender: Sender, pub num_unused_workers: Mutex, pub num_used_workers: Mutex, - pub workers: Vec>, + pub workers: Vec>, pub runtime: R, } -impl ThreadPool { +impl ThreadPool { pub fn new( size: usize, graph: Arc, router: Arc, runtime: R, - ) -> Result, RouterError> { + ) -> Result, RouterError> { assert!( size > 0, "Expected number of threads in thread pool to be more than 0, got {}", size ); - let (tx, rx) = flume::unbounded::(); + let (tx, rx) = flume::unbounded::(); let mut workers = Vec::with_capacity(size); for id in 0..size { workers.push(Worker::new(id, Arc::clone(&graph), Arc::clone(&router), rx.clone(), runtime.clone())); diff --git a/helixdb/src/helix_runtime/mod.rs b/helixdb/src/helix_runtime/mod.rs index 8481fea41..6885b9be6 100644 --- a/helixdb/src/helix_runtime/mod.rs +++ b/helixdb/src/helix_runtime/mod.rs @@ -8,7 +8,9 @@ use std::pin::Pin; /// Production code uses a Tokio-backed implementation while tests can /// provide deterministic schedulers by implementing this trait. pub trait AsyncRuntime { - type JoinHandle: Future + Send + 'static; + type JoinHandle: Future + Send + 'static + where + T: Send + 'static; /// Spawn a future onto the runtime. fn spawn(&self, fut: F) -> Self::JoinHandle diff --git a/helixdb/src/helix_runtime/tokio_runtime.rs b/helixdb/src/helix_runtime/tokio_runtime.rs index 8e2a3175d..b993040f0 100644 --- a/helixdb/src/helix_runtime/tokio_runtime.rs +++ b/helixdb/src/helix_runtime/tokio_runtime.rs @@ -1,6 +1,7 @@ use super::AsyncRuntime; use std::future::Future; use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; /// Tokio based implementation of [`AsyncRuntime`]. @@ -8,17 +9,35 @@ use std::time::Duration; pub struct TokioRuntime; impl AsyncRuntime for TokioRuntime { - type JoinHandle = tokio::task::JoinHandle; + type JoinHandle = TokioJoinHandle + where + T: Send + 'static; fn spawn(&self, fut: F) -> Self::JoinHandle where F: Future + Send + 'static, T: Send + 'static, { - tokio::spawn(fut) + TokioJoinHandle(tokio::spawn(fut)) } fn sleep(&self, dur: Duration) -> Pin + Send>> { Box::pin(tokio::time::sleep(dur)) } } + +/// Wrapper around Tokio's [`JoinHandle`] that unwraps the result. +pub struct TokioJoinHandle(tokio::task::JoinHandle); + +impl Future for TokioJoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner = unsafe { self.map_unchecked_mut(|s| &mut s.0) }; + match inner.poll(cx) { + Poll::Ready(Ok(val)) => Poll::Ready(val), + Poll::Ready(Err(err)) => panic!("Join error: {}", err), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/helixdb/src/helix_transport/mod.rs b/helixdb/src/helix_transport/mod.rs new file mode 100644 index 000000000..f8dee31ec --- /dev/null +++ b/helixdb/src/helix_transport/mod.rs @@ -0,0 +1,25 @@ +pub mod tokio_transport; + +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// Abstraction over network transport for HelixDB. +/// +/// The transport trait allows the gateway to be decoupled from a +/// concrete networking stack so that simulation tests can provide a +/// deterministic in-memory transport. +pub trait Transport { + type Listener: Send + Sync + 'static; + type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static; + + /// Bind a listener to the provided address. + fn bind<'a>(addr: &'a str) -> Pin> + Send + 'a>>; + + /// Accept the next incoming connection from a listener. + fn accept<'a>(listener: &'a Self::Listener) -> Pin> + Send + 'a>>; + + /// Connect to a remote address returning a stream. + fn connect<'a>(addr: &'a str) -> Pin> + Send + 'a>>; +} diff --git a/helixdb/src/helix_transport/tokio_transport.rs b/helixdb/src/helix_transport/tokio_transport.rs new file mode 100644 index 000000000..248229692 --- /dev/null +++ b/helixdb/src/helix_transport/tokio_transport.rs @@ -0,0 +1,27 @@ +use super::Transport; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, TcpStream}; + +/// Tokio based transport implementation using TCP sockets. +#[derive(Clone, Default)] +pub struct TokioTransport; + +impl Transport for TokioTransport { + type Listener = TcpListener; + type Stream = TcpStream; + + fn bind<'a>(addr: &'a str) -> Pin> + Send + 'a>> { + Box::pin(async move { TcpListener::bind(addr).await }) + } + + fn accept<'a>(listener: &'a Self::Listener) -> Pin> + Send + 'a>> { + Box::pin(async move { listener.accept().await }) + } + + fn connect<'a>(addr: &'a str) -> Pin> + Send + 'a>> { + Box::pin(async move { TcpStream::connect(addr).await }) + } +} diff --git a/helixdb/src/lib.rs b/helixdb/src/lib.rs index 728101eec..eea321429 100644 --- a/helixdb/src/lib.rs +++ b/helixdb/src/lib.rs @@ -6,3 +6,4 @@ pub mod helixc; pub mod ingestion_engine; pub mod protocol; pub mod helix_runtime; +pub mod helix_transport;