From 9ff232fa8a181c4e87214a87f8ddc2ef8997d74e Mon Sep 17 00:00:00 2001 From: Migo Date: Sun, 7 Dec 2025 08:37:31 +0400 Subject: [PATCH 1/5] framing added --- duva-client/src/broker/write_stream.rs | 2 + duva/src/adapters/io/tokio_stream.rs | 144 +++++++++++++++---------- 2 files changed, 87 insertions(+), 59 deletions(-) diff --git a/duva-client/src/broker/write_stream.rs b/duva-client/src/broker/write_stream.rs index 9015c0fc..9c0bacaa 100644 --- a/duva-client/src/broker/write_stream.rs +++ b/duva-client/src/broker/write_stream.rs @@ -7,6 +7,8 @@ pub struct ServerStreamWriter(pub(crate) OwnedWriteHalf); impl ServerStreamWriter { pub async fn write_all(&mut self, buf: &[u8]) -> anyhow::Result<()> { + let len = buf.len() as u32; + self.0.write_u32(len).await.context("Failed to write length prefix")?; self.0.write_all(buf).await.context("Failed to send command")?; self.0.flush().await.context("Failed to flush stream")?; Ok(()) diff --git a/duva/src/adapters/io/tokio_stream.rs b/duva/src/adapters/io/tokio_stream.rs index 9debe2fd..0272b3aa 100644 --- a/duva/src/adapters/io/tokio_stream.rs +++ b/duva/src/adapters/io/tokio_stream.rs @@ -12,40 +12,42 @@ use std::io::ErrorKind; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -const BUFFER_SIZE: usize = 512; const INITIAL_CAPACITY: usize = 1024; +// Arbitrary limit to prevent memory exhaustion. +const MAX_MSG_SIZE: usize = 4 * 1024 * 1024; // 4MB #[async_trait::async_trait] impl TReadBytes for T { - // TCP doesn't inherently delimit messages. - // The data arrives in a continuous stream of bytes. And - // we might not receive all the data in one go. - // So, we need to read the data in chunks until we have received all the data for the message. + // Reads a length-prefixed message from the stream. + // The protocol is: + // - 4 bytes (u32, big-endian) for the message length. + // - N bytes for the message body, where N is the length read. async fn read_bytes(&mut self, buffer: &mut BytesMut) -> Result<(), IoError> { - let mut temp_buffer = [0u8; BUFFER_SIZE]; - loop { - let bytes_read = - self.read(&mut temp_buffer).await.map_err(|err| io_error_from_kind(err.kind()))?; - - if bytes_read == 0 { - // read 0 bytes AND buffer is empty - connection closed - if buffer.is_empty() { - return Err(IoError::ConnectionAborted); - } - // read 0 bytes but buffer is not empty - end of message - return Ok(()); + let len = self.read_u32().await.map_err(|e| { + if e.kind() == ErrorKind::UnexpectedEof { + IoError::ConnectionAborted + } else { + io_error_from_kind(e.kind()) } + })? as usize; - // Extend the buffer with the newly read data - buffer.extend_from_slice(&temp_buffer[..bytes_read]); + if len > MAX_MSG_SIZE { + return Err(IoError::Custom(format!( + "Incoming message too large: {len} bytes, max is {MAX_MSG_SIZE}" + ))); + } - // If fewer bytes than the buffer size are read, it suggests that - // - The sender has sent all the data currently available for this message. - // - You have reached the end of the message. - if bytes_read < temp_buffer.len() { - break; - } + buffer.reserve(len); + let mut body_buffer = vec![0u8; len]; + if let Err(e) = self.read_exact(&mut body_buffer).await { + return if e.kind() == ErrorKind::UnexpectedEof { + Err(IoError::ConnectionAborted) + } else { + Err(io_error_from_kind(e.kind())) + }; } + + buffer.extend_from_slice(&body_buffer); Ok(()) } } @@ -73,7 +75,12 @@ impl TSer async fn serialized_write(&mut self, buf: impl bincode::Encode + Send) -> Result<(), IoError> { let encoded = bincode::encode_to_vec(buf, SERDE_CONFIG) .map_err(|e| IoError::Custom(e.to_string()))?; - self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind())) + + let len = encoded.len() as u32; + self.write_u32(len).await.map_err(|e| io_error_from_kind(e.kind()))?; + + self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind()))?; + self.flush().await.map_err(|e| io_error_from_kind(e.kind())) } } @@ -84,7 +91,12 @@ impl TSer async fn send(&mut self, msg: PeerMessage) -> Result<(), IoError> { let encoded = bincode::encode_to_vec(msg, SERDE_CONFIG) .map_err(|e| IoError::Custom(e.to_string()))?; - self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind())) + + let len = encoded.len() as u32; + self.write_u32(len).await.map_err(|e| io_error_from_kind(e.kind()))?; + + self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind()))?; + self.flush().await.map_err(|e| io_error_from_kind(e.kind())) } async fn send_connection_msg(&mut self, arg: &str) -> Result<(), IoError> { @@ -164,17 +176,17 @@ pub mod test_tokio_stream_impl { data: String, } - /// A mock that implements AsyncRead for testing + /// A mock that implements AsyncRead for testing by simulating a byte stream. #[derive(Debug)] struct MockAsyncStream { - chunks: Vec>, - current_chunk: usize, + data: Vec, + pos: usize, } + impl MockAsyncStream { - /// Creates a new mock stream from a vector of byte chunks. - /// Each inner Vec represents a single return from the `read` method. + /// Creates a new mock stream from a vector of byte chunks, which are flattened. fn new(chunks: Vec>) -> Self { - MockAsyncStream { chunks, current_chunk: 0 } + MockAsyncStream { data: chunks.into_iter().flatten().collect(), pos: 0 } } } @@ -187,21 +199,16 @@ pub mod test_tokio_stream_impl { ) -> std::task::Poll> { let self_mut = self.get_mut(); - if self_mut.current_chunk >= self_mut.chunks.len() { - // All chunks have been read, simulate EOF (read 0 bytes) - return std::task::Poll::Ready(Ok(())); + if self_mut.pos >= self_mut.data.len() { + return std::task::Poll::Ready(Ok(())); // EOF } - let chunk = &self_mut.chunks[self_mut.current_chunk]; - let bytes_to_copy = std::cmp::min(buf.remaining(), chunk.len()); - - // Copy data into the ReadBuf - buf.put_slice(&chunk[..bytes_to_copy]); + let remaining_data = &self_mut.data[self_mut.pos..]; + let bytes_to_copy = std::cmp::min(buf.remaining(), remaining_data.len()); - // Note: Real world scenarios would handle `Poll::Pending` here, - // but for unit tests, we usually return `Ready` to keep them synchronous. + buf.put_slice(&remaining_data[..bytes_to_copy]); + self_mut.pos += bytes_to_copy; - self_mut.current_chunk += 1; std::task::Poll::Ready(Ok(())) } } @@ -222,8 +229,11 @@ pub mod test_tokio_stream_impl { let msg = TestMessage { id: 1, data: "quick".to_string() }; let encoded = bincode::encode_to_vec(&msg, SERDE_CONFIG).unwrap(); - let encoded_msg = BytesMut::from(encoded.as_slice()); - let mut mock = MockAsyncStream::new(vec![encoded_msg.into()]); + let len = encoded.len() as u32; + let mut framed_msg = len.to_be_bytes().to_vec(); + framed_msg.extend_from_slice(&encoded); + + let mut mock = MockAsyncStream::new(vec![framed_msg]); // 2. Act let result: Result, IoError> = mock.deserialized_reads().await; @@ -236,23 +246,39 @@ pub mod test_tokio_stream_impl { #[tokio::test] async fn test_deserialize_reads_vec() { - // 1. Arrange: Single message in one chunk - + // 1. Arrange: two messages sent sequentially let message_one = TestMessage { id: 1, data: "quick".to_string() }; let message_two = TestMessage { id: 2, data: "silver".to_string() }; - let mut raw_data = vec![]; - raw_data.extend_from_slice(&bincode::encode_to_vec(message_one, SERDE_CONFIG).unwrap()); - raw_data.extend_from_slice(&bincode::encode_to_vec(message_two, SERDE_CONFIG).unwrap()); - let mut mock = MockAsyncStream::new(vec![raw_data]); + let encoded1 = bincode::encode_to_vec(&message_one, SERDE_CONFIG).unwrap(); + let len1 = encoded1.len() as u32; + let mut framed_msg1 = len1.to_be_bytes().to_vec(); + framed_msg1.extend_from_slice(&encoded1); - // 2. Act - let result: Result, IoError> = mock.deserialized_reads().await; + let encoded2 = bincode::encode_to_vec(&message_two, SERDE_CONFIG).unwrap(); + let len2 = encoded2.len() as u32; + let mut framed_msg2 = len2.to_be_bytes().to_vec(); + framed_msg2.extend_from_slice(&encoded2); - // 3. Assert - let deserialized = result.unwrap(); - assert_eq!(deserialized.len(), 2); - assert_eq!(deserialized[0], TestMessage { id: 1, data: "quick".to_string() }); - assert_eq!(deserialized[1], TestMessage { id: 2, data: "silver".to_string() }); + let mut combined_data = framed_msg1; + combined_data.extend_from_slice(&framed_msg2); + + let mut mock = MockAsyncStream::new(vec![combined_data]); + + // 2. Act: read first message + let result1: Result, IoError> = mock.deserialized_reads().await; + + // 3. Assert: first message is correct + let deserialized1 = result1.unwrap(); + assert_eq!(deserialized1.len(), 1); + assert_eq!(deserialized1[0], message_one); + + // 4. Act: read second message + let result2: Result, IoError> = mock.deserialized_reads().await; + + // 5. Assert: second message is correct + let deserialized2 = result2.unwrap(); + assert_eq!(deserialized2.len(), 1); + assert_eq!(deserialized2[0], message_two); } } From c5a257ce25cf61da68c0edb474de47e78f69e529 Mon Sep 17 00:00:00 2001 From: Migo Date: Sun, 7 Dec 2025 08:51:16 +0400 Subject: [PATCH 2/5] remove vector reads --- duva-client/src/broker/read_stream.rs | 19 +++++---- duva/src/adapters/io/tokio_stream.rs | 36 +++++------------ duva/src/domains/interface.rs | 4 -- duva/src/presentation/clients/stream.rs | 54 +++++++++---------------- 4 files changed, 36 insertions(+), 77 deletions(-) diff --git a/duva-client/src/broker/read_stream.rs b/duva-client/src/broker/read_stream.rs index 0e15af24..ca3970e2 100644 --- a/duva-client/src/broker/read_stream.rs +++ b/duva-client/src/broker/read_stream.rs @@ -3,6 +3,7 @@ use duva::domains::TSerdeRead; use duva::prelude::ReplicationId; use duva::prelude::tokio::{self, net::tcp::OwnedReadHalf, sync::oneshot}; +use duva::presentation::clients::request::ServerResponse; pub struct ServerStreamReader(pub(crate) OwnedReadHalf); impl ServerStreamReader { @@ -17,16 +18,14 @@ impl ServerStreamReader { let controller_sender = controller_sender.clone(); loop { - match self.0.deserialized_reads().await { - Ok(server_responses) => { - for res in server_responses { - if controller_sender - .send(BrokerMessage::FromServer(replication_id.clone(), res)) - .await - .is_err() - { - break; - } + match self.0.deserialized_read::().await { + Ok(res) => { + if controller_sender + .send(BrokerMessage::FromServer(replication_id.clone(), res)) + .await + .is_err() + { + break; } }, Err(e) => { diff --git a/duva/src/adapters/io/tokio_stream.rs b/duva/src/adapters/io/tokio_stream.rs index 0272b3aa..c1393a68 100644 --- a/duva/src/adapters/io/tokio_stream.rs +++ b/duva/src/adapters/io/tokio_stream.rs @@ -117,24 +117,6 @@ impl TSerd Ok(request) } - - async fn deserialized_reads(&mut self) -> Result, IoError> - where - U: bincode::Decode<()>, - { - let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY); - self.read_bytes(&mut buffer).await?; - - let mut parsed_values = Vec::new(); - - while !buffer.is_empty() { - let (request, size) = bincode::decode_from_slice(&buffer, SERDE_CONFIG) - .map_err(|e| IoError::Custom(e.to_string()))?; - parsed_values.push(request); - buffer = buffer.split_off(size); - } - Ok(parsed_values) - } } impl TAsyncReadWrite for TcpStream { @@ -236,12 +218,12 @@ pub mod test_tokio_stream_impl { let mut mock = MockAsyncStream::new(vec![framed_msg]); // 2. Act - let result: Result, IoError> = mock.deserialized_reads().await; + let result: Result = mock.deserialized_read().await; // 3. Assert let deserialized = result.unwrap(); - assert_eq!(deserialized.len(), 1); - assert_eq!(deserialized[0], msg); + + assert_eq!(deserialized, msg); } #[tokio::test] @@ -266,19 +248,19 @@ pub mod test_tokio_stream_impl { let mut mock = MockAsyncStream::new(vec![combined_data]); // 2. Act: read first message - let result1: Result, IoError> = mock.deserialized_reads().await; + let result1: Result = mock.deserialized_read().await; // 3. Assert: first message is correct let deserialized1 = result1.unwrap(); - assert_eq!(deserialized1.len(), 1); - assert_eq!(deserialized1[0], message_one); + + assert_eq!(deserialized1, message_one); // 4. Act: read second message - let result2: Result, IoError> = mock.deserialized_reads().await; + let result2: Result = mock.deserialized_read().await; // 5. Assert: second message is correct let deserialized2 = result2.unwrap(); - assert_eq!(deserialized2.len(), 1); - assert_eq!(deserialized2[0], message_two); + + assert_eq!(deserialized2, message_two); } } diff --git a/duva/src/domains/interface.rs b/duva/src/domains/interface.rs index 4f2b8bd6..b5ae89b5 100644 --- a/duva/src/domains/interface.rs +++ b/duva/src/domains/interface.rs @@ -37,10 +37,6 @@ pub trait TSerdeRead { fn deserialized_read>( &mut self, ) -> impl std::future::Future> + Send; - - fn deserialized_reads>( - &mut self, - ) -> impl std::future::Future, IoError>> + Send; } pub(crate) trait TAsyncReadWrite { diff --git a/duva/src/presentation/clients/stream.rs b/duva/src/presentation/clients/stream.rs index 1ef46040..81850401 100644 --- a/duva/src/presentation/clients/stream.rs +++ b/duva/src/presentation/clients/stream.rs @@ -31,8 +31,8 @@ impl ClientStreamReader { ) { loop { // * extract queries - let query_ios = self.r.deserialized_reads::().await; - if let Err(err) = query_ios { + let query_io = self.r.deserialized_read::().await; + if let Err(err) = query_io { info!("{}", err); if err.should_break() { return; @@ -44,42 +44,24 @@ impl ClientStreamReader { } // * map client request - let requests = query_ios.unwrap().into_iter().map(|query_io| { - Ok(ClientRequest { - action: query_io.action, - conn_offset: query_io.conn_offset, - conn_id: self.conn_id.clone(), - }) - }); + let req = query_io.unwrap(); - for req in requests { - match req { - Err(err) => { - let _ = stream_writer_sender - .send(ServerResponse::Err { reason: err, conn_offset: 0 }) - .await; - break; - }, - Ok(ClientRequest { action, conn_offset, conn_id }) => { - // * processing part - let result = match action { - ClientAction::NonMutating(non_mutating_action) => { - handler.handle_non_mutating(non_mutating_action, conn_offset).await - }, - ClientAction::Mutating(log_entry) => { - handler.handle_mutating(conn_offset, conn_id, log_entry).await - }, - }; + // * processing part + let result = match req.action { + ClientAction::NonMutating(non_mutating_action) => { + handler.handle_non_mutating(non_mutating_action, req.conn_offset).await + }, + ClientAction::Mutating(log_entry) => { + handler.handle_mutating(req.conn_offset, self.conn_id.clone(), log_entry).await + }, + }; - let response = result.unwrap_or_else(|e| { - error!("failure on state change / query {e}"); - ServerResponse::Err { reason: e.to_string(), conn_offset } - }); - if stream_writer_sender.send(response).await.is_err() { - return; - } - }, - } + let response = result.unwrap_or_else(|e| { + error!("failure on state change / query {e}"); + ServerResponse::Err { reason: e.to_string(), conn_offset: req.conn_offset } + }); + if stream_writer_sender.send(response).await.is_err() { + return; } } } From 7ce123e6a383f71e98677e63c9cf1880fd590fa7 Mon Sep 17 00:00:00 2001 From: Migo Date: Sun, 7 Dec 2025 08:59:33 +0400 Subject: [PATCH 3/5] simplify peer msg recv --- duva/src/adapters/io/tokio_stream.rs | 15 +++----- .../domains/cluster_actors/actor/tests/mod.rs | 6 +-- duva/src/domains/cluster_actors/service.rs | 38 ++++++++----------- duva/src/domains/interface.rs | 2 +- duva/src/domains/peers/command.rs | 2 +- duva/src/domains/peers/service.rs | 14 ++++--- duva/src/presentation/clients/request.rs | 7 ---- duva/src/presentation/clients/stream.rs | 2 +- 8 files changed, 36 insertions(+), 50 deletions(-) diff --git a/duva/src/adapters/io/tokio_stream.rs b/duva/src/adapters/io/tokio_stream.rs index c1393a68..82b26796 100644 --- a/duva/src/adapters/io/tokio_stream.rs +++ b/duva/src/adapters/io/tokio_stream.rs @@ -54,17 +54,14 @@ impl TRead #[async_trait::async_trait] impl TSerdeDynamicRead for T { - async fn receive_peer_msgs(&mut self) -> Result, IoError> { + async fn receive_peer_msgs(&mut self) -> Result { let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY); self.read_bytes(&mut buffer).await?; - let mut parsed_values = Vec::new(); - while !buffer.is_empty() { - let (request, size) = bincode::decode_from_slice(&buffer, SERDE_CONFIG) - .map_err(|e| IoError::Custom(e.to_string()))?; - parsed_values.push(request); - buffer = buffer.split_off(size); - } - Ok(parsed_values) + + let (peer_message, _) = bincode::decode_from_slice(&buffer, SERDE_CONFIG) + .map_err(|e| IoError::Custom(e.to_string()))?; + + Ok(peer_message) } async fn receive_connection_msgs(&mut self) -> Result { self.deserialized_read().await diff --git a/duva/src/domains/cluster_actors/actor/tests/mod.rs b/duva/src/domains/cluster_actors/actor/tests/mod.rs index eafaa9fa..9a9bcb7f 100644 --- a/duva/src/domains/cluster_actors/actor/tests/mod.rs +++ b/duva/src/domains/cluster_actors/actor/tests/mod.rs @@ -47,10 +47,10 @@ impl FakeReadWrite { #[async_trait::async_trait] impl TSerdeDynamicRead for FakeReadWrite { - async fn receive_peer_msgs(&mut self) -> Result, IoError> { + async fn receive_peer_msgs(&mut self) -> Result { let guard = self.0.lock().await; - let values = guard.clone().drain(..).collect(); - Ok(values) + let values: Vec<_> = guard.clone().drain(..).collect(); + Ok(values[values.len() - 1].clone()) } async fn receive_connection_msgs(&mut self) -> Result { diff --git a/duva/src/domains/cluster_actors/service.rs b/duva/src/domains/cluster_actors/service.rs index 32c3b39b..0c710d76 100644 --- a/duva/src/domains/cluster_actors/service.rs +++ b/duva/src/domains/cluster_actors/service.rs @@ -130,29 +130,23 @@ impl ClusterActor { }; } - async fn process_peer_message( - &mut self, - peer_messages: Vec, - from: PeerIdentifier, - ) { + async fn process_peer_message(&mut self, peer_message: PeerMessage, from: PeerIdentifier) { use PeerMessage::*; - for peer_message in peer_messages { - match peer_message { - ClusterHeartBeat(heartbeat) => self.receive_cluster_heartbeat(heartbeat).await, - RequestVote(request_vote) => self.vote_election(request_vote).await, - AckReplication(repl_res) => self.ack_replication(&from, repl_res).await, - AppendEntriesRPC(heartbeat) => self.append_entries_rpc(heartbeat).await, - ElectionVote(request_vote_reply) => { - self.receive_election_vote(&from, request_vote_reply).await - }, - StartRebalance => self.start_rebalance().await, - BatchEntries(migrate_batch) => self.receive_batch(migrate_batch, &from).await, - MigrationBatchAck(migration_batch_ack) => { - self.handle_migration_ack(migration_batch_ack).await - }, - CloseConnection => self.close_connection(&from).await, - }; - } + match peer_message { + ClusterHeartBeat(heartbeat) => self.receive_cluster_heartbeat(heartbeat).await, + RequestVote(request_vote) => self.vote_election(request_vote).await, + AckReplication(repl_res) => self.ack_replication(&from, repl_res).await, + AppendEntriesRPC(heartbeat) => self.append_entries_rpc(heartbeat).await, + ElectionVote(request_vote_reply) => { + self.receive_election_vote(&from, request_vote_reply).await + }, + StartRebalance => self.start_rebalance().await, + BatchEntries(migrate_batch) => self.receive_batch(migrate_batch, &from).await, + MigrationBatchAck(migration_batch_ack) => { + self.handle_migration_ack(migration_batch_ack).await + }, + CloseConnection => self.close_connection(&from).await, + }; } #[instrument(level = tracing::Level::DEBUG, skip(self, conn_msg))] diff --git a/duva/src/domains/interface.rs b/duva/src/domains/interface.rs index b5ae89b5..4d1b8b44 100644 --- a/duva/src/domains/interface.rs +++ b/duva/src/domains/interface.rs @@ -16,7 +16,7 @@ pub trait TReadBytes: Send + Sync + Debug + 'static { #[async_trait::async_trait] pub(crate) trait TSerdeDynamicRead: Send + Sync + Debug + 'static { - async fn receive_peer_msgs(&mut self) -> Result, IoError>; + async fn receive_peer_msgs(&mut self) -> Result; async fn receive_connection_msgs(&mut self) -> Result; } diff --git a/duva/src/domains/peers/command.rs b/duva/src/domains/peers/command.rs index bcf2694d..00fc72f4 100644 --- a/duva/src/domains/peers/command.rs +++ b/duva/src/domains/peers/command.rs @@ -12,7 +12,7 @@ use std::collections::{HashMap, VecDeque}; #[derive(Debug, PartialEq, Eq)] pub(crate) struct PeerCommand { pub(crate) from: PeerIdentifier, - pub(crate) msg: Vec, + pub(crate) msg: PeerMessage, } impl From for ClusterCommand { diff --git a/duva/src/domains/peers/service.rs b/duva/src/domains/peers/service.rs index 32480190..7831124f 100644 --- a/duva/src/domains/peers/service.rs +++ b/duva/src/domains/peers/service.rs @@ -39,15 +39,17 @@ impl PeerListener { } async fn start(&mut self) { - while let Ok(cmds) = self.read_command().await { - let _ = self - .cluster_handler - .send(PeerCommand { from: self.peer_id.clone(), msg: cmds }) - .await; + loop { + if let Ok(msg) = self.read_command().await { + let _ = self + .cluster_handler + .send(PeerCommand { from: self.peer_id.clone(), msg }) + .await; + } } } - async fn read_command(&mut self) -> anyhow::Result> { + async fn read_command(&mut self) -> anyhow::Result { Ok(self.read_connected.receive_peer_msgs().await?) } } diff --git a/duva/src/presentation/clients/request.rs b/duva/src/presentation/clients/request.rs index 6ca02df4..c4befd55 100644 --- a/duva/src/presentation/clients/request.rs +++ b/duva/src/presentation/clients/request.rs @@ -299,13 +299,6 @@ pub fn extract_expiry(expiry: &str) -> anyhow::Result { Ok((Utc::now() + chrono::Duration::milliseconds(expiry.parse::()?)).timestamp_millis()) } -#[derive(Clone, Debug)] -pub struct ClientRequest { - pub(crate) action: ClientAction, - pub(crate) conn_offset: u64, - pub(crate) conn_id: String, -} - #[derive(Clone, Debug, bincode::Decode, bincode::Encode)] pub enum ServerResponse { WriteRes { res: QueryIO, log_index: u64, conn_offset: u64 }, diff --git a/duva/src/presentation/clients/stream.rs b/duva/src/presentation/clients/stream.rs index 81850401..0183ac7b 100644 --- a/duva/src/presentation/clients/stream.rs +++ b/duva/src/presentation/clients/stream.rs @@ -1,4 +1,4 @@ -use super::{ClientController, request::ClientRequest}; +use super::ClientController; use crate::domains::TSerdeRead; use crate::domains::cluster_actors::queue::ClusterActorSender; use crate::domains::cluster_actors::topology::Topology; From 1ebdccf86c468b2a8cf84327bf99c9e149cfb876 Mon Sep 17 00:00:00 2001 From: Migo Date: Sun, 7 Dec 2025 09:02:50 +0400 Subject: [PATCH 4/5] remove dependency ooon async_trat for TReadBytes --- duva/src/adapters/io/tokio_stream.rs | 1 - duva/src/domains/interface.rs | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/duva/src/adapters/io/tokio_stream.rs b/duva/src/adapters/io/tokio_stream.rs index 82b26796..574710f2 100644 --- a/duva/src/adapters/io/tokio_stream.rs +++ b/duva/src/adapters/io/tokio_stream.rs @@ -16,7 +16,6 @@ const INITIAL_CAPACITY: usize = 1024; // Arbitrary limit to prevent memory exhaustion. const MAX_MSG_SIZE: usize = 4 * 1024 * 1024; // 4MB -#[async_trait::async_trait] impl TReadBytes for T { // Reads a length-prefixed message from the stream. // The protocol is: diff --git a/duva/src/domains/interface.rs b/duva/src/domains/interface.rs index 4d1b8b44..cce8bff3 100644 --- a/duva/src/domains/interface.rs +++ b/duva/src/domains/interface.rs @@ -9,9 +9,11 @@ use crate::domains::{ }; use bytes::BytesMut; -#[async_trait::async_trait] pub trait TReadBytes: Send + Sync + Debug + 'static { - async fn read_bytes(&mut self, buf: &mut BytesMut) -> Result<(), IoError>; + fn read_bytes( + &mut self, + buf: &mut BytesMut, + ) -> impl std::future::Future> + Send; } #[async_trait::async_trait] From 6f979f0715bd0985c7d3f0975b5645e290e1f406 Mon Sep 17 00:00:00 2001 From: Migo Date: Sun, 7 Dec 2025 09:34:23 +0400 Subject: [PATCH 5/5] zero initialization --- duva/src/adapters/io/tokio_stream.rs | 37 ++++++++++++++-------------- duva/src/domains/interface.rs | 8 +++--- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/duva/src/adapters/io/tokio_stream.rs b/duva/src/adapters/io/tokio_stream.rs index 574710f2..3d6c3633 100644 --- a/duva/src/adapters/io/tokio_stream.rs +++ b/duva/src/adapters/io/tokio_stream.rs @@ -6,13 +6,12 @@ use crate::domains::{ TSerdeWrite, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use std::fmt::Debug; use std::io::ErrorKind; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -const INITIAL_CAPACITY: usize = 1024; // Arbitrary limit to prevent memory exhaustion. const MAX_MSG_SIZE: usize = 4 * 1024 * 1024; // 4MB @@ -21,7 +20,7 @@ impl TRead // The protocol is: // - 4 bytes (u32, big-endian) for the message length. // - N bytes for the message body, where N is the length read. - async fn read_bytes(&mut self, buffer: &mut BytesMut) -> Result<(), IoError> { + async fn read_bytes(&mut self) -> Result { let len = self.read_u32().await.map_err(|e| { if e.kind() == ErrorKind::UnexpectedEof { IoError::ConnectionAborted @@ -36,28 +35,29 @@ impl TRead ))); } - buffer.reserve(len); - let mut body_buffer = vec![0u8; len]; - if let Err(e) = self.read_exact(&mut body_buffer).await { - return if e.kind() == ErrorKind::UnexpectedEof { - Err(IoError::ConnectionAborted) - } else { - Err(io_error_from_kind(e.kind())) - }; + // Reserve space in the buffer (Allocates, but doesn't write zeros yet) + let mut buffer = BytesMut::with_capacity(len); + + // Unsafe-ish trick made safe by Tokio + // Tokio's read_buf can read directly into uninitialized memory + // preventing the "Double Write". + while buffer.len() < len { + let n = self.read_buf(&mut buffer).await.map_err(|e| io_error_from_kind(e.kind()))?; + if n == 0 { + return Err(IoError::ConnectionAborted); + } } - buffer.extend_from_slice(&body_buffer); - Ok(()) + Ok(buffer) } } #[async_trait::async_trait] impl TSerdeDynamicRead for T { async fn receive_peer_msgs(&mut self) -> Result { - let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY); - self.read_bytes(&mut buffer).await?; + let body = self.read_bytes().await?; - let (peer_message, _) = bincode::decode_from_slice(&buffer, SERDE_CONFIG) + let (peer_message, _) = bincode::decode_from_slice(&body, SERDE_CONFIG) .map_err(|e| IoError::Custom(e.to_string()))?; Ok(peer_message) @@ -105,10 +105,9 @@ impl TSerd where U: bincode::Decode<()>, { - let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY); - self.read_bytes(&mut buffer).await?; + let body = self.read_bytes().await?; - let (request, _) = bincode::decode_from_slice(&buffer, SERDE_CONFIG) + let (request, _) = bincode::decode_from_slice(&body, SERDE_CONFIG) .map_err(|e| IoError::Custom(e.to_string()))?; Ok(request) diff --git a/duva/src/domains/interface.rs b/duva/src/domains/interface.rs index cce8bff3..e5c280f5 100644 --- a/duva/src/domains/interface.rs +++ b/duva/src/domains/interface.rs @@ -7,13 +7,11 @@ use crate::domains::{ connections::connection_types::{ReadConnected, WriteConnected}, }, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; pub trait TReadBytes: Send + Sync + Debug + 'static { - fn read_bytes( - &mut self, - buf: &mut BytesMut, - ) -> impl std::future::Future> + Send; + fn read_bytes(&mut self) + -> impl std::future::Future> + Send; } #[async_trait::async_trait]