diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml new file mode 100644 index 00000000..30c7c47f --- /dev/null +++ b/.github/workflows/build_and_test.yml @@ -0,0 +1,20 @@ +name: Rust + +on: + push: + branches: + - master + +jobs: + test: + name: Running on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macOS-latest] + steps: + - uses: actions/checkout@v1 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose diff --git a/examples/server_client.rs b/examples/server_client.rs index 3d47793b..b2db7207 100644 --- a/examples/server_client.rs +++ b/examples/server_client.rs @@ -12,9 +12,9 @@ const SERVER: &str = "127.0.0.1:12351"; fn server() -> Result<(), ErrorKind> { let mut socket = Socket::bind(SERVER)?; let (sender, receiver) = (socket.get_packet_sender(), socket.get_event_receiver()); - let _thread = thread::spawn(move || socket.start_polling()); loop { + socket.manual_poll(Instant::now()); if let Ok(event) = receiver.recv() { match event { SocketEvent::Packet(packet) => { diff --git a/src/config.rs b/src/config.rs index 26a85b10..d7908cdd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,6 @@ +use crate::log; use crate::net::constants::{DEFAULT_MTU, FRAGMENT_SIZE_DEFAULT, MAX_FRAGMENTS_DEFAULT}; -use std::{default::Default, time::Duration}; +use std::{default::Default, rc::Rc, time::Duration}; #[derive(Clone, Debug)] /// Contains the configuration options to configure laminar for special use-cases. @@ -8,6 +9,9 @@ pub struct Config { pub blocking_mode: bool, /// Value which can specify the amount of time that can pass without hearing from a client before considering them disconnected pub idle_connection_timeout: Duration, + /// Value which specifies at which interval (if at all) a heartbeat should be sent, if no other packet was sent in the meantime. + /// If None, no heartbeats will be sent (the default). + pub heartbeat_interval: Option, /// Value which can specify the maximum size a packet can be in bytes. This value is inclusive of fragmenting; if a packet is fragmented, the total size of the fragments cannot exceed this value. /// /// Recommended value: 16384 @@ -45,6 +49,14 @@ pub struct Config { /// /// Value that specifies how long we should block polling for socket events, in milliseconds. Defaults to `1ms`. pub socket_polling_timeout: Option, + /// The maximum amount of reliable packets in flight on this connection before we drop the + /// connection. + /// + /// When we send a reliable packet, it is stored locally until an acknowledgement comes back to + /// us, if that store grows to a size + pub max_packets_in_flight: u16, + /// Logger used for this instance of laminar. See [log::LaminarLogger] for more details. + pub logger: Rc, } impl Default for Config { @@ -52,6 +64,7 @@ impl Default for Config { Self { blocking_mode: false, idle_connection_timeout: Duration::from_secs(5), + heartbeat_interval: None, max_packet_size: (MAX_FRAGMENTS_DEFAULT * FRAGMENT_SIZE_DEFAULT) as usize, max_fragments: MAX_FRAGMENTS_DEFAULT as u8, fragment_size: FRAGMENT_SIZE_DEFAULT, @@ -61,6 +74,8 @@ impl Default for Config { rtt_max_value: 250, socket_event_buffer_size: 1024, socket_polling_timeout: Some(Duration::from_millis(1)), + max_packets_in_flight: 512, + logger: Rc::new(log::DefaultLogger), } } } diff --git a/src/error.rs b/src/error.rs index 755306b8..c6203219 100644 --- a/src/error.rs +++ b/src/error.rs @@ -185,6 +185,6 @@ mod tests { #[test] fn able_to_box_errors() { - let _: Box = Box::new(ErrorKind::CouldNotReadHeader("".into())); + let _: Box = Box::new(ErrorKind::CouldNotReadHeader("".into())); } } diff --git a/src/infrastructure/acknowledgment.rs b/src/infrastructure/acknowledgment.rs index 381fe0cf..b268d58b 100644 --- a/src/infrastructure/acknowledgment.rs +++ b/src/infrastructure/acknowledgment.rs @@ -1,6 +1,6 @@ use crate::packet::OrderingGuarantee; use crate::packet::SequenceNumber; -use crate::sequence_buffer::{sequence_less_than, SequenceBuffer}; +use crate::sequence_buffer::{sequence_greater_than, sequence_less_than, SequenceBuffer}; use std::collections::HashMap; const REDUNDANT_PACKET_ACKS_SIZE: u16 = 32; @@ -31,6 +31,11 @@ impl AcknowledgmentHandler { } } + /// Get the current number of not yet acknowledged packets + pub fn packets_in_flight(&self) -> u16 { + self.sent_packets.len() as u16 + } + /// Returns the next sequence number to send. pub fn local_sequence_num(&self) -> SequenceNumber { self.sequence_number @@ -71,7 +76,11 @@ impl AcknowledgmentHandler { remote_ack_seq: u16, mut remote_ack_field: u32, ) { - self.remote_ack_sequence_num = remote_ack_seq; + // We must ensure that self.remote_ack_sequence_num is always increasing (with wrapping) + if sequence_greater_than(remote_ack_seq, self.remote_ack_sequence_num) { + self.remote_ack_sequence_num = remote_ack_seq; + } + self.received_packets .insert(remote_seq_num, ReceivedPacket {}); @@ -285,4 +294,26 @@ mod test { assert_eq!(handler.sent_packets.len(), 1); assert_eq!(handler.local_sequence_num(), 1); } + + #[test] + fn remote_ack_seq_must_never_be_less_than_prior() { + let mut handler = AcknowledgmentHandler::new(); + // Second packet received before first + handler.process_incoming(1, 1, 1); + assert_eq!(handler.remote_ack_sequence_num, 1); + // First packet received + handler.process_incoming(0, 0, 0); + assert_eq!(handler.remote_ack_sequence_num, 1); + } + + #[test] + fn remote_ack_seq_must_never_be_less_than_prior_wrap_boundary() { + let mut handler = AcknowledgmentHandler::new(); + // newer packet received before first + handler.process_incoming(1, 0, 1); + assert_eq!(handler.remote_ack_sequence_num, 0); + // earlier packet received + handler.process_incoming(0, u16::max_value(), 0); + assert_eq!(handler.remote_ack_sequence_num, 0); + } } diff --git a/src/infrastructure/arranging.rs b/src/infrastructure/arranging.rs index 806844bd..ca7af85e 100644 --- a/src/infrastructure/arranging.rs +++ b/src/infrastructure/arranging.rs @@ -54,7 +54,7 @@ pub trait Arranging { /// If the `incoming_offset` satisfies the arranging algorithm it returns `Some` with the passed item. fn arrange( &mut self, - incoming_index: usize, + incoming_index: u16, item: Self::ArrangingItem, ) -> Option; } diff --git a/src/infrastructure/arranging/ordering.rs b/src/infrastructure/arranging/ordering.rs index 7f7d1e6b..6aa5193c 100644 --- a/src/infrastructure/arranging/ordering.rs +++ b/src/infrastructure/arranging/ordering.rs @@ -132,9 +132,9 @@ pub struct OrderingStream { _stream_id: u8, // the storage for items that are waiting for older items to arrive. // the items will be stored by key and value where the key is the incoming index and the value is the item value. - storage: HashMap, + storage: HashMap, // the next expected item index. - expected_index: usize, + expected_index: u16, // unique identifier which should be used for ordering on a different stream e.g. the remote endpoint. unique_item_identifier: u16, } @@ -159,7 +159,7 @@ impl OrderingStream { pub fn with_capacity(size: usize, stream_id: u8) -> OrderingStream { OrderingStream { storage: HashMap::with_capacity(size), - expected_index: 1, + expected_index: 0, _stream_id: stream_id, unique_item_identifier: 0, } @@ -173,14 +173,15 @@ impl OrderingStream { /// Returns the next expected index. #[cfg(test)] - pub fn expected_index(&self) -> usize { + pub fn expected_index(&self) -> u16 { self.expected_index } /// Returns the unique identifier which should be used for ordering on the other stream e.g. the remote endpoint. pub fn new_item_identifier(&mut self) -> SequenceNumber { + let id = self.unique_item_identifier; self.unique_item_identifier = self.unique_item_identifier.wrapping_add(1); - self.unique_item_identifier + id } /// Returns an iterator of stored items. @@ -216,6 +217,12 @@ impl OrderingStream { } } +fn is_u16_within_half_window_from_start(start: u16, incoming: u16) -> bool { + // Check (with wrapping) if the incoming value lies within the next u16::max_value()/2 from + // start. + incoming.wrapping_sub(start) <= u16::max_value() / 2 + 1 +} + impl Arranging for OrderingStream { type ArrangingItem = T; @@ -234,18 +241,18 @@ impl Arranging for OrderingStream { /// This can only happen in cases where we have a duplicated package. Again we don't give anything back. /// /// # Remark - /// - When we receive an item there is a possibility that a gab is filled and one or more items will could be returned. + /// - When we receive an item there is a possibility that a gap is filled and one or more items will could be returned. /// You should use the `iter_mut` instead for reading the items in order. /// However the item given to `arrange` will be returned directly when it matches the `expected_index`. fn arrange( &mut self, - incoming_offset: usize, + incoming_offset: u16, item: Self::ArrangingItem, ) -> Option { if incoming_offset == self.expected_index { - self.expected_index += 1; + self.expected_index = self.expected_index.wrapping_add(1); Some(item) - } else if incoming_offset > self.expected_index { + } else if is_u16_within_half_window_from_start(self.expected_index, incoming_offset) { self.storage.insert(incoming_offset, item); None } else { @@ -270,8 +277,8 @@ impl Arranging for OrderingStream { /// - Iterator mutates the `expected_index`. /// - You can't use this iterator for iterating trough all cached values. pub struct IterMut<'a, T> { - items: &'a mut HashMap, - expected_index: &'a mut usize, + items: &'a mut HashMap, + expected_index: &'a mut u16, } impl<'a, T> Iterator for IterMut<'a, T> { @@ -283,7 +290,7 @@ impl<'a, T> Iterator for IterMut<'a, T> { match self.items.remove(&self.expected_index) { None => None, Some(e) => { - *self.expected_index += 1; + *self.expected_index = self.expected_index.wrapping_add(1); Some(e) } } @@ -292,16 +299,16 @@ impl<'a, T> Iterator for IterMut<'a, T> { #[cfg(test)] mod tests { - use super::{Arranging, ArrangingSystem, OrderingSystem}; + use super::{is_u16_within_half_window_from_start, Arranging, ArrangingSystem, OrderingSystem}; #[derive(Debug, PartialEq, Clone)] struct Packet { - pub sequence: usize, + pub sequence: u16, pub ordering_stream: u8, } impl Packet { - fn new(sequence: usize, ordering_stream: u8) -> Packet { + fn new(sequence: u16, ordering_stream: u8) -> Packet { Packet { sequence, ordering_stream, @@ -314,7 +321,7 @@ mod tests { let mut system: OrderingSystem = OrderingSystem::new(); let stream = system.get_or_create_stream(1); - assert_eq!(stream.expected_index(), 1); + assert_eq!(stream.expected_index(), 0); assert_eq!(stream.stream_id(), 1); } @@ -328,6 +335,53 @@ mod tests { assert_eq!(stream.stream_id(), 1); } + #[test] + fn packet_wraps_around_offset() { + let mut system: OrderingSystem<()> = OrderingSystem::new(); + + let stream = system.get_or_create_stream(1); + for idx in 0..=65500 { + assert![stream.arrange(idx, ()).is_some()]; + } + assert![stream.arrange(123, ()).is_none()]; + for idx in 65501..=65535u16 { + assert![stream.arrange(idx, ()).is_some()]; + } + assert![stream.arrange(0, ()).is_some()]; + for idx in 1..123 { + assert![stream.arrange(idx, ()).is_some()]; + } + assert![stream.iter_mut().next().is_some()]; + } + + #[test] + fn exactly_half_u16_packet_is_stored() { + let mut system: OrderingSystem = OrderingSystem::new(); + + let stream = system.get_or_create_stream(1); + for idx in 0..=32766 { + assert![stream.arrange(idx, idx).is_some()]; + } + assert![stream.arrange(32768, 32768).is_none()]; + assert![stream.arrange(32767, 32767).is_some()]; + assert_eq![Some(32768), stream.iter_mut().next()]; + assert_eq![None, stream.iter_mut().next()]; + } + + #[test] + fn u16_forward_half() { + assert![!is_u16_within_half_window_from_start(0, 65535)]; + assert![!is_u16_within_half_window_from_start(0, 32769)]; + + assert![is_u16_within_half_window_from_start(0, 32768)]; + assert![is_u16_within_half_window_from_start(0, 32767)]; + + assert![is_u16_within_half_window_from_start(32767, 65535)]; + assert![!is_u16_within_half_window_from_start(32766, 65535)]; + assert![is_u16_within_half_window_from_start(32768, 65535)]; + assert![is_u16_within_half_window_from_start(32769, 0)]; + } + #[test] fn can_iterate() { let mut system: OrderingSystem = OrderingSystem::new(); @@ -335,21 +389,21 @@ mod tests { system.get_or_create_stream(1); let stream = system.get_or_create_stream(1); + let stub_packet0 = Packet::new(0, 1); let stub_packet1 = Packet::new(1, 1); let stub_packet2 = Packet::new(2, 1); let stub_packet3 = Packet::new(3, 1); let stub_packet4 = Packet::new(4, 1); - let stub_packet5 = Packet::new(5, 1); { assert_eq!( - stream.arrange(1, stub_packet1.clone()).unwrap(), - stub_packet1 + stream.arrange(0, stub_packet0.clone()).unwrap(), + stub_packet0 ); - stream.arrange(4, stub_packet4.clone()).is_none(); - stream.arrange(5, stub_packet5.clone()).is_none(); - stream.arrange(3, stub_packet3.clone()).is_none(); + assert![stream.arrange(3, stub_packet3.clone()).is_none()]; + assert![stream.arrange(4, stub_packet4.clone()).is_none()]; + assert![stream.arrange(2, stub_packet2.clone()).is_none()]; } { let mut iterator = stream.iter_mut(); @@ -359,17 +413,17 @@ mod tests { } { assert_eq!( - stream.arrange(2, stub_packet2.clone()).unwrap(), - stub_packet2 + stream.arrange(1, stub_packet1.clone()).unwrap(), + stub_packet1 ); } { // since we processed packet 2 by now we should be able to iterate and get back: 3,4,5; let mut iterator = stream.iter_mut(); + assert_eq!(iterator.next().unwrap(), stub_packet2); assert_eq!(iterator.next().unwrap(), stub_packet3); assert_eq!(iterator.next().unwrap(), stub_packet4); - assert_eq!(iterator.next().unwrap(), stub_packet5); } } @@ -378,13 +432,13 @@ mod tests { ( [$( $x:expr ),*] , [$( $y:expr),*] , $stream_id:expr) => { { // initialize vector of given range on the left. - let mut before: Vec = Vec::new(); + let mut before: Vec = Vec::new(); $( before.push($x); )* // initialize vector of given range on the right. - let mut after: Vec = Vec::new(); + let mut after: Vec = Vec::new(); $( after.push($y); )* @@ -428,26 +482,26 @@ mod tests { #[test] fn expect_right_order() { // we order on stream 1 - assert_order!([1, 3, 5, 4, 2], [1, 2, 3, 4, 5], 1); - assert_order!([1, 5, 4, 3, 2], [1, 2, 3, 4, 5], 1); - assert_order!([5, 3, 4, 2, 1], [1, 2, 3, 4, 5], 1); - assert_order!([4, 3, 2, 1, 5], [1, 2, 3, 4, 5], 1); - assert_order!([2, 1, 4, 3, 5], [1, 2, 3, 4, 5], 1); - assert_order!([5, 2, 1, 4, 3], [1, 2, 3, 4, 5], 1); - assert_order!([3, 2, 4, 1, 5], [1, 2, 3, 4, 5], 1); - assert_order!([2, 1, 4, 3, 5], [1, 2, 3, 4, 5], 1); + assert_order!([0, 2, 4, 3, 1], [0, 1, 2, 3, 4], 1); + assert_order!([0, 4, 3, 2, 1], [0, 1, 2, 3, 4], 1); + assert_order!([4, 2, 3, 1, 0], [0, 1, 2, 3, 4], 1); + assert_order!([3, 2, 1, 0, 4], [0, 1, 2, 3, 4], 1); + assert_order!([1, 0, 3, 2, 4], [0, 1, 2, 3, 4], 1); + assert_order!([4, 1, 0, 3, 2], [0, 1, 2, 3, 4], 1); + assert_order!([2, 1, 3, 0, 4], [0, 1, 2, 3, 4], 1); + assert_order!([1, 0, 3, 2, 4], [0, 1, 2, 3, 4], 1); } #[test] fn order_on_multiple_streams() { // we order on streams [1...8] - assert_order!([1, 3, 5, 4, 2], [1, 2, 3, 4, 5], 1); - assert_order!([1, 5, 4, 3, 2], [1, 2, 3, 4, 5], 2); - assert_order!([5, 3, 4, 2, 1], [1, 2, 3, 4, 5], 3); - assert_order!([4, 3, 2, 1, 5], [1, 2, 3, 4, 5], 4); - assert_order!([2, 1, 4, 3, 5], [1, 2, 3, 4, 5], 5); - assert_order!([5, 2, 1, 4, 3], [1, 2, 3, 4, 5], 6); - assert_order!([3, 2, 4, 1, 5], [1, 2, 3, 4, 5], 7); - assert_order!([2, 1, 4, 3, 5], [1, 2, 3, 4, 5], 8); + assert_order!([0, 2, 4, 3, 1], [0, 1, 2, 3, 4], 1); + assert_order!([0, 4, 3, 2, 1], [0, 1, 2, 3, 4], 2); + assert_order!([4, 2, 3, 1, 0], [0, 1, 2, 3, 4], 3); + assert_order!([3, 2, 1, 0, 4], [0, 1, 2, 3, 4], 4); + assert_order!([1, 0, 3, 2, 4], [0, 1, 2, 3, 4], 5); + assert_order!([4, 1, 0, 3, 2], [0, 1, 2, 3, 4], 6); + assert_order!([2, 1, 3, 0, 4], [0, 1, 2, 3, 4], 7); + assert_order!([1, 0, 3, 2, 4], [0, 1, 2, 3, 4], 8); } } diff --git a/src/infrastructure/arranging/sequencing.rs b/src/infrastructure/arranging/sequencing.rs index 429d48ea..465e8411 100644 --- a/src/infrastructure/arranging/sequencing.rs +++ b/src/infrastructure/arranging/sequencing.rs @@ -70,7 +70,7 @@ pub struct SequencingStream { // the id of this stream. _stream_id: u8, // the highest seen item index. - top_index: usize, + top_index: u16, // I need `PhantomData`, otherwise, I can't use a generic in the `Arranging` implementation because `T` is not constrained. phantom: PhantomData, // unique identifier which should be used for ordering on an other stream e.g. the remote endpoint. @@ -98,11 +98,18 @@ impl SequencingStream { /// Returns the unique identifier which should be used for ordering on an other stream e.g. the remote endpoint. pub fn new_item_identifier(&mut self) -> SequenceNumber { + let id = self.unique_item_identifier; self.unique_item_identifier = self.unique_item_identifier.wrapping_add(1); - self.unique_item_identifier + id } } +fn is_u16_within_half_window_from_start(start: u16, incoming: u16) -> bool { + // Check (with wrapping) if the incoming value lies within the next u16::max_value()/2 from + // start. + incoming.wrapping_sub(start) <= u16::max_value() / 2 + 1 +} + impl Arranging for SequencingStream { type ArrangingItem = T; @@ -125,10 +132,10 @@ impl Arranging for SequencingStream { /// - None is returned when an old packet is received. fn arrange( &mut self, - incoming_index: usize, + incoming_index: u16, item: Self::ArrangingItem, ) -> Option { - if incoming_index > self.top_index { + if is_u16_within_half_window_from_start(self.top_index, incoming_index) { self.top_index = incoming_index; return Some(item); } @@ -142,12 +149,12 @@ mod tests { #[derive(Debug, PartialEq, Clone)] struct Packet { - pub sequence: usize, + pub sequence: u16, pub ordering_stream: u8, } impl Packet { - fn new(sequence: usize, ordering_stream: u8) -> Packet { + fn new(sequence: u16, ordering_stream: u8) -> Packet { Packet { sequence, ordering_stream, @@ -178,13 +185,13 @@ mod tests { ( [$( $x:expr ),*], [$( $y:expr),*], $stream_id:expr) => { { // initialize vector of given range on the left. - let mut before: Vec = Vec::new(); + let mut before: Vec = Vec::new(); $( before.push($x); )* // initialize vector of given range on the right. - let mut after: Vec = Vec::new(); + let mut after: Vec = Vec::new(); $( after.push($y); )* diff --git a/src/lib.rs b/src/lib.rs index 1f830b77..d5574255 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ mod config; mod either; mod error; mod infrastructure; +pub mod log; mod net; mod packet; mod protocol_version; diff --git a/src/log.rs b/src/log.rs new file mode 100644 index 00000000..1b912d79 --- /dev/null +++ b/src/log.rs @@ -0,0 +1,141 @@ +//! Logging adapter for Laminar +//! +//! This module implements a simple, threaded-logger-friendly logging adapter. Logging adapters are +//! used to attach an arbitrary logger into Laminar. +use std::fmt; +use std::sync::Arc; + +/// Logger trait for laminar +/// +/// Any user of Laminar can implement this trait to attach their favorite logger to an instance of +/// laminar. The log levels correspond to the same log levels as in the `log` crate. +pub trait LaminarLogger { + /// Log a trace message + fn trace(&self, disp: Displayer); + /// Log a debug message + fn debug(&self, disp: Displayer); + /// Log an info message + fn info(&self, disp: Displayer); + /// Log a warning message + fn warn(&self, disp: Displayer); + /// Log an error message + fn error(&self, disp: Displayer); +} + +// --- + +/// Holds a handle to a formatter function while implementing the [fmt::Display] trait. +pub struct Displayer { + data: Arc ::std::fmt::Result + Send + Sync>, +} + +impl Displayer { + pub(crate) fn new( + delegate: Arc ::std::fmt::Result + Send + Sync>, + ) -> Self { + Self { data: delegate } + } +} + +impl fmt::Display for Displayer { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + (self.data)(f) + } +} + +// --- + +pub(crate) struct DefaultLogger; + +impl LaminarLogger for DefaultLogger { + fn trace(&self, _: Displayer) {} + fn debug(&self, _: Displayer) {} + fn info(&self, _: Displayer) {} + fn warn(&self, _: Displayer) {} + fn error(&self, _: Displayer) {} +} + +// --- + +impl fmt::Debug for dyn LaminarLogger { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write![f, "LaminarLogger"] + } +} + +// --- + +/// Format-friendly form of [log::LaminarLogger::trace] +#[macro_export] +macro_rules! trace { + ($logger:expr, $($fmt:expr),* $(,)?) => {{ + $logger.trace($crate::log::Displayer::new(::std::sync::Arc::new(move |f: &mut ::std::fmt::Formatter| { write![f, $($fmt),*] }) )); + }}; +} + +/// Format-friendly form of [log::LaminarLogger::debug] +#[macro_export] +macro_rules! debug { + ($logger:expr, $($fmt:expr),* $(,)?) => {{ + $logger.debug($crate::log::Displayer::new(::std::sync::Arc::new(move |f: &mut ::std::fmt::Formatter| { write![f, $($fmt),*] }) )); + }}; +} + +/// Format-friendly form of [log::LaminarLogger::info] +#[macro_export] +macro_rules! info { + ($logger:expr, $($fmt:expr),* $(,)?) => {{ + $logger.info($crate::log::Displayer::new(::std::sync::Arc::new(move |f: &mut ::std::fmt::Formatter| { write![f, $($fmt),*] }) )); + }}; +} + +/// Format-friendly form of [log::LaminarLogger::warn] +#[macro_export] +macro_rules! warn { + ($logger:expr, $($fmt:expr),* $(,)?) => {{ + $logger.warn($crate::log::Displayer::new(::std::sync::Arc::new(move |f: &mut ::std::fmt::Formatter| { write![f, $($fmt),*] }) )); + }}; +} + +/// Format-friendly form of [log::LaminarLogger::error] +#[macro_export] +macro_rules! error { + ($logger:expr, $($fmt:expr),* $(,)?) => {{ + $logger.error($crate::log::Displayer::new(::std::sync::Arc::new(move |f: &mut ::std::fmt::Formatter| { write![f, $($fmt),*] }) )); + }}; +} + +#[cfg(test)] +mod tests { + #[test] + fn log_adapter() { + use crate::log::{Displayer, LaminarLogger}; + use std::{rc::Rc, sync::Arc}; + + let mut cfg = Config::default(); + + struct MyAdapter {} + + impl LaminarLogger for MyAdapter { + fn trace(&self, disp: Displayer) { + println!["trace: {}", disp]; + } + fn debug(&self, disp: Displayer) { + println!["debug: {}", disp]; + } + fn info(&self, disp: Displayer) { + println!["info: {}", disp]; + } + fn warn(&self, disp: Displayer) { + println!["warn: {}", disp]; + } + fn error(&self, disp: Displayer) { + println!["An error! {}", disp]; + } + } + + cfg.logger = Rc::new(MyAdapter {}); + + Socket::bind_any_with_config(cfg).unwrap(); + } +} diff --git a/src/net/connection.rs b/src/net/connection.rs index d9a17789..2ee26360 100644 --- a/src/net/connection.rs +++ b/src/net/connection.rs @@ -67,6 +67,27 @@ impl ActiveConnections { .collect() } + /// Get a list of addresses of dead connections + pub fn dead_connections(&mut self) -> Vec { + self.connections + .iter() + .filter(|(_, connection)| connection.should_be_dropped()) + .map(|(address, _)| *address) + .collect() + } + + /// Check for and return `VirtualConnection`s which have not sent anything for a duration of at least `heartbeat_interval`. + pub fn heartbeat_required_connections( + &mut self, + heartbeat_interval: Duration, + time: Instant, + ) -> impl Iterator { + self.connections + .iter_mut() + .filter(move |(_, connection)| connection.last_sent(time) >= heartbeat_interval) + .map(|(_, connection)| connection) + } + /// Returns true if the given connection exists. pub fn exists(&self, address: &SocketAddr) -> bool { self.connections.contains_key(&address) diff --git a/src/net/socket.rs b/src/net/socket.rs index fa1e0015..1e80e2b0 100644 --- a/src/net/socket.rs +++ b/src/net/socket.rs @@ -1,4 +1,5 @@ use crate::either::Either::{Left, Right}; +use crate::error; use crate::{ config::Config, error::{ErrorKind, Result}, @@ -6,7 +7,6 @@ use crate::{ packet::{DeliveryGuarantee, Outgoing, Packet}, }; use crossbeam_channel::{self, unbounded, Receiver, SendError, Sender, TryRecvError}; -use log::error; use std::{ self, io, net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket}, @@ -142,7 +142,10 @@ impl Socket { match self.recv_from(time) { Ok(UdpSocketState::MaybeMore) => continue, Ok(UdpSocketState::MaybeEmpty) => break, - Err(e) => error!("Encountered an error receiving data: {:?}", e), + Err(e) => error!( + self.config.logger, + "Encountered an error receiving data: {:?}", e + ), } } @@ -151,14 +154,36 @@ impl Socket { if let Err(e) = self.send_to(p, time) { match e { ErrorKind::IOError(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - _ => error!("There was an error sending packet: {:?}", e), + _ => error!( + self.config.logger, + "There was an error sending packet: {:?}", e + ), } } } - // Finally check for idle clients + // Check for idle clients if let Err(e) = self.handle_idle_clients(time) { - error!("Encountered an error when sending TimeoutEvent: {:?}", e); + error!( + self.config.logger, + "Encountered an error when sending TimeoutEvent: {:?}", e + ); + } + + // Handle any dead clients + self.handle_dead_clients().expect("Internal laminar error"); + + // Finally send heartbeat packets to connections that require them, if enabled + if let Some(heartbeat_interval) = self.config.heartbeat_interval { + if let Err(e) = self.send_heartbeat_packets(heartbeat_interval, time) { + match e { + ErrorKind::IOError(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + _ => error!( + self.config.logger, + "There was an error sending a heartbeat packet: {:?}", e + ), + } + } } } @@ -172,6 +197,18 @@ impl Socket { Ok(self.socket.local_addr()?) } + /// Iterate through the dead connections and disconnect them by removing them from the + /// connection map while informing the user of this by sending an event. + fn handle_dead_clients(&mut self) -> Result<()> { + let dead_addresses = self.connections.dead_connections(); + for address in dead_addresses { + self.connections.remove_connection(&address); + self.event_sender.send(SocketEvent::Timeout(address))?; + } + + Ok(()) + } + /// Iterate through all of the idle connections based on `idle_connection_timeout` config and /// remove them from the active connections. For each connection removed, we will send a /// `SocketEvent::TimeOut` event to the `event_sender` channel. @@ -187,6 +224,35 @@ impl Socket { Ok(()) } + /// Iterate over all connections which have not sent a packet for a duration of at least + /// `heartbeat_interval` (from config), and send a heartbeat packet to each. + fn send_heartbeat_packets( + &mut self, + heartbeat_interval: Duration, + time: Instant, + ) -> Result { + let heartbeat_packets_and_addrs = self + .connections + .heartbeat_required_connections(heartbeat_interval, time) + .map(|connection| { + ( + connection.create_and_process_heartbeat(time), + connection.remote_address, + ) + }) + .collect::>(); + + let mut bytes_sent = 0; + + for (heartbeat_packet, address) in heartbeat_packets_and_addrs { + if self.should_send_packet() { + bytes_sent += self.send_packet(&address, &heartbeat_packet.contents())?; + } + } + + Ok(bytes_sent) + } + // Serializes and sends a `Packet` on the socket. On success, returns the number of bytes written. fn send_to(&mut self, packet: Packet, time: Instant) -> Result { let connection = @@ -266,7 +332,11 @@ impl Socket { } Err(e) => { if e.kind() != io::ErrorKind::WouldBlock { - error!("Encountered an error receiving data: {:?}", e); + let err = format!["{:?}", e]; + error!( + self.config.logger, + "Encountered an error receiving data: {}", err + ); return Err(e.into()); } else { return Ok(UdpSocketState::MaybeEmpty); @@ -554,7 +624,7 @@ mod tests { server.forget_all_incoming_packets(); // Send a packet that the server receives - for id in 0..36 { + for id in 0..35 { client .send(create_ordered_packet(id, "127.0.0.1:12333")) .unwrap(); @@ -579,11 +649,14 @@ mod tests { #[test] fn do_not_duplicate_sequenced_packets_when_received() { - let server_addr = "127.0.0.1:12325".parse::().unwrap(); - let client_addr = "127.0.0.1:12326".parse::().unwrap(); + let mut config = Config::default(); - let mut server = Socket::bind(server_addr).unwrap(); - let mut client = Socket::bind(client_addr).unwrap(); + let mut client = Socket::bind_any_with_config(config.clone()).unwrap(); + config.blocking_mode = true; + let mut server = Socket::bind_any_with_config(config).unwrap(); + + let server_addr = server.local_addr().unwrap(); + let client_addr = client.local_addr().unwrap(); let time = Instant::now(); @@ -592,10 +665,9 @@ mod tests { .send(Packet::reliable_sequenced(server_addr, vec![id], None)) .unwrap(); client.manual_poll(time); + server.manual_poll(time); } - server.manual_poll(time); - let mut seen = HashSet::new(); while let Some(message) = server.recv() { @@ -615,6 +687,91 @@ mod tests { assert_eq![100, seen.len()]; } + #[test] + fn more_than_65536_sequenced_packets() { + let mut config = Config::default(); + + let mut client = Socket::bind_any_with_config(config.clone()).unwrap(); + config.blocking_mode = true; + let mut server = Socket::bind_any_with_config(config).unwrap(); + + let server_addr = server.local_addr().unwrap(); + let client_addr = client.local_addr().unwrap(); + + // Acknowledge the client + server + .send(Packet::unreliable(client_addr, vec![0])) + .unwrap(); + + let time = Instant::now(); + + for id in 0..65536 + 100 { + client + .send(Packet::unreliable_sequenced( + server_addr, + id.to_string().as_bytes().to_vec(), + None, + )) + .unwrap(); + client.manual_poll(time); + server.manual_poll(time); + } + + let mut cnt = 0; + while let Some(message) = server.recv() { + match message { + SocketEvent::Connect(_) => {} + SocketEvent::Packet(packet) => { + cnt += 1; + } + SocketEvent::Timeout(_) => { + panic!["This should not happen, as we've not advanced time"]; + } + } + } + assert_eq![65536 + 100, cnt]; + } + + #[test] + fn sequenced_packets_pathological_case() { + let mut config = Config::default(); + + config.max_packets_in_flight = 100; + let mut client = Socket::bind_any_with_config(config.clone()).unwrap(); + config.blocking_mode = true; + let mut server = Socket::bind_any_with_config(config).unwrap(); + + let server_addr = server.local_addr().unwrap(); + + let time = Instant::now(); + + for id in 0..101 { + client + .send(Packet::reliable_sequenced( + server_addr, + id.to_string().as_bytes().to_vec(), + None, + )) + .unwrap(); + client.manual_poll(time); + + while let Some(event) = client.recv() { + match event { + SocketEvent::Timeout(remote_addr) => { + assert_eq![100, id]; + assert_eq![remote_addr, server_addr]; + return; + } + _ => { + panic!["No other event possible"]; + } + } + } + } + + panic!["Should have received a timeout event"]; + } + #[test] fn manual_polling_socket() { let mut server = Socket::bind("127.0.0.1:12339".parse::().unwrap()).unwrap(); @@ -737,52 +894,111 @@ mod tests { let mut config = Config::default(); config.idle_connection_timeout = Duration::from_millis(1); - let mut server = Socket::bind("127.0.0.1:12347".parse::().unwrap()).unwrap(); - let mut client = Socket::bind("127.0.0.1:12346".parse::().unwrap()).unwrap(); + let server_addr = "127.0.0.1:12347".parse::().unwrap(); + let client_addr = "127.0.0.1:12346".parse::().unwrap(); + + let mut server = Socket::bind_with_config(server_addr, config.clone()).unwrap(); + let mut client = Socket::bind_with_config(client_addr, config.clone()).unwrap(); client - .send(Packet::unreliable( - "127.0.0.1:12347".parse().unwrap(), - vec![0, 1, 2], - )) + .send(Packet::unreliable(server_addr, vec![0, 1, 2])) .unwrap(); let now = Instant::now(); client.manual_poll(now); server.manual_poll(now); + assert_eq!(server.recv().unwrap(), SocketEvent::Connect(client_addr)); assert_eq!( server.recv().unwrap(), - SocketEvent::Connect("127.0.0.1:12346".parse().unwrap()) + SocketEvent::Packet(Packet::unreliable(client_addr, vec![0, 1, 2])) + ); + + // Acknowledge the client + server + .send(Packet::unreliable(client_addr, vec![])) + .unwrap(); + + server.manual_poll(now); + client.manual_poll(now); + + // Make sure the connection was successful on the client side + assert_eq!( + client.recv().unwrap(), + SocketEvent::Packet(Packet::unreliable(server_addr, vec![])) ); + + // Give just enough time for no timeout events to occur (yet) + server.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1)); + client.manual_poll(now + config.idle_connection_timeout - Duration::from_millis(1)); + + assert_eq!(server.recv(), None); + assert_eq!(client.recv(), None); + + // Give enough time for timeouts to be detected + server.manual_poll(now + config.idle_connection_timeout); + client.manual_poll(now + config.idle_connection_timeout); + + assert_eq!(server.recv().unwrap(), SocketEvent::Timeout(client_addr)); + assert_eq!(client.recv().unwrap(), SocketEvent::Timeout(server_addr)); + } + + #[test] + fn heartbeats_work() { + let mut config = Config::default(); + config.idle_connection_timeout = Duration::from_millis(10); + config.heartbeat_interval = Some(Duration::from_millis(4)); + + let server_addr = "127.0.0.1:12351".parse::().unwrap(); + let client_addr = "127.0.0.1:12352".parse::().unwrap(); + + // Start up a server and a client. + let mut server = Socket::bind_with_config(server_addr, config.clone()).unwrap(); + let mut client = Socket::bind_with_config(client_addr, config.clone()).unwrap(); + + // Initiate a connection + client + .send(Packet::unreliable(server_addr, vec![0, 1, 2])) + .unwrap(); + + let now = Instant::now(); + client.manual_poll(now); + server.manual_poll(now); + + // Make sure the connection was successful on the server side + assert_eq!(server.recv().unwrap(), SocketEvent::Connect(client_addr)); assert_eq!( server.recv().unwrap(), - SocketEvent::Packet(Packet::unreliable( - "127.0.0.1:12346".parse().unwrap(), - vec![0, 1, 2] - )) + SocketEvent::Packet(Packet::unreliable(client_addr, vec![0, 1, 2])) ); // Acknowledge the client + // This way, the server also knows about the connection and sends heartbeats server - .send(Packet::unreliable( - "127.0.0.1:12346".parse().unwrap(), - vec![], - )) + .send(Packet::unreliable(client_addr, vec![])) .unwrap(); server.manual_poll(now); client.manual_poll(now); - server.manual_poll(now + Duration::new(5, 0)); + // Make sure the connection was successful on the client side assert_eq!( - server.recv().unwrap(), - SocketEvent::Timeout("127.0.0.1:12346".parse().unwrap()) + client.recv().unwrap(), + SocketEvent::Packet(Packet::unreliable(server_addr, vec![])) ); - } - const LOCAL_ADDR: &str = "127.0.0.1:13000"; - const REMOTE_ADDR: &str = "127.0.0.1:14000"; + // Give time to send heartbeats + client.manual_poll(now + config.heartbeat_interval.unwrap()); + server.manual_poll(now + config.heartbeat_interval.unwrap()); + + // Give time for timeouts to occur if no heartbeats were sent + client.manual_poll(now + config.idle_connection_timeout); + server.manual_poll(now + config.idle_connection_timeout); + + // Assert that no disconnection events occurred + assert_eq!(client.recv(), None); + assert_eq!(server.recv(), None); + } fn create_test_packet(id: u8, addr: &str) -> Packet { let payload = vec![id]; @@ -801,6 +1017,9 @@ mod tests { #[test] fn multiple_sends_should_start_sending_dropped() { + const LOCAL_ADDR: &str = "127.0.0.1:13000"; + const REMOTE_ADDR: &str = "127.0.0.1:14000"; + // Start up a server and a client. let mut server = Socket::bind(REMOTE_ADDR.parse::().unwrap()).unwrap(); let mut client = Socket::bind(LOCAL_ADDR.parse::().unwrap()).unwrap(); @@ -956,4 +1175,49 @@ mod tests { Socket::bind(format!("127.0.0.1:{}", port).parse::().unwrap()).unwrap(); assert_eq!(port, socket.local_addr().unwrap().port()); } + + #[test] + fn ordered_16_bit_overflow() { + let mut cfg = Config::default(); + + let mut client = Socket::bind_any_with_config(cfg.clone()).unwrap(); + let client_addr = client.local_addr().unwrap(); + + cfg.blocking_mode = false; + let mut server = Socket::bind_any_with_config(cfg).unwrap(); + let server_addr = server.local_addr().unwrap(); + + let time = Instant::now(); + + let mut last_payload = String::new(); + + for idx in 0..100_000u64 { + client + .send(Packet::reliable_ordered( + server_addr, + idx.to_string().as_bytes().to_vec(), + None, + )) + .unwrap(); + + client.manual_poll(time); + + while let Some(_) = client.recv() {} + server + .send(Packet::reliable_ordered(client_addr, vec![123], None)) + .unwrap(); + server.manual_poll(time); + + while let Some(msg) = server.recv() { + match msg { + SocketEvent::Packet(pkt) => { + last_payload = std::str::from_utf8(pkt.payload()).unwrap().to_string(); + } + _ => {} + } + } + } + + assert_eq!["99999", last_payload]; + } } diff --git a/src/net/virtual_connection.rs b/src/net/virtual_connection.rs index 3c323117..a56f5b21 100644 --- a/src/net/virtual_connection.rs +++ b/src/net/virtual_connection.rs @@ -10,8 +10,8 @@ use crate::{ STANDARD_HEADER_SIZE, }, packet::{ - DeliveryGuarantee, OrderingGuarantee, Outgoing, OutgoingPacketBuilder, Packet, - PacketReader, PacketType, SequenceNumber, + DeliveryGuarantee, OrderingGuarantee, Outgoing, OutgoingPacket, OutgoingPacketBuilder, + Packet, PacketReader, PacketType, SequenceNumber, }, SocketEvent, }; @@ -26,6 +26,8 @@ use std::time::{Duration, Instant}; pub struct VirtualConnection { /// Last time we received a packet from this client pub last_heard: Instant, + /// Last time we sent a packet to this client + pub last_sent: Instant, /// The address of the remote endpoint pub remote_address: SocketAddr, @@ -43,6 +45,7 @@ impl VirtualConnection { pub fn new(addr: SocketAddr, config: &Config, time: Instant) -> VirtualConnection { VirtualConnection { last_heard: time, + last_sent: time, remote_address: addr, ordering_system: OrderingSystem::new(), sequencing_system: SequencingSystem::new(), @@ -53,6 +56,11 @@ impl VirtualConnection { } } + /// Determine if this connection should be dropped due to its state + pub fn should_be_dropped(&self) -> bool { + self.acknowledge_handler.packets_in_flight() > self.config.max_packets_in_flight + } + /// Returns a [Duration] representing the interval since we last heard from the client pub fn last_heard(&self, time: Instant) -> Duration { // TODO: Replace with saturating_duration_since once it becomes stable. @@ -60,6 +68,28 @@ impl VirtualConnection { time.duration_since(self.last_heard) } + /// Returns a [Duration] representing the interval since we last sent to the client + pub fn last_sent(&self, time: Instant) -> Duration { + // TODO: Replace with saturating_duration_since once it becomes stable. + // This function panics if the user supplies a time instant earlier than last_heard. + time.duration_since(self.last_sent) + } + + /// This will create a heartbeat packet that is expected to be sent over the network + pub fn create_and_process_heartbeat(&mut self, time: Instant) -> OutgoingPacket<'static> { + self.last_sent = time; + self.congestion_handler + .process_outgoing(self.acknowledge_handler.local_sequence_num(), time); + + OutgoingPacketBuilder::new(&[]) + .with_default_header( + PacketType::Heartbeat, + DeliveryGuarantee::Unreliable, + OrderingGuarantee::None, + ) + .build() + } + /// This will pre-process the given buffer to be sent over the network. pub fn process_outgoing<'a>( &mut self, @@ -114,17 +144,16 @@ impl VirtualConnection { ); if let OrderingGuarantee::Ordered(stream_id) = ordering_guarantee { - let item_identifier = if let Some(item_identifier) = - last_item_identifier - { - item_identifier - } else { - self.ordering_system - .get_or_create_stream( - stream_id.unwrap_or(DEFAULT_ORDERING_STREAM), - ) - .new_item_identifier() as u16 - }; + let item_identifier = + if let Some(item_identifier) = last_item_identifier { + item_identifier + } else { + self.ordering_system + .get_or_create_stream( + stream_id.unwrap_or(DEFAULT_ORDERING_STREAM), + ) + .new_item_identifier() + }; item_identifier_value = Some(item_identifier); @@ -132,17 +161,16 @@ impl VirtualConnection { }; if let OrderingGuarantee::Sequenced(stream_id) = ordering_guarantee { - let item_identifier = if let Some(item_identifier) = - last_item_identifier - { - item_identifier - } else { - self.sequencing_system - .get_or_create_stream( - stream_id.unwrap_or(DEFAULT_SEQUENCING_STREAM), - ) - .new_item_identifier() as u16 - }; + let item_identifier = + if let Some(item_identifier) = last_item_identifier { + item_identifier + } else { + self.sequencing_system + .get_or_create_stream( + stream_id.unwrap_or(DEFAULT_SEQUENCING_STREAM), + ) + .new_item_identifier() + }; item_identifier_value = Some(item_identifier); @@ -190,6 +218,7 @@ impl VirtualConnection { } }; + self.last_sent = time; self.congestion_handler .process_outgoing(self.acknowledge_handler.local_sequence_num(), time); self.acknowledge_handler.process_outgoing( @@ -220,6 +249,12 @@ impl VirtualConnection { return Err(ErrorKind::ProtocolVersionMismatch); } + if header.is_heartbeat() { + // Heartbeat packets are unreliable, unordered and empty packets. + // We already updated our `self.last_heard` time, nothing else to be done. + return Ok(()); + } + match header.delivery_guarantee() { DeliveryGuarantee::Unreliable => { if let OrderingGuarantee::Sequenced(_id) = header.ordering_guarantee() { @@ -232,9 +267,7 @@ impl VirtualConnection { .sequencing_system .get_or_create_stream(arranging_header.stream_id()); - if let Some(packet) = - stream.arrange(arranging_header.arranging_id() as usize, payload) - { + if let Some(packet) = stream.arrange(arranging_header.arranging_id(), payload) { Self::queue_packet( sender, packet, @@ -302,7 +335,7 @@ impl VirtualConnection { .get_or_create_stream(arranging_header.stream_id()); if let Some(packet) = - stream.arrange(arranging_header.arranging_id() as usize, payload) + stream.arrange(arranging_header.arranging_id(), payload) { Self::queue_packet( sender, @@ -324,7 +357,7 @@ impl VirtualConnection { .get_or_create_stream(arranging_header.stream_id()); if let Some(packet) = - stream.arrange(arranging_header.arranging_id() as usize, payload) + stream.arrange(arranging_header.arranging_id(), payload) { Self::queue_packet( sender, @@ -648,7 +681,7 @@ mod tests { PAYLOAD.to_vec(), Some(1), ))), - 1, + 0, ); assert_incoming_with_order( @@ -656,7 +689,7 @@ mod tests { OrderingGuarantee::Ordered(Some(1)), &mut connection, Err(TryRecvError::Empty), - 3, + 2, ); assert_incoming_with_order( @@ -664,7 +697,7 @@ mod tests { OrderingGuarantee::Ordered(Some(1)), &mut connection, Err(TryRecvError::Empty), - 4, + 3, ); assert_incoming_with_order( @@ -676,7 +709,7 @@ mod tests { PAYLOAD.to_vec(), Some(1), ))), - 2, + 1, ); } @@ -720,7 +753,7 @@ mod tests { PAYLOAD.to_vec(), Some(1), ))), - 1, + 0, ); } diff --git a/src/packet/enums.rs b/src/packet/enums.rs index 18e23c1b..454f3634 100644 --- a/src/packet/enums.rs +++ b/src/packet/enums.rs @@ -88,6 +88,8 @@ pub enum PacketType { Packet = 0, /// Fragment of a full packet Fragment = 1, + /// Heartbeat packet + Heartbeat = 2, } impl EnumConverter for PacketType { @@ -104,6 +106,7 @@ impl TryFrom for PacketType { match value { 0 => Ok(PacketType::Packet), 1 => Ok(PacketType::Fragment), + 2 => Ok(PacketType::Heartbeat), _ => Err(ErrorKind::DecodingError(DecodingErrorKind::PacketType)), } } @@ -152,9 +155,10 @@ mod tests { } #[test] - fn assure_parsing_packet_id() { + fn assure_parsing_packet_type() { let packet = PacketType::Packet; let fragment = PacketType::Fragment; + let heartbeat = PacketType::Heartbeat; assert_eq!( PacketType::Packet, PacketType::try_from(packet.to_u8()).unwrap() @@ -163,5 +167,9 @@ mod tests { PacketType::Fragment, PacketType::try_from(fragment.to_u8()).unwrap() ); + assert_eq!( + PacketType::Heartbeat, + PacketType::try_from(heartbeat.to_u8()).unwrap() + ); } } diff --git a/src/packet/header/acked_packet_header.rs b/src/packet/header/acked_packet_header.rs index 23756838..a1994bd0 100644 --- a/src/packet/header/acked_packet_header.rs +++ b/src/packet/header/acked_packet_header.rs @@ -85,7 +85,7 @@ mod tests { fn serialize() { let mut buffer = Vec::new(); let header = AckedPacketHeader::new(1, 2, 3); - header.parse(&mut buffer).is_ok(); + assert![header.parse(&mut buffer).is_ok()]; assert_eq!(buffer[1], 1); assert_eq!(buffer[3], 2); diff --git a/src/packet/header/arranging_header.rs b/src/packet/header/arranging_header.rs index 287cd911..f9ac9c84 100644 --- a/src/packet/header/arranging_header.rs +++ b/src/packet/header/arranging_header.rs @@ -74,7 +74,7 @@ mod tests { fn serialize() { let mut buffer = Vec::new(); let header = ArrangingHeader::new(1, 2); - header.parse(&mut buffer).is_ok(); + assert![header.parse(&mut buffer).is_ok()]; assert_eq!(buffer[1], 1); assert_eq!(buffer[2], 2); diff --git a/src/packet/header/fragment_header.rs b/src/packet/header/fragment_header.rs index 233f402c..984741b3 100644 --- a/src/packet/header/fragment_header.rs +++ b/src/packet/header/fragment_header.rs @@ -83,7 +83,7 @@ mod tests { fn serialize() { let mut buffer = Vec::new(); let header = FragmentHeader::new(1, 2, 3); - header.parse(&mut buffer).is_ok(); + assert![header.parse(&mut buffer).is_ok()]; assert_eq!(buffer[1], 1); assert_eq!(buffer[2], 2); diff --git a/src/packet/header/standard_header.rs b/src/packet/header/standard_header.rs index 9619815a..4eccaafe 100644 --- a/src/packet/header/standard_header.rs +++ b/src/packet/header/standard_header.rs @@ -17,7 +17,7 @@ pub struct StandardHeader { } impl StandardHeader { - /// Create new heartbeat header. + /// Create new header. pub fn new( delivery_guarantee: DeliveryGuarantee, ordering_guarantee: OrderingGuarantee, @@ -53,6 +53,11 @@ impl StandardHeader { self.packet_type } + /// Returns true if the packet is a heartbeat packet, false otherwise + pub fn is_heartbeat(&self) -> bool { + self.packet_type == PacketType::Heartbeat + } + /// Returns true if the packet is a fragment, false if not pub fn is_fragment(&self) -> bool { self.packet_type == PacketType::Fragment @@ -126,7 +131,7 @@ mod tests { OrderingGuarantee::Sequenced(None), PacketType::Packet, ); - header.parse(&mut buffer).is_ok(); + assert![header.parse(&mut buffer).is_ok()]; // [0 .. 3] protocol version assert_eq!(buffer[2], PacketType::Packet.to_u8());