Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions duva-client/src/broker/read_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use duva::domains::TSerdeRead;

use duva::prelude::ReplicationId;
use duva::prelude::tokio::{self, net::tcp::OwnedReadHalf, sync::oneshot};
use duva::presentation::clients::request::ServerResponse;

pub struct ServerStreamReader(pub(crate) OwnedReadHalf);
impl ServerStreamReader {
Expand All @@ -17,16 +18,14 @@ impl ServerStreamReader {
let controller_sender = controller_sender.clone();

loop {
match self.0.deserialized_reads().await {
Ok(server_responses) => {
for res in server_responses {
if controller_sender
.send(BrokerMessage::FromServer(replication_id.clone(), res))
.await
.is_err()
{
break;
}
match self.0.deserialized_read::<ServerResponse>().await {
Ok(res) => {
if controller_sender
.send(BrokerMessage::FromServer(replication_id.clone(), res))
.await
.is_err()
{
break;
}
},
Err(e) => {
Expand Down
2 changes: 2 additions & 0 deletions duva-client/src/broker/write_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub struct ServerStreamWriter(pub(crate) OwnedWriteHalf);

impl ServerStreamWriter {
pub async fn write_all(&mut self, buf: &[u8]) -> anyhow::Result<()> {
let len = buf.len() as u32;
self.0.write_u32(len).await.context("Failed to write length prefix")?;
self.0.write_all(buf).await.context("Failed to send command")?;
self.0.flush().await.context("Failed to flush stream")?;
Ok(())
Expand Down
199 changes: 101 additions & 98 deletions duva/src/adapters/io/tokio_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,63 +6,61 @@ use crate::domains::{
TSerdeWrite,
};

use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use std::fmt::Debug;
use std::io::ErrorKind;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;

const BUFFER_SIZE: usize = 512;
const INITIAL_CAPACITY: usize = 1024;
// Arbitrary limit to prevent memory exhaustion.
const MAX_MSG_SIZE: usize = 4 * 1024 * 1024; // 4MB

#[async_trait::async_trait]
impl<T: AsyncReadExt + std::marker::Unpin + Sync + Send + Debug + 'static> TReadBytes for T {
// TCP doesn't inherently delimit messages.
// The data arrives in a continuous stream of bytes. And
// we might not receive all the data in one go.
// So, we need to read the data in chunks until we have received all the data for the message.
async fn read_bytes(&mut self, buffer: &mut BytesMut) -> Result<(), IoError> {
let mut temp_buffer = [0u8; BUFFER_SIZE];
loop {
let bytes_read =
self.read(&mut temp_buffer).await.map_err(|err| io_error_from_kind(err.kind()))?;

if bytes_read == 0 {
// read 0 bytes AND buffer is empty - connection closed
if buffer.is_empty() {
return Err(IoError::ConnectionAborted);
}
// read 0 bytes but buffer is not empty - end of message
return Ok(());
// Reads a length-prefixed message from the stream.
// The protocol is:
// - 4 bytes (u32, big-endian) for the message length.
// - N bytes for the message body, where N is the length read.
async fn read_bytes(&mut self) -> Result<BytesMut, IoError> {
let len = self.read_u32().await.map_err(|e| {
if e.kind() == ErrorKind::UnexpectedEof {
IoError::ConnectionAborted
} else {
io_error_from_kind(e.kind())
}
})? as usize;

if len > MAX_MSG_SIZE {
return Err(IoError::Custom(format!(
"Incoming message too large: {len} bytes, max is {MAX_MSG_SIZE}"
)));
}

// Extend the buffer with the newly read data
buffer.extend_from_slice(&temp_buffer[..bytes_read]);
// Reserve space in the buffer (Allocates, but doesn't write zeros yet)
let mut buffer = BytesMut::with_capacity(len);

// If fewer bytes than the buffer size are read, it suggests that
// - The sender has sent all the data currently available for this message.
// - You have reached the end of the message.
if bytes_read < temp_buffer.len() {
break;
// Unsafe-ish trick made safe by Tokio
// Tokio's read_buf can read directly into uninitialized memory
// preventing the "Double Write".
while buffer.len() < len {
let n = self.read_buf(&mut buffer).await.map_err(|e| io_error_from_kind(e.kind()))?;
if n == 0 {
return Err(IoError::ConnectionAborted);
}
}
Ok(())

Ok(buffer)
}
}

#[async_trait::async_trait]
impl<T: AsyncReadExt + std::marker::Unpin + Sync + Send + Debug + 'static> TSerdeDynamicRead for T {
async fn receive_peer_msgs(&mut self) -> Result<Vec<PeerMessage>, IoError> {
let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY);
self.read_bytes(&mut buffer).await?;
let mut parsed_values = Vec::new();
while !buffer.is_empty() {
let (request, size) = bincode::decode_from_slice(&buffer, SERDE_CONFIG)
.map_err(|e| IoError::Custom(e.to_string()))?;
parsed_values.push(request);
buffer = buffer.split_off(size);
}
Ok(parsed_values)
async fn receive_peer_msgs(&mut self) -> Result<PeerMessage, IoError> {
let body = self.read_bytes().await?;

let (peer_message, _) = bincode::decode_from_slice(&body, SERDE_CONFIG)
.map_err(|e| IoError::Custom(e.to_string()))?;

Ok(peer_message)
}
async fn receive_connection_msgs(&mut self) -> Result<String, IoError> {
self.deserialized_read().await
Expand All @@ -73,7 +71,12 @@ impl<T: AsyncWriteExt + std::marker::Unpin + Sync + Send + Debug + 'static> TSer
async fn serialized_write(&mut self, buf: impl bincode::Encode + Send) -> Result<(), IoError> {
let encoded = bincode::encode_to_vec(buf, SERDE_CONFIG)
.map_err(|e| IoError::Custom(e.to_string()))?;
self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind()))

let len = encoded.len() as u32;
self.write_u32(len).await.map_err(|e| io_error_from_kind(e.kind()))?;

self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind()))?;
self.flush().await.map_err(|e| io_error_from_kind(e.kind()))
}
}

Expand All @@ -84,7 +87,12 @@ impl<T: AsyncWriteExt + std::marker::Unpin + Sync + Send + Debug + 'static> TSer
async fn send(&mut self, msg: PeerMessage) -> Result<(), IoError> {
let encoded = bincode::encode_to_vec(msg, SERDE_CONFIG)
.map_err(|e| IoError::Custom(e.to_string()))?;
self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind()))

let len = encoded.len() as u32;
self.write_u32(len).await.map_err(|e| io_error_from_kind(e.kind()))?;

self.write_all(&encoded).await.map_err(|e| io_error_from_kind(e.kind()))?;
self.flush().await.map_err(|e| io_error_from_kind(e.kind()))
}

async fn send_connection_msg(&mut self, arg: &str) -> Result<(), IoError> {
Expand All @@ -97,32 +105,13 @@ impl<T: AsyncReadExt + std::marker::Unpin + Sync + Send + Debug + 'static> TSerd
where
U: bincode::Decode<()>,
{
let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY);
self.read_bytes(&mut buffer).await?;
let body = self.read_bytes().await?;

let (request, _) = bincode::decode_from_slice(&buffer, SERDE_CONFIG)
let (request, _) = bincode::decode_from_slice(&body, SERDE_CONFIG)
.map_err(|e| IoError::Custom(e.to_string()))?;

Ok(request)
}

async fn deserialized_reads<U>(&mut self) -> Result<Vec<U>, IoError>
where
U: bincode::Decode<()>,
{
let mut buffer = BytesMut::with_capacity(INITIAL_CAPACITY);
self.read_bytes(&mut buffer).await?;

let mut parsed_values = Vec::new();

while !buffer.is_empty() {
let (request, size) = bincode::decode_from_slice(&buffer, SERDE_CONFIG)
.map_err(|e| IoError::Custom(e.to_string()))?;
parsed_values.push(request);
buffer = buffer.split_off(size);
}
Ok(parsed_values)
}
}

impl TAsyncReadWrite for TcpStream {
Expand Down Expand Up @@ -164,17 +153,17 @@ pub mod test_tokio_stream_impl {
data: String,
}

/// A mock that implements AsyncRead for testing
/// A mock that implements AsyncRead for testing by simulating a byte stream.
#[derive(Debug)]
struct MockAsyncStream {
chunks: Vec<Vec<u8>>,
current_chunk: usize,
data: Vec<u8>,
pos: usize,
}

impl MockAsyncStream {
/// Creates a new mock stream from a vector of byte chunks.
/// Each inner Vec<u8> represents a single return from the `read` method.
/// Creates a new mock stream from a vector of byte chunks, which are flattened.
fn new(chunks: Vec<Vec<u8>>) -> Self {
MockAsyncStream { chunks, current_chunk: 0 }
MockAsyncStream { data: chunks.into_iter().flatten().collect(), pos: 0 }
}
}

Expand All @@ -187,21 +176,16 @@ pub mod test_tokio_stream_impl {
) -> std::task::Poll<std::io::Result<()>> {
let self_mut = self.get_mut();

if self_mut.current_chunk >= self_mut.chunks.len() {
// All chunks have been read, simulate EOF (read 0 bytes)
return std::task::Poll::Ready(Ok(()));
if self_mut.pos >= self_mut.data.len() {
return std::task::Poll::Ready(Ok(())); // EOF
}

let chunk = &self_mut.chunks[self_mut.current_chunk];
let bytes_to_copy = std::cmp::min(buf.remaining(), chunk.len());
let remaining_data = &self_mut.data[self_mut.pos..];
let bytes_to_copy = std::cmp::min(buf.remaining(), remaining_data.len());

// Copy data into the ReadBuf
buf.put_slice(&chunk[..bytes_to_copy]);
buf.put_slice(&remaining_data[..bytes_to_copy]);
self_mut.pos += bytes_to_copy;

// Note: Real world scenarios would handle `Poll::Pending` here,
// but for unit tests, we usually return `Ready` to keep them synchronous.

self_mut.current_chunk += 1;
std::task::Poll::Ready(Ok(()))
}
}
Expand All @@ -222,37 +206,56 @@ pub mod test_tokio_stream_impl {
let msg = TestMessage { id: 1, data: "quick".to_string() };

let encoded = bincode::encode_to_vec(&msg, SERDE_CONFIG).unwrap();
let encoded_msg = BytesMut::from(encoded.as_slice());
let mut mock = MockAsyncStream::new(vec![encoded_msg.into()]);
let len = encoded.len() as u32;
let mut framed_msg = len.to_be_bytes().to_vec();
framed_msg.extend_from_slice(&encoded);

let mut mock = MockAsyncStream::new(vec![framed_msg]);

// 2. Act
let result: Result<Vec<TestMessage>, IoError> = mock.deserialized_reads().await;
let result: Result<TestMessage, IoError> = mock.deserialized_read().await;

// 3. Assert
let deserialized = result.unwrap();
assert_eq!(deserialized.len(), 1);
assert_eq!(deserialized[0], msg);

assert_eq!(deserialized, msg);
}

#[tokio::test]
async fn test_deserialize_reads_vec() {
// 1. Arrange: Single message in one chunk

// 1. Arrange: two messages sent sequentially
let message_one = TestMessage { id: 1, data: "quick".to_string() };
let message_two = TestMessage { id: 2, data: "silver".to_string() };
let mut raw_data = vec![];
raw_data.extend_from_slice(&bincode::encode_to_vec(message_one, SERDE_CONFIG).unwrap());
raw_data.extend_from_slice(&bincode::encode_to_vec(message_two, SERDE_CONFIG).unwrap());

let mut mock = MockAsyncStream::new(vec![raw_data]);
let encoded1 = bincode::encode_to_vec(&message_one, SERDE_CONFIG).unwrap();
let len1 = encoded1.len() as u32;
let mut framed_msg1 = len1.to_be_bytes().to_vec();
framed_msg1.extend_from_slice(&encoded1);

// 2. Act
let result: Result<Vec<TestMessage>, IoError> = mock.deserialized_reads().await;
let encoded2 = bincode::encode_to_vec(&message_two, SERDE_CONFIG).unwrap();
let len2 = encoded2.len() as u32;
let mut framed_msg2 = len2.to_be_bytes().to_vec();
framed_msg2.extend_from_slice(&encoded2);

// 3. Assert
let deserialized = result.unwrap();
assert_eq!(deserialized.len(), 2);
assert_eq!(deserialized[0], TestMessage { id: 1, data: "quick".to_string() });
assert_eq!(deserialized[1], TestMessage { id: 2, data: "silver".to_string() });
let mut combined_data = framed_msg1;
combined_data.extend_from_slice(&framed_msg2);

let mut mock = MockAsyncStream::new(vec![combined_data]);

// 2. Act: read first message
let result1: Result<TestMessage, IoError> = mock.deserialized_read().await;

// 3. Assert: first message is correct
let deserialized1 = result1.unwrap();

assert_eq!(deserialized1, message_one);

// 4. Act: read second message
let result2: Result<TestMessage, IoError> = mock.deserialized_read().await;

// 5. Assert: second message is correct
let deserialized2 = result2.unwrap();

assert_eq!(deserialized2, message_two);
}
}
6 changes: 3 additions & 3 deletions duva/src/domains/cluster_actors/actor/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ impl FakeReadWrite {

#[async_trait::async_trait]
impl TSerdeDynamicRead for FakeReadWrite {
async fn receive_peer_msgs(&mut self) -> Result<Vec<PeerMessage>, IoError> {
async fn receive_peer_msgs(&mut self) -> Result<PeerMessage, IoError> {
let guard = self.0.lock().await;
let values = guard.clone().drain(..).collect();
Ok(values)
let values: Vec<_> = guard.clone().drain(..).collect();
Ok(values[values.len() - 1].clone())
}

async fn receive_connection_msgs(&mut self) -> Result<String, IoError> {
Expand Down
38 changes: 16 additions & 22 deletions duva/src/domains/cluster_actors/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,23 @@ impl ClusterActor {
};
}

async fn process_peer_message(
&mut self,
peer_messages: Vec<PeerMessage>,
from: PeerIdentifier,
) {
async fn process_peer_message(&mut self, peer_message: PeerMessage, from: PeerIdentifier) {
use PeerMessage::*;
for peer_message in peer_messages {
match peer_message {
ClusterHeartBeat(heartbeat) => self.receive_cluster_heartbeat(heartbeat).await,
RequestVote(request_vote) => self.vote_election(request_vote).await,
AckReplication(repl_res) => self.ack_replication(&from, repl_res).await,
AppendEntriesRPC(heartbeat) => self.append_entries_rpc(heartbeat).await,
ElectionVote(request_vote_reply) => {
self.receive_election_vote(&from, request_vote_reply).await
},
StartRebalance => self.start_rebalance().await,
BatchEntries(migrate_batch) => self.receive_batch(migrate_batch, &from).await,
MigrationBatchAck(migration_batch_ack) => {
self.handle_migration_ack(migration_batch_ack).await
},
CloseConnection => self.close_connection(&from).await,
};
}
match peer_message {
ClusterHeartBeat(heartbeat) => self.receive_cluster_heartbeat(heartbeat).await,
RequestVote(request_vote) => self.vote_election(request_vote).await,
AckReplication(repl_res) => self.ack_replication(&from, repl_res).await,
AppendEntriesRPC(heartbeat) => self.append_entries_rpc(heartbeat).await,
ElectionVote(request_vote_reply) => {
self.receive_election_vote(&from, request_vote_reply).await
},
StartRebalance => self.start_rebalance().await,
BatchEntries(migrate_batch) => self.receive_batch(migrate_batch, &from).await,
MigrationBatchAck(migration_batch_ack) => {
self.handle_migration_ack(migration_batch_ack).await
},
CloseConnection => self.close_connection(&from).await,
};
}

#[instrument(level = tracing::Level::DEBUG, skip(self, conn_msg))]
Expand Down
Loading
Loading