diff --git a/crates/libtortillas/src/peer/actor.rs b/crates/libtortillas/src/peer/actor.rs index 232cd86..9d8dc2b 100644 --- a/crates/libtortillas/src/peer/actor.rs +++ b/crates/libtortillas/src/peer/actor.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, sync::{Arc, atomic::AtomicU8}, time::Instant, }; @@ -26,6 +26,8 @@ use crate::{ torrent::{TorrentActor, TorrentMessage, TorrentRequest, TorrentResponse}, }; +const MAX_PENDING_MESSAGES: usize = 8; + const PEER_KEEPALIVE_TIMEOUT: u64 = 10; const PEER_DISCONNECT_TIMEOUT: u64 = 20; @@ -39,6 +41,7 @@ pub(crate) struct PeerActor { supervisor: ActorRef, pending_block_requests: Arc>, + pending_message_requests: VecDeque, } impl PeerActor { @@ -168,7 +171,7 @@ impl PeerActor { None, ); - if let Err(e) = self.stream.send(message).await { + if let Err(e) = self.send_message(message).await { trace!(error = %e, piece, "Failed to send metadata request"); } } else { @@ -223,6 +226,87 @@ impl PeerActor { self.peer.set_am_interested(has_interesting_pieces); } + + /// Sends all queued messages to the peer. This sends synchronously, and will + /// not return until each message has been sent. This is because most of + /// the time we want the messages to be sent in their original order. + #[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))] + async fn flush_queue(&mut self) { + if self.pending_message_requests.is_empty() { + return; + } + + let queued_messages = self.pending_message_requests.len(); + + while let Some(msg) = self.pending_message_requests.pop_back() { + self + .stream + .send(msg) + .await + .expect("Failed to send message to peer"); + } + + trace!(amount = queued_messages, "Flushed queued messages to peer"); + } + + /// Flushes/resends all pending block requests to the peer. + #[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))] + async fn flush_block_requests(&mut self) { + if self.pending_block_requests.is_empty() { + return; + } + + let queued_block_requests = self.pending_block_requests.len(); + let mut completed = 0usize; + + for request in self.pending_block_requests.iter() { + let (index, begin, length) = *request; + if self + .stream + .send(PeerMessages::Request( + index as u32, + begin as u32, + length as u32, + )) + .await + .is_ok() + { + completed += 1; + } + } + trace!( + amount = queued_block_requests, + amount_succussful = completed, + "Flushed queued block requests to peer" + ); + } + + /// Send a message to the peer. Checks if the peer is choked, and if so, + /// queues the message in [`self.pending_message_requests`]. This function + /// will NOT queue request messages since they have their own queue of + /// sorts. + /// + /// Unless you're doing something like a `KeepAlive` message or a piece + /// request, you should use this function over [`Self::stream.send`]. + #[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))] + async fn send_message(&mut self, msg: PeerMessages) -> Result<(), PeerActorError> { + if self.peer.am_choked() { + // Only push the message if it's not a request + if matches!(msg, PeerMessages::Request(..)) { + return Ok(()); + } + if self.pending_message_requests.len() >= MAX_PENDING_MESSAGES { + self.pending_message_requests.pop_back(); + } + + self.pending_message_requests.push_front(msg); + trace!("Peer is choked, queueing message"); + + return Ok(()); + } + + self.stream.send(msg).await + } } impl Actor for PeerActor { @@ -251,6 +335,7 @@ impl Actor for PeerActor { stream, supervisor, pending_block_requests: Arc::new(DashSet::new()), + pending_message_requests: VecDeque::with_capacity(MAX_PENDING_MESSAGES), }) } @@ -336,6 +421,10 @@ impl Message for PeerActor { PeerMessages::Unchoke => { self.peer.update_last_optimistic_unchoke(); self.peer.set_am_choked(false); + + // Send all pending messages + self.flush_queue().await; + self.flush_block_requests().await; trace!("Peer unchoked us"); } PeerMessages::Interested => { @@ -450,15 +539,17 @@ impl Message for PeerActor { return; } - self - .stream - .send(PeerMessages::Request( - index as u32, - begin as u32, - length as u32, - )) - .await - .expect("Failed to send piece request"); + if !self.peer.am_choked() { + self + .stream + .send(PeerMessages::Request( + index as u32, + begin as u32, + length as u32, + )) + .await + .expect("Failed to send piece request"); + } self.pending_block_requests.insert((index, begin, length)); trace!(piece_index = index, "Sent piece request to peer"); } @@ -485,14 +576,13 @@ impl Message for PeerActor { } PeerTell::HaveInfoDict(bitfield) => { self - .stream - .send(PeerMessages::Bitfield(bitfield)) + .send_message(PeerMessages::Bitfield(bitfield)) .await .expect("Failed to send bitfield"); trace!("Sent bitfield to peer"); } PeerTell::Have(piece) => { - if let Err(e) = self.stream.send(PeerMessages::Have(piece as u32)).await { + if let Err(e) = self.send_message(PeerMessages::Have(piece as u32)).await { trace!(piece_num = piece, error = %e, "Failed to send Have message to peer"); } } diff --git a/crates/libtortillas/src/protocol/messages.rs b/crates/libtortillas/src/protocol/messages.rs index a037c38..fdf9715 100644 --- a/crates/libtortillas/src/protocol/messages.rs +++ b/crates/libtortillas/src/protocol/messages.rs @@ -1,3 +1,4 @@ +use core::hash; use std::{ collections::HashMap, fmt::Display, @@ -23,7 +24,7 @@ use crate::{ peer::{MAGIC_STRING, PeerId}, }; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[repr(u8)] /// Represents messages exchanged between peers in the BitTorrent protocol. /// @@ -366,7 +367,8 @@ impl PeerMessages { PartialEq, Eq, Deserialize_repr, - TryFromPrimitive + TryFromPrimitive, + Hash )] #[repr(u8)] pub enum ExtendedMessageType { @@ -456,6 +458,30 @@ pub struct ExtendedMessage { pub total_size: Option, } +impl hash::Hash for ExtendedMessage { + fn hash(&self, state: &mut H) { + if let Some(extensions) = &self.supported_extensions { + let mut pairs: Vec<_> = extensions.iter().collect(); + pairs.sort_by_key(|i| i.0); + + pairs.hash(state); + } + + self.local_port.hash(state); + self.version.hash(state); + self.your_ip.hash(state); + self.ipv6.hash(state); + self.ipv4.hash(state); + self.outstanding_requests.hash(state); + self.metadata_size.hash(state); + if let Some(msg_type) = &self.msg_type { + msg_type.hash(state); + } + self.piece.hash(state); + self.total_size.hash(state); + } +} + impl ExtendedMessage { pub fn new() -> Self { Self::default() @@ -502,7 +528,7 @@ impl ExtendedMessage { } /// BitTorrent Handshake message structure -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Handshake { /// Protocol identifier (typically "BitTorrent protocol") pub protocol: Bytes,