diff --git a/src/error.rs b/src/error.rs index b8200621..ddbb2558 100644 --- a/src/error.rs +++ b/src/error.rs @@ -103,6 +103,8 @@ impl Display for DecodingErrorKind { pub enum PacketErrorKind { /// The maximal allowed size of the packet was exceeded ExceededMaxPacketSize, + /// Only `PacketType::Packet` can be fragmented + PacketCannotBeFragmented, } impl Display for PacketErrorKind { @@ -111,6 +113,9 @@ impl Display for PacketErrorKind { PacketErrorKind::ExceededMaxPacketSize => { write!(fmt, "The packet size was bigger than the max allowed size.") } + PacketErrorKind::PacketCannotBeFragmented => { + write!(fmt, "The packet type cannot be fragmented.") + } } } } diff --git a/src/infrastructure/acknowledgment.rs b/src/infrastructure/acknowledgment.rs index 749e165c..506f9ba0 100644 --- a/src/infrastructure/acknowledgment.rs +++ b/src/infrastructure/acknowledgment.rs @@ -1,5 +1,4 @@ -use crate::packet::OrderingGuarantee; -use crate::packet::SequenceNumber; +use crate::packet::{OrderingGuarantee, PacketType, SequenceNumber}; use crate::sequence_buffer::{sequence_greater_than, sequence_less_than, SequenceBuffer}; use std::collections::HashMap; @@ -101,6 +100,7 @@ impl AcknowledgmentHandler { /// Enqueue the outgoing packet for acknowledgment. pub fn process_outgoing( &mut self, + packet_type: PacketType, payload: &[u8], ordering_guarantee: OrderingGuarantee, item_identifier: Option, @@ -108,6 +108,7 @@ impl AcknowledgmentHandler { self.sent_packets.insert( self.sequence_number, SentPacket { + packet_type, payload: Box::from(payload), ordering_guarantee, item_identifier, @@ -138,8 +139,9 @@ impl AcknowledgmentHandler { } } -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq)] pub struct SentPacket { + pub packet_type: PacketType, pub payload: Box<[u8]>, pub ordering_guarantee: OrderingGuarantee, pub item_identifier: Option, @@ -154,7 +156,7 @@ pub struct ReceivedPacket; mod test { use crate::infrastructure::acknowledgment::ReceivedPacket; use crate::infrastructure::{AcknowledgmentHandler, SentPacket}; - use crate::packet::OrderingGuarantee; + use crate::packet::{OrderingGuarantee, PacketType}; use log::debug; #[test] @@ -162,7 +164,12 @@ mod test { let mut handler = AcknowledgmentHandler::new(); assert_eq!(handler.local_sequence_num(), 0); for i in 0..10 { - handler.process_outgoing(vec![].as_slice(), OrderingGuarantee::None, None); + handler.process_outgoing( + PacketType::Packet, + vec![].as_slice(), + OrderingGuarantee::None, + None, + ); assert_eq!(handler.local_sequence_num(), i + 1); } } @@ -171,7 +178,12 @@ mod test { fn local_seq_num_wraps_on_overflow() { let mut handler = AcknowledgmentHandler::new(); handler.sequence_number = u16::max_value(); - handler.process_outgoing(vec![].as_slice(), OrderingGuarantee::None, None); + handler.process_outgoing( + PacketType::Packet, + vec![].as_slice(), + OrderingGuarantee::None, + None, + ); assert_eq!(handler.local_sequence_num(), 0); } @@ -202,9 +214,19 @@ mod test { let mut handler = AcknowledgmentHandler::new(); handler.sequence_number = 0; - handler.process_outgoing(vec![1, 2, 3].as_slice(), OrderingGuarantee::None, None); + handler.process_outgoing( + PacketType::Packet, + vec![1, 2, 3].as_slice(), + OrderingGuarantee::None, + None, + ); handler.sequence_number = 40; - handler.process_outgoing(vec![1, 2, 4].as_slice(), OrderingGuarantee::None, None); + handler.process_outgoing( + PacketType::Packet, + vec![1, 2, 4].as_slice(), + OrderingGuarantee::None, + None, + ); static ARBITRARY: u16 = 23; handler.process_incoming(ARBITRARY, 40, 0); @@ -212,9 +234,10 @@ mod test { assert_eq!( handler.dropped_packets(), vec![SentPacket { + packet_type: PacketType::Packet, payload: vec![1, 2, 3].into_boxed_slice(), ordering_guarantee: OrderingGuarantee::None, - item_identifier: None, + item_identifier: None }] ); } @@ -226,7 +249,12 @@ mod test { for i in 0..500 { handler.sequence_number = i; - handler.process_outgoing(vec![1, 2, 3].as_slice(), OrderingGuarantee::None, None); + handler.process_outgoing( + PacketType::Packet, + vec![1, 2, 3].as_slice(), + OrderingGuarantee::None, + None, + ); other.process_incoming(i, handler.remote_sequence_num(), handler.ack_bitfield()); handler.process_incoming(i, other.remote_sequence_num(), other.ack_bitfield()); @@ -243,7 +271,12 @@ mod test { let mut drop_count = 0; for i in 0..100 { - handler.process_outgoing(vec![1, 2, 3].as_slice(), OrderingGuarantee::None, None); + handler.process_outgoing( + PacketType::Packet, + vec![1, 2, 3].as_slice(), + OrderingGuarantee::None, + None, + ); handler.sequence_number = i; // dropping every 4th with modulo's @@ -293,7 +326,12 @@ mod test { #[test] fn test_process_outgoing() { let mut handler = AcknowledgmentHandler::new(); - handler.process_outgoing(vec![1, 2, 3].as_slice(), OrderingGuarantee::None, None); + handler.process_outgoing( + PacketType::Packet, + vec![1, 2, 3].as_slice(), + OrderingGuarantee::None, + None, + ); assert_eq!(handler.sent_packets.len(), 1); assert_eq!(handler.local_sequence_num(), 1); } diff --git a/src/net/socket.rs b/src/net/socket.rs index 98ab6564..dd425fe4 100644 --- a/src/net/socket.rs +++ b/src/net/socket.rs @@ -3,7 +3,7 @@ use crate::{ config::Config, error::{ErrorKind, Result}, net::{connection::ActiveConnections, events::SocketEvent, link_conditioner::LinkConditioner}, - packet::{DeliveryGuarantee, Outgoing, Packet}, + packet::{DeliveryGuarantee, Packet, PacketInfo}, }; use crossbeam_channel::{self, unbounded, Receiver, SendError, Sender, TryRecvError}; use log::error; @@ -14,14 +14,56 @@ use std::{ time::{Duration, Instant}, }; +// Wrap `LinkConditioner` and `UdpSocket` together +#[derive(Debug)] +struct SocketWithConditioner { + socket: UdpSocket, + link_conditioner: Option, +} + +impl SocketWithConditioner { + /// Creates an instance of `SocketWithConditioner` + pub fn new(socket: UdpSocket, link_conditioner: Option) -> Self { + Self { + socket, + link_conditioner, + } + } + + // In the presence of a link conditioner, we would like it to determine whether or not we should + // send a single packet over the UDP socket. + pub fn send_packet(&mut self, addr: &SocketAddr, payload: &[u8]) -> Result { + if let Some(ref mut link) = self.link_conditioner { + if !link.should_send() { + return Ok(0); + } + } + Ok(self.socket.send_to(payload, addr)?) + } + + /// Returns mutable reference of `UdpSocket` + pub fn socket(&mut self) -> &mut UdpSocket { + &mut self.socket + } + + /// Returns the local socket address + pub fn local_addr(&self) -> Result { + Ok(self.socket.local_addr()?) + } + + /// Set the link conditioner for this socket. See [LinkConditioner] for further details. + pub fn set_link_conditioner(&mut self, conditioner: Option) { + self.link_conditioner = conditioner; + } +} + /// A reliable UDP socket implementation with configurable reliability and ordering guarantees. #[derive(Debug)] pub struct Socket { - socket: UdpSocket, + socket_wrapper: SocketWithConditioner, config: Config, connections: ActiveConnections, recv_buffer: Vec, - link_conditioner: Option, event_sender: Sender, packet_receiver: Receiver, @@ -71,10 +113,9 @@ impl Socket { let (packet_sender, packet_receiver) = unbounded(); Ok(Socket { recv_buffer: vec![0; config.receive_buffer_max_size], - socket, + socket_wrapper: SocketWithConditioner::new(socket, None), config, connections: ActiveConnections::new(), - link_conditioner: None, event_sender, packet_receiver, @@ -177,12 +218,12 @@ impl Socket { /// Set the link conditioner for this socket. See [LinkConditioner] for further details. pub fn set_link_conditioner(&mut self, link_conditioner: Option) { - self.link_conditioner = link_conditioner; + self.socket_wrapper.set_link_conditioner(link_conditioner); } - /// Get the local socket address + /// Returns the local socket address pub fn local_addr(&self) -> Result { - Ok(self.socket.local_addr()?) + self.socket_wrapper.local_addr() } /// Iterate through the dead connections and disconnect them by removing them from the @@ -224,7 +265,7 @@ impl Socket { .heartbeat_required_connections(heartbeat_interval, time) .map(|connection| { ( - connection.create_and_process_heartbeat(time), + connection.process_outgoing(PacketInfo::heartbeat_packet(&[]), None, time), connection.remote_address, ) }) @@ -233,9 +274,13 @@ impl Socket { let mut bytes_sent = 0; for (heartbeat_packet, address) in heartbeat_packets_and_addrs { - if self.should_send_packet() { - bytes_sent += self.send_packet(&address, &heartbeat_packet.contents())?; - } + let packet = heartbeat_packet? + .into_iter() + .next() + .expect("Heartbeat packet must exists"); + bytes_sent += self + .socket_wrapper + .send_packet(&address, &packet.contents())?; } Ok(bytes_sent) @@ -247,54 +292,55 @@ impl Socket { self.connections .get_or_insert_connection(packet.addr(), &self.config, time); - let dropped = connection.gather_dropped_packets(); - let mut processed_packets: Vec = dropped - .iter() - .flat_map(|waiting_packet| { - connection.process_outgoing( - &waiting_packet.payload, + let mut bytes_sent = 0; + + // TODO maybe dropped packets shouldn't depend on how often a user sends a packet? + let dropped_packets = connection.gather_dropped_packets(); + for dropped in dropped_packets { + let packets = connection.process_outgoing( + PacketInfo { + packet_type: dropped.packet_type, + payload: &dropped.payload, // Because a delivery guarantee is only sent with reliable packets - DeliveryGuarantee::Reliable, + delivery: DeliveryGuarantee::Reliable, // This is stored with the dropped packet because they could be mixed - waiting_packet.ordering_guarantee, - waiting_packet.item_identifier, - time, - ) - }) - .collect(); + ordering: dropped.ordering_guarantee, + }, + dropped.item_identifier, + time, + )?; + + for outgoing in packets { + bytes_sent += self + .socket_wrapper + .send_packet(&packet.addr(), &outgoing.contents())?; + } + } - let processed_packet = connection.process_outgoing( - packet.payload(), - packet.delivery_guarantee(), - packet.order_guarantee(), + let packets = connection.process_outgoing( + PacketInfo::user_packet( + packet.payload(), + packet.delivery_guarantee(), + packet.order_guarantee(), + ), None, time, )?; - - processed_packets.push(processed_packet); - - let mut bytes_sent = 0; - - for processed_packet in processed_packets { - if self.should_send_packet() { - match processed_packet { - Outgoing::Packet(outgoing) => { - bytes_sent += self.send_packet(&packet.addr(), &outgoing.contents())?; - } - Outgoing::Fragments(packets) => { - for outgoing in packets { - bytes_sent += self.send_packet(&packet.addr(), &outgoing.contents())?; - } - } - } - } + for outgoing in packets { + bytes_sent += self + .socket_wrapper + .send_packet(&packet.addr(), &outgoing.contents())?; } Ok(bytes_sent) } // On success the packet will be sent on the `event_sender` fn recv_from(&mut self, time: Instant) -> Result { - match self.socket.recv_from(&mut self.recv_buffer) { + match self + .socket_wrapper + .socket() + .recv_from(&mut self.recv_buffer) + { Ok((recv_len, address)) => { if recv_len == 0 { return Err(ErrorKind::ReceivedDataToShort); @@ -309,13 +355,14 @@ impl Socket { self.connections .get_or_create_connection(address, &self.config, time); - match connection { - Left(existing) => { - existing.process_incoming(received_payload, &self.event_sender, time)?; - } - Right(mut anonymous) => { - anonymous.process_incoming(received_payload, &self.event_sender, time)?; - } + let packets = match connection { + Left(existing) => existing.process_incoming(received_payload, time)?, + Right(mut anonymous) => anonymous.process_incoming(received_payload, time)?, + }; + for incoming in packets { + self.event_sender + .send(SocketEvent::Packet(incoming.0)) + .unwrap(); } } Err(e) => { @@ -335,22 +382,6 @@ impl Socket { } } - // Send a single packet over the UDP socket. - fn send_packet(&self, addr: &SocketAddr, payload: &[u8]) -> Result { - let bytes_sent = self.socket.send_to(payload, addr)?; - Ok(bytes_sent) - } - - // In the presence of a link conditioner, we would like it to determine whether or not we should - // send a packet. - fn should_send_packet(&mut self) -> bool { - if let Some(link_conditioner) = &mut self.link_conditioner { - link_conditioner.should_send() - } else { - true - } - } - #[cfg(test)] fn connection_count(&self) -> usize { self.connections.count() @@ -359,9 +390,13 @@ impl Socket { #[cfg(test)] fn forget_all_incoming_packets(&mut self) { std::thread::sleep(std::time::Duration::from_millis(100)); - self.socket.set_nonblocking(true).unwrap(); + self.socket_wrapper.socket().set_nonblocking(true).unwrap(); loop { - match self.socket.recv_from(&mut self.recv_buffer) { + match self + .socket_wrapper + .socket() + .recv_from(&mut self.recv_buffer) + { Ok((recv_len, _address)) => { if recv_len == 0 { panic!("Received data too short"); @@ -371,7 +406,8 @@ impl Socket { if e.kind() != io::ErrorKind::WouldBlock { panic!("Encountered an error receiving data: {:?}", e); } else { - self.socket + self.socket_wrapper + .socket() .set_nonblocking(!self.config.blocking_mode) .unwrap(); return; @@ -705,7 +741,7 @@ mod tests { while let Some(message) = server.recv() { match message { SocketEvent::Connect(_) => {} - SocketEvent::Packet(_packet) => { + SocketEvent::Packet(_) => { cnt += 1; } SocketEvent::Timeout(_) => { diff --git a/src/net/virtual_connection.rs b/src/net/virtual_connection.rs index c283d1b5..ceb913c1 100644 --- a/src/net/virtual_connection.rs +++ b/src/net/virtual_connection.rs @@ -10,13 +10,11 @@ use crate::{ STANDARD_HEADER_SIZE, }, packet::{ - DeliveryGuarantee, OrderingGuarantee, Outgoing, OutgoingPacket, OutgoingPacketBuilder, - Packet, PacketReader, PacketType, SequenceNumber, + DeliveryGuarantee, IncomingPackets, OrderingGuarantee, OutgoingPacketBuilder, + OutgoingPackets, Packet, PacketInfo, PacketReader, PacketType, SequenceNumber, }, - SocketEvent, }; -use crossbeam_channel::{self, Sender}; use std::fmt; use std::net::SocketAddr; use std::time::{Duration, Instant}; @@ -31,7 +29,7 @@ pub struct VirtualConnection { /// The address of the remote endpoint pub remote_address: SocketAddr, - ordering_system: OrderingSystem>, + ordering_system: OrderingSystem<(Box<[u8]>, PacketType)>, sequencing_system: SequencingSystem>, acknowledge_handler: AcknowledgmentHandler, congestion_handler: CongestionHandler, @@ -75,40 +73,27 @@ impl VirtualConnection { time.duration_since(self.last_sent) } - /// This will create a heartbeat packet that is expected to be sent over the network - pub fn create_and_process_heartbeat(&mut self, time: Instant) -> OutgoingPacket<'static> { - self.last_sent = time; - self.congestion_handler - .process_outgoing(self.acknowledge_handler.local_sequence_num(), time); - - OutgoingPacketBuilder::new(&[]) - .with_default_header( - PacketType::Heartbeat, - DeliveryGuarantee::Unreliable, - OrderingGuarantee::None, - ) - .build() - } - /// This will pre-process the given buffer to be sent over the network. pub fn process_outgoing<'a>( &mut self, - payload: &'a [u8], - delivery_guarantee: DeliveryGuarantee, - ordering_guarantee: OrderingGuarantee, + packet: PacketInfo<'a>, last_item_identifier: Option, time: Instant, - ) -> Result> { - match delivery_guarantee { + ) -> Result> { + self.last_sent = time; + match packet.delivery { DeliveryGuarantee::Unreliable => { - if payload.len() <= self.config.receive_buffer_max_size { - let mut builder = OutgoingPacketBuilder::new(payload).with_default_header( - PacketType::Packet, - delivery_guarantee, - ordering_guarantee, - ); + if packet.payload.len() <= self.config.receive_buffer_max_size { + if packet.packet_type == PacketType::Heartbeat { + // TODO (bug?) is this really required here? + self.congestion_handler + .process_outgoing(self.acknowledge_handler.local_sequence_num(), time); + } + + let mut builder = OutgoingPacketBuilder::new(packet.payload) + .with_default_header(packet.packet_type, packet.delivery, packet.ordering); - if let OrderingGuarantee::Sequenced(stream_id) = ordering_guarantee { + if let OrderingGuarantee::Sequenced(stream_id) = packet.ordering { let item_identifier = self .sequencing_system .get_or_create_stream(stream_id.unwrap_or(DEFAULT_SEQUENCING_STREAM)) @@ -117,25 +102,24 @@ impl VirtualConnection { builder = builder.with_sequencing_header(item_identifier as u16, stream_id); }; - Ok(Outgoing::Packet(builder.build())) + Ok(OutgoingPackets::one(builder.build())) } else { - Err(ErrorKind::PacketError( - PacketErrorKind::ExceededMaxPacketSize, - )) + Err(PacketErrorKind::ExceededMaxPacketSize.into()) } } DeliveryGuarantee::Reliable => { - let payload_length = payload.len() as u16; + let payload_length = packet.payload.len() as u16; let mut item_identifier_value = None; let outgoing = { // spit the packet if the payload length is greater than the allowed fragment size. if payload_length <= self.config.fragment_size { - let mut builder = OutgoingPacketBuilder::new(payload).with_default_header( - PacketType::Packet, - delivery_guarantee, - ordering_guarantee, - ); + let mut builder = OutgoingPacketBuilder::new(packet.payload) + .with_default_header( + packet.packet_type, + packet.delivery, + packet.ordering, + ); builder = builder.with_acknowledgment_header( self.acknowledge_handler.local_sequence_num(), @@ -143,7 +127,7 @@ impl VirtualConnection { self.acknowledge_handler.ack_bitfield(), ); - if let OrderingGuarantee::Ordered(stream_id) = ordering_guarantee { + if let OrderingGuarantee::Ordered(stream_id) = packet.ordering { let item_identifier = if let Some(item_identifier) = last_item_identifier { item_identifier @@ -160,7 +144,7 @@ impl VirtualConnection { builder = builder.with_ordering_header(item_identifier, stream_id); }; - if let OrderingGuarantee::Sequenced(stream_id) = ordering_guarantee { + if let OrderingGuarantee::Sequenced(stream_id) = packet.ordering { let item_identifier = if let Some(item_identifier) = last_item_identifier { item_identifier @@ -177,10 +161,13 @@ impl VirtualConnection { builder = builder.with_sequencing_header(item_identifier, stream_id); }; - Outgoing::Packet(builder.build()) + OutgoingPackets::one(builder.build()) } else { - Outgoing::Fragments( - Fragmentation::spit_into_fragments(payload, &self.config)? + if packet.packet_type != PacketType::Packet { + return Err(PacketErrorKind::PacketCannotBeFragmented.into()); + } + OutgoingPackets::many( + Fragmentation::spit_into_fragments(packet.payload, &self.config)? .into_iter() .enumerate() .map(|(fragment_id, fragment)| { @@ -192,9 +179,9 @@ impl VirtualConnection { let mut builder = OutgoingPacketBuilder::new(fragment) .with_default_header( - PacketType::Fragment, - delivery_guarantee, - ordering_guarantee, + PacketType::Fragment, // change from Packet to Fragment type, it only matters when assembling/dissasembling packet header. + packet.delivery, + packet.ordering, ); builder = builder.with_fragment_header( @@ -218,12 +205,12 @@ impl VirtualConnection { } }; - self.last_sent = time; self.congestion_handler .process_outgoing(self.acknowledge_handler.local_sequence_num(), time); self.acknowledge_handler.process_outgoing( - payload, - ordering_guarantee, + packet.packet_type, + packet.payload, + packet.ordering, item_identifier_value, ); @@ -236,9 +223,8 @@ impl VirtualConnection { pub fn process_incoming( &mut self, received_data: &[u8], - sender: &Sender, time: Instant, - ) -> crate::Result<()> { + ) -> Result { self.last_heard = time; let mut packet_reader = PacketReader::new(received_data); @@ -252,7 +238,7 @@ impl VirtualConnection { if header.is_heartbeat() { // Heartbeat packets are unreliable, unordered and empty packets. // We already updated our `self.last_heard` time, nothing else to be done. - return Ok(()); + return Ok(IncomingPackets::zero()); } match header.delivery_guarantee() { @@ -268,25 +254,29 @@ impl VirtualConnection { .get_or_create_stream(arranging_header.stream_id()); if let Some(packet) = stream.arrange(arranging_header.arranging_id(), payload) { - Self::queue_packet( - sender, - packet, - self.remote_address, - header.delivery_guarantee(), - OrderingGuarantee::Sequenced(Some(arranging_header.stream_id())), - )?; + return Ok(IncomingPackets::one( + Packet::new( + self.remote_address, + packet, + header.delivery_guarantee(), + OrderingGuarantee::Sequenced(Some(arranging_header.stream_id())), + ), + header.packet_type(), + )); } - return Ok(()); + return Ok(IncomingPackets::zero()); } - Self::queue_packet( - sender, - packet_reader.read_payload(), - self.remote_address, - header.delivery_guarantee(), - header.ordering_guarantee(), - )?; + return Ok(IncomingPackets::one( + Packet::new( + self.remote_address, + packet_reader.read_payload(), + header.delivery_guarantee(), + header.ordering_guarantee(), + ), + header.packet_type(), + )); } DeliveryGuarantee::Reliable => { if header.is_fragment() { @@ -299,14 +289,6 @@ impl VirtualConnection { acked_header, ) { Ok(Some((payload, acked_header))) => { - Self::queue_packet( - sender, - payload.into_boxed_slice(), - self.remote_address, - header.delivery_guarantee(), - header.ordering_guarantee(), - )?; - self.congestion_handler .process_incoming(acked_header.sequence()); self.acknowledge_handler.process_incoming( @@ -314,14 +296,32 @@ impl VirtualConnection { acked_header.ack_seq(), acked_header.ack_field(), ); + + return Ok(IncomingPackets::one( + Packet::new( + self.remote_address, + payload.into_boxed_slice(), + header.delivery_guarantee(), + header.ordering_guarantee(), + ), + PacketType::Packet, // change from Fragment to Packet type, it only matters when assembling/dissasembling packet header. + )); } - Ok(None) => return Ok(()), + Ok(None) => return Ok(IncomingPackets::zero()), Err(e) => return Err(e), }; } } else { let acked_header = packet_reader.read_acknowledge_header()?; + self.congestion_handler + .process_incoming(acked_header.sequence()); + self.acknowledge_handler.process_incoming( + acked_header.sequence(), + acked_header.ack_seq(), + acked_header.ack_field(), + ); + if let OrderingGuarantee::Sequenced(_) = header.ordering_guarantee() { let arranging_header = packet_reader.read_arranging_header(u16::from( STANDARD_HEADER_SIZE + ACKED_PACKET_HEADER, @@ -336,13 +336,17 @@ impl VirtualConnection { if let Some(packet) = stream.arrange(arranging_header.arranging_id(), payload) { - Self::queue_packet( - sender, - packet, - self.remote_address, - header.delivery_guarantee(), - OrderingGuarantee::Sequenced(Some(arranging_header.stream_id())), - )?; + return Ok(IncomingPackets::one( + Packet::new( + self.remote_address, + packet, + header.delivery_guarantee(), + OrderingGuarantee::Sequenced(Some( + arranging_header.stream_id(), + )), + ), + header.packet_type(), + )); } } else if let OrderingGuarantee::Ordered(_id) = header.ordering_guarantee() { let arranging_header = packet_reader.read_arranging_header(u16::from( @@ -354,68 +358,46 @@ impl VirtualConnection { let stream = self .ordering_system .get_or_create_stream(arranging_header.stream_id()); - - if let Some(packet) = - stream.arrange(arranging_header.arranging_id(), payload) - { - Self::queue_packet( - sender, - packet, - self.remote_address, - header.delivery_guarantee(), - OrderingGuarantee::Ordered(Some(arranging_header.stream_id())), - )?; - - while let Some(packet) = stream.iter_mut().next() { - Self::queue_packet( - sender, - packet, - self.remote_address, - header.delivery_guarantee(), - OrderingGuarantee::Ordered(Some(arranging_header.stream_id())), - )?; - } - } + let address = self.remote_address; + return Ok(IncomingPackets::many( + stream + .arrange( + arranging_header.arranging_id(), + (payload, header.packet_type()), + ) + .into_iter() + .chain(stream.iter_mut()) + .map(|(packet, packet_type)| { + ( + Packet::new( + address, + packet, + header.delivery_guarantee(), + OrderingGuarantee::Ordered(Some( + arranging_header.stream_id(), + )), + ), + packet_type, + ) + }) + .collect(), + )); } else { let payload = packet_reader.read_payload(); - - Self::queue_packet( - sender, - payload, - self.remote_address, - header.delivery_guarantee(), - header.ordering_guarantee(), - )?; + return Ok(IncomingPackets::one( + Packet::new( + self.remote_address, + payload, + header.delivery_guarantee(), + header.ordering_guarantee(), + ), + header.packet_type(), + )); } - - self.congestion_handler - .process_incoming(acked_header.sequence()); - self.acknowledge_handler.process_incoming( - acked_header.sequence(), - acked_header.ack_seq(), - acked_header.ack_field(), - ); } } } - - Ok(()) - } - - fn queue_packet( - tx: &Sender, - payload: Box<[u8]>, - remote_addr: SocketAddr, - delivery: DeliveryGuarantee, - ordering: OrderingGuarantee, - ) -> Result<()> { - tx.send(SocketEvent::Packet(Packet::new( - remote_addr, - payload, - delivery, - ordering, - )))?; - Ok(()) + Ok(IncomingPackets::zero()) } /// This will gather dropped packets from the acknowledgment handler. @@ -443,16 +425,47 @@ mod tests { use crate::config::Config; use crate::net::constants; use crate::packet::header::{AckedPacketHeader, ArrangingHeader, HeaderWriter, StandardHeader}; - use crate::packet::{DeliveryGuarantee, OrderingGuarantee, Outgoing, Packet, PacketType}; + use crate::packet::{DeliveryGuarantee, OrderingGuarantee, Packet, PacketInfo, PacketType}; use crate::protocol_version::ProtocolVersion; - use crate::SocketEvent; use byteorder::{BigEndian, WriteBytesExt}; - use crossbeam_channel::{unbounded, TryRecvError}; use std::io::Write; - use std::time::Instant; + use std::time::{Duration, Instant}; const PAYLOAD: [u8; 4] = [1, 2, 3, 4]; + #[test] + fn set_last_sent_and_last_heard_when_processing() { + let mut connection = create_virtual_connection(); + let curr_sent = connection.last_sent; + let curr_heard = connection.last_heard; + + let out_packet = connection + .process_outgoing( + PacketInfo::heartbeat_packet(&[]), + None, + curr_sent + Duration::from_secs(1), + ) + .unwrap() + .into_iter() + .next() + .unwrap(); + let in_packet = connection + .process_incoming(&out_packet.contents(), curr_heard + Duration::from_secs(2)) + .unwrap() + .into_iter() + .next(); + + assert_eq!( + connection.last_sent.duration_since(curr_sent), + Duration::from_secs(1) + ); + assert_eq!( + connection.last_heard.duration_since(curr_heard), + Duration::from_secs(2) + ); + assert_eq!(in_packet.is_none(), true); + } + #[test] fn assure_right_fragmentation() { let mut protocol_version = Vec::new(); @@ -467,20 +480,19 @@ mod tests { let second_fragment = vec![0, 0, 2, 4]; let third_fragment = vec![0, 0, 3, 4]; - let (tx, rx) = unbounded::(); - let mut connection = create_virtual_connection(); - connection + let packet = connection .process_incoming( [standard_header.as_slice(), acked_header.as_slice()] .concat() .as_slice(), - &tx, Instant::now(), ) - .unwrap(); - assert!(rx.try_recv().is_err()); - connection + .unwrap() + .into_iter() + .next(); + assert!(packet.is_none()); + let packet = connection .process_incoming( [ standard_header.as_slice(), @@ -489,12 +501,13 @@ mod tests { ] .concat() .as_slice(), - &tx, Instant::now(), ) - .unwrap(); - assert!(rx.try_recv().is_err()); - connection + .unwrap() + .into_iter() + .next(); + assert!(packet.is_none()); + let packet = connection .process_incoming( [ standard_header.as_slice(), @@ -503,12 +516,13 @@ mod tests { ] .concat() .as_slice(), - &tx, Instant::now(), ) - .unwrap(); - assert!(rx.try_recv().is_err()); - connection + .unwrap() + .into_iter() + .next(); + assert!(packet.is_none()); + let (packets, _) = connection .process_incoming( [ standard_header.as_slice(), @@ -517,22 +531,16 @@ mod tests { ] .concat() .as_slice(), - &tx, Instant::now(), ) + .unwrap() + .into_iter() + .next() .unwrap(); - - let complete_fragment = rx.try_recv().unwrap(); - - match complete_fragment { - SocketEvent::Packet(fragment) => assert_eq!( - fragment.payload(), - &*[PAYLOAD, PAYLOAD, PAYLOAD].concat().into_boxed_slice() - ), - _ => { - panic!("Expected fragment other result."); - } - } + assert_eq!( + packets.payload(), + &*[PAYLOAD, PAYLOAD, PAYLOAD].concat().into_boxed_slice() + ); } #[test] @@ -541,22 +549,20 @@ mod tests { let buffer = vec![1; 4000]; - let outgoing = connection + let packets: Vec<_> = connection .process_outgoing( - &buffer, - DeliveryGuarantee::Reliable, - OrderingGuarantee::Ordered(None), + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Reliable, + OrderingGuarantee::Ordered(None), + ), None, Instant::now(), ) - .unwrap(); - - match outgoing { - Outgoing::Packet(_) => panic!("Expected fragment got packet"), - Outgoing::Fragments(fragments) => { - assert_eq!(fragments.len(), 4); - } - } + .unwrap() + .into_iter() + .collect(); + assert_eq!(packets.len(), 4); } #[test] @@ -567,9 +573,11 @@ mod tests { connection .process_outgoing( - &buffer, - DeliveryGuarantee::Unreliable, - OrderingGuarantee::None, + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Unreliable, + OrderingGuarantee::None, + ), None, Instant::now(), ) @@ -577,9 +585,11 @@ mod tests { connection .process_outgoing( - &buffer, - DeliveryGuarantee::Unreliable, - OrderingGuarantee::Sequenced(None), + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Unreliable, + OrderingGuarantee::Sequenced(None), + ), None, Instant::now(), ) @@ -587,9 +597,11 @@ mod tests { connection .process_outgoing( - &buffer, - DeliveryGuarantee::Reliable, - OrderingGuarantee::Ordered(None), + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Reliable, + OrderingGuarantee::Ordered(None), + ), None, Instant::now(), ) @@ -597,9 +609,11 @@ mod tests { connection .process_outgoing( - &buffer, - DeliveryGuarantee::Reliable, - OrderingGuarantee::Sequenced(None), + PacketInfo::user_packet( + &buffer, + DeliveryGuarantee::Reliable, + OrderingGuarantee::Sequenced(None), + ), None, Instant::now(), ) @@ -614,11 +628,11 @@ mod tests { DeliveryGuarantee::Unreliable, OrderingGuarantee::Sequenced(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::unreliable_sequenced( + Some(Packet::unreliable_sequenced( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 1, ); @@ -626,11 +640,11 @@ mod tests { DeliveryGuarantee::Unreliable, OrderingGuarantee::Sequenced(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::unreliable_sequenced( + Some(Packet::unreliable_sequenced( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 3, ); @@ -638,7 +652,7 @@ mod tests { DeliveryGuarantee::Unreliable, OrderingGuarantee::Sequenced(Some(1)), &mut connection, - Err(TryRecvError::Empty), + None, 2, ); @@ -646,11 +660,11 @@ mod tests { DeliveryGuarantee::Unreliable, OrderingGuarantee::Sequenced(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::unreliable_sequenced( + Some(Packet::unreliable_sequenced( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 4, ); @@ -658,11 +672,11 @@ mod tests { DeliveryGuarantee::Reliable, OrderingGuarantee::Sequenced(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::reliable_sequenced( + Some(Packet::reliable_sequenced( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 5, ); } @@ -675,11 +689,11 @@ mod tests { DeliveryGuarantee::Reliable, OrderingGuarantee::Ordered(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::reliable_ordered( + Some(Packet::reliable_ordered( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 0, ); @@ -687,7 +701,7 @@ mod tests { DeliveryGuarantee::Reliable, OrderingGuarantee::Ordered(Some(1)), &mut connection, - Err(TryRecvError::Empty), + None, 2, ); @@ -695,7 +709,7 @@ mod tests { DeliveryGuarantee::Reliable, OrderingGuarantee::Ordered(Some(1)), &mut connection, - Err(TryRecvError::Empty), + None, 3, ); @@ -703,11 +717,11 @@ mod tests { DeliveryGuarantee::Reliable, OrderingGuarantee::Ordered(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::reliable_ordered( + Some(Packet::reliable_ordered( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 1, ); } @@ -719,27 +733,24 @@ mod tests { assert_incoming_without_order( DeliveryGuarantee::Unreliable, &mut connection, - SocketEvent::Packet(Packet::unreliable(get_fake_addr(), PAYLOAD.to_vec())), + Packet::unreliable(get_fake_addr(), PAYLOAD.to_vec()), ); assert_incoming_without_order( DeliveryGuarantee::Reliable, &mut connection, - SocketEvent::Packet(Packet::reliable_unordered( - get_fake_addr(), - PAYLOAD.to_vec(), - )), + Packet::reliable_unordered(get_fake_addr(), PAYLOAD.to_vec()), ); assert_incoming_with_order( DeliveryGuarantee::Unreliable, OrderingGuarantee::Sequenced(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::unreliable_sequenced( + Some(Packet::unreliable_sequenced( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 1, ); @@ -747,11 +758,11 @@ mod tests { DeliveryGuarantee::Reliable, OrderingGuarantee::Ordered(Some(1)), &mut connection, - Ok(SocketEvent::Packet(Packet::reliable_ordered( + Some(Packet::reliable_ordered( get_fake_addr(), PAYLOAD.to_vec(), Some(1), - ))), + )), 0, ); } @@ -793,8 +804,6 @@ mod tests { let acked_header = vec![0, 0, 255, 4, 0, 0, 255, 255, 0, 0, 0, 0]; - let (tx, _rx) = unbounded::(); - use crate::error::{ErrorKind, FragmentErrorKind}; let mut connection = create_virtual_connection(); @@ -802,7 +811,6 @@ mod tests { [standard_header.as_slice(), acked_header.as_slice()] .concat() .as_slice(), - &tx, Instant::now(), ); @@ -830,7 +838,7 @@ mod tests { delivery: DeliveryGuarantee, ordering: OrderingGuarantee, connection: &mut VirtualConnection, - result_event: Result, + result_packet: Option, order_id: u16, ) { let mut packet = Vec::new(); @@ -867,25 +875,20 @@ mod tests { packet.write_all(&PAYLOAD).unwrap(); - let (tx, rx) = unbounded::(); - - connection - .process_incoming(packet.as_slice(), &tx, Instant::now()) - .unwrap(); - - let event = rx.try_recv(); - - match event { - Ok(val) => assert_eq!(val, result_event.unwrap()), - Err(e) => assert_eq!(e, result_event.err().unwrap()), - } + let packets = connection + .process_incoming(packet.as_slice(), Instant::now()) + .unwrap() + .into_iter() + .next() + .map(|(packet, _)| packet); + assert_eq!(packets, result_packet); } - // assert that the given `DeliveryGuarantee` results into the given `SocketEvent` after processing. + // assert that the given `DeliveryGuarantee` results into the given `Packet` after processing. fn assert_incoming_without_order( delivery: DeliveryGuarantee, connection: &mut VirtualConnection, - result_event: SocketEvent, + result_packet: Packet, ) { let mut packet = Vec::new(); @@ -900,15 +903,14 @@ mod tests { packet.write_all(&PAYLOAD).unwrap(); - let (tx, rx) = unbounded::(); - - connection - .process_incoming(packet.as_slice(), &tx, Instant::now()) + let (packet, _) = connection + .process_incoming(packet.as_slice(), Instant::now()) + .unwrap() + .into_iter() + .next() .unwrap(); - let event = rx.try_recv(); - - assert_eq!(event, Ok(result_event)); + assert_eq!(packet, result_packet); } // assert that the size of the processed header is the same as the given one. @@ -922,14 +924,19 @@ mod tests { let buffer = vec![1; 500]; let outgoing = connection - .process_outgoing(&buffer, delivery, ordering, None, Instant::now()) + .process_outgoing( + PacketInfo::user_packet(&buffer, delivery, ordering), + None, + Instant::now(), + ) .unwrap(); - - match outgoing { - Outgoing::Packet(packet) => { - assert_eq!(packet.contents().len() - buffer.len(), expected_header_size); - } - Outgoing::Fragments(_) => panic!("Expected packet got fragment"), + let mut iter = outgoing.into_iter(); + assert_eq!( + iter.next().unwrap().contents().len() - buffer.len(), + expected_header_size + ); + if iter.next().is_some() { + panic!("Expected not fragmented packet") } } } diff --git a/src/packet.rs b/src/packet.rs index 9c5ad403..e3935ee1 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -6,11 +6,13 @@ mod enums; mod outgoing; mod packet_reader; mod packet_structure; +mod process_result; pub use self::enums::{DeliveryGuarantee, OrderingGuarantee, PacketType}; -pub use self::outgoing::{Outgoing, OutgoingPacket, OutgoingPacketBuilder}; +pub use self::outgoing::{OutgoingPacket, OutgoingPacketBuilder}; pub use self::packet_reader::PacketReader; -pub use self::packet_structure::Packet; +pub use self::packet_structure::{Packet, PacketInfo}; +pub use self::process_result::{IncomingPackets, OutgoingPackets}; pub type SequenceNumber = u16; diff --git a/src/packet/header/standard_header.rs b/src/packet/header/standard_header.rs index 4eccaafe..c685a4f0 100644 --- a/src/packet/header/standard_header.rs +++ b/src/packet/header/standard_header.rs @@ -48,7 +48,6 @@ impl StandardHeader { } /// Returns the PacketType - #[cfg(test)] pub fn packet_type(&self) -> PacketType { self.packet_type } diff --git a/src/packet/outgoing.rs b/src/packet/outgoing.rs index 2ee33ecc..e622d513 100644 --- a/src/packet/outgoing.rs +++ b/src/packet/outgoing.rs @@ -122,14 +122,6 @@ impl<'p> OutgoingPacket<'p> { } } -/// Enum for storing different kinds of outgoing types with data. -pub enum Outgoing<'a> { - /// Represents a single packet. - Packet(OutgoingPacket<'a>), - /// Represents a packet that is fragmented and thus contains more than one `OutgoingPacket`. - Fragments(Vec>), -} - #[cfg(test)] mod tests { use crate::packet::PacketType; diff --git a/src/packet/packet_structure.rs b/src/packet/packet_structure.rs index 95ba89ee..084f59ba 100644 --- a/src/packet/packet_structure.rs +++ b/src/packet/packet_structure.rs @@ -1,4 +1,4 @@ -use crate::packet::{DeliveryGuarantee, OrderingGuarantee}; +use crate::packet::{DeliveryGuarantee, OrderingGuarantee, PacketType}; use std::net::SocketAddr; #[derive(Clone, PartialEq, Eq, Debug)] @@ -176,6 +176,45 @@ impl Packet { } } +/// This packet type has similar properties to `Packet` except that it doesn't own anything, and additionally has `PacketType`. +#[derive(Debug)] +pub struct PacketInfo<'a> { + /// defines a type of the packet + pub(crate) packet_type: PacketType, + /// the raw payload of the packet + pub(crate) payload: &'a [u8], + /// defines how the packet will be delivered. + pub(crate) delivery: DeliveryGuarantee, + /// defines how the packet will be ordered. + pub(crate) ordering: OrderingGuarantee, +} + +impl<'a> PacketInfo<'a> { + /// Creates a user packet that can be received by the user. + pub fn user_packet( + payload: &'a [u8], + delivery: DeliveryGuarantee, + ordering: OrderingGuarantee, + ) -> Self { + PacketInfo { + packet_type: PacketType::Packet, + payload, + delivery, + ordering, + } + } + + /// Creates a heartbeat packet that is expected to be sent over the network. + pub fn heartbeat_packet(payload: &'a [u8]) -> Self { + PacketInfo { + packet_type: PacketType::Heartbeat, + payload, + delivery: DeliveryGuarantee::Unreliable, + ordering: OrderingGuarantee::None, + } + } +} + #[cfg(test)] mod tests { use crate::packet::{DeliveryGuarantee, OrderingGuarantee, Packet}; diff --git a/src/packet/process_result.rs b/src/packet/process_result.rs new file mode 100644 index 00000000..1313c465 --- /dev/null +++ b/src/packet/process_result.rs @@ -0,0 +1,108 @@ +use crate::either::Either; +use crate::packet::{OutgoingPacket, Packet, PacketType}; + +use std::collections::VecDeque; + +/// Struct that implements `Iterator`, and is used to return incoming (from bytes to packets) or outgoing (from packet to bytes) packets. +/// It is used as optimization in cases, where most of the time there is only one element to iterate, and we don't want to create a vector for it. +pub struct ZeroOrMore { + data: Either, VecDeque>, +} + +impl ZeroOrMore { + fn zero() -> Self { + Self { + data: Either::Left(None), + } + } + + fn one(data: T) -> Self { + Self { + data: Either::Left(Some(data)), + } + } + + fn many(vec: VecDeque) -> Self { + Self { + data: Either::Right(vec), + } + } +} + +impl Iterator for ZeroOrMore { + type Item = T; + + fn next(&mut self) -> Option { + match &mut self.data { + Either::Left(option) => option.take(), + Either::Right(vec) => vec.pop_front(), + } + } +} + +/// Stores packets with headers that will be sent to the network, implements `IntoIterator` for convenience. +pub struct OutgoingPackets<'a> { + data: ZeroOrMore>, +} + +impl<'a> OutgoingPackets<'a> { + /// Stores only one packet, without allocating on the heap. + pub fn one(packet: OutgoingPacket<'a>) -> Self { + Self { + data: ZeroOrMore::one(packet), + } + } + + /// Stores multiple packets, allocated on the heap. + pub fn many(packets: VecDeque>) -> Self { + Self { + data: ZeroOrMore::many(packets), + } + } +} + +impl<'a> IntoIterator for OutgoingPackets<'a> { + type Item = OutgoingPacket<'a>; + type IntoIter = ZeroOrMore; + + fn into_iter(self) -> Self::IntoIter { + self.data + } +} + +/// Stores parsed packets with their types, that was received from network, implements `IntoIterator` for convenience. +pub struct IncomingPackets { + data: ZeroOrMore<(Packet, PacketType)>, +} + +impl IncomingPackets { + /// No packets are stored + pub fn zero() -> Self { + Self { + data: ZeroOrMore::zero(), + } + } + + /// Stores only one packet, without allocating on the heap. + pub fn one(packet: Packet, packet_type: PacketType) -> Self { + Self { + data: ZeroOrMore::one((packet, packet_type)), + } + } + + /// Stores multiple packets, allocated on the heap. + pub fn many(vec: VecDeque<(Packet, PacketType)>) -> Self { + Self { + data: ZeroOrMore::many(vec), + } + } +} + +impl IntoIterator for IncomingPackets { + type Item = (Packet, PacketType); + type IntoIter = ZeroOrMore; + + fn into_iter(self) -> Self::IntoIter { + self.data + } +}