Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 104 additions & 14 deletions crates/libtortillas/src/peer/actor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
collections::HashMap,
collections::{HashMap, VecDeque},
sync::{Arc, atomic::AtomicU8},
time::Instant,
};
Expand All @@ -26,6 +26,8 @@ use crate::{
torrent::{TorrentActor, TorrentMessage, TorrentRequest, TorrentResponse},
};

const MAX_PENDING_MESSAGES: usize = 8;

const PEER_KEEPALIVE_TIMEOUT: u64 = 10;
const PEER_DISCONNECT_TIMEOUT: u64 = 20;

Expand All @@ -39,6 +41,7 @@ pub(crate) struct PeerActor {
supervisor: ActorRef<TorrentActor>,

pending_block_requests: Arc<DashSet<(usize, usize, usize)>>,
pending_message_requests: VecDeque<PeerMessages>,
}

impl PeerActor {
Expand Down Expand Up @@ -168,7 +171,7 @@ impl PeerActor {
None,
);

if let Err(e) = self.stream.send(message).await {
if let Err(e) = self.send_message(message).await {
trace!(error = %e, piece, "Failed to send metadata request");
}
} else {
Expand Down Expand Up @@ -223,6 +226,87 @@ impl PeerActor {

self.peer.set_am_interested(has_interesting_pieces);
}

/// Sends all queued messages to the peer. This sends synchronously, and will
/// not return until each message has been sent. This is because most of
/// the time we want the messages to be sent in their original order.
#[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))]
async fn flush_queue(&mut self) {
if self.pending_message_requests.is_empty() {
return;
}

let queued_messages = self.pending_message_requests.len();

while let Some(msg) = self.pending_message_requests.pop_back() {
self
.stream
.send(msg)
.await
.expect("Failed to send message to peer");
}

trace!(amount = queued_messages, "Flushed queued messages to peer");
}

/// Flushes/resends all pending block requests to the peer.
#[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))]
async fn flush_block_requests(&mut self) {
if self.pending_block_requests.is_empty() {
return;
}

let queued_block_requests = self.pending_block_requests.len();
let mut completed = 0usize;

for request in self.pending_block_requests.iter() {
let (index, begin, length) = *request;
if self
.stream
.send(PeerMessages::Request(
index as u32,
begin as u32,
length as u32,
))
.await
.is_ok()
{
completed += 1;
}
}
trace!(
amount = queued_block_requests,
amount_succussful = completed,
"Flushed queued block requests to peer"
);
}

/// Send a message to the peer. Checks if the peer is choked, and if so,
/// queues the message in [`self.pending_message_requests`]. This function
/// will NOT queue request messages since they have their own queue of
/// sorts.
///
/// Unless you're doing something like a `KeepAlive` message or a piece
/// request, you should use this function over [`Self::stream.send`].
#[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))]
async fn send_message(&mut self, msg: PeerMessages) -> Result<(), PeerActorError> {
if self.peer.am_choked() {
// Only push the message if it's not a request
if matches!(msg, PeerMessages::Request(..)) {
return Ok(());
}
if self.pending_message_requests.len() >= MAX_PENDING_MESSAGES {
self.pending_message_requests.pop_back();
}

self.pending_message_requests.push_front(msg);
trace!("Peer is choked, queueing message");

return Ok(());
}

self.stream.send(msg).await
}
}

impl Actor for PeerActor {
Expand Down Expand Up @@ -251,6 +335,7 @@ impl Actor for PeerActor {
stream,
supervisor,
pending_block_requests: Arc::new(DashSet::new()),
pending_message_requests: VecDeque::with_capacity(MAX_PENDING_MESSAGES),
})
}

Expand Down Expand Up @@ -336,6 +421,10 @@ impl Message<PeerMessages> for PeerActor {
PeerMessages::Unchoke => {
self.peer.update_last_optimistic_unchoke();
self.peer.set_am_choked(false);

// Send all pending messages
self.flush_queue().await;
self.flush_block_requests().await;
trace!("Peer unchoked us");
}
PeerMessages::Interested => {
Expand Down Expand Up @@ -450,15 +539,17 @@ impl Message<PeerTell> for PeerActor {
return;
}

self
.stream
.send(PeerMessages::Request(
index as u32,
begin as u32,
length as u32,
))
.await
.expect("Failed to send piece request");
if !self.peer.am_choked() {
self
.stream
.send(PeerMessages::Request(
index as u32,
begin as u32,
length as u32,
))
.await
.expect("Failed to send piece request");
}
self.pending_block_requests.insert((index, begin, length));
trace!(piece_index = index, "Sent piece request to peer");
}
Expand All @@ -485,14 +576,13 @@ impl Message<PeerTell> for PeerActor {
}
PeerTell::HaveInfoDict(bitfield) => {
self
.stream
.send(PeerMessages::Bitfield(bitfield))
.send_message(PeerMessages::Bitfield(bitfield))
.await
.expect("Failed to send bitfield");
trace!("Sent bitfield to peer");
}
PeerTell::Have(piece) => {
if let Err(e) = self.stream.send(PeerMessages::Have(piece as u32)).await {
if let Err(e) = self.send_message(PeerMessages::Have(piece as u32)).await {
trace!(piece_num = piece, error = %e, "Failed to send Have message to peer");
}
}
Expand Down
32 changes: 29 additions & 3 deletions crates/libtortillas/src/protocol/messages.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::hash;
use std::{
collections::HashMap,
fmt::Display,
Expand All @@ -23,7 +24,7 @@ use crate::{
peer::{MAGIC_STRING, PeerId},
};

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[repr(u8)]
/// Represents messages exchanged between peers in the BitTorrent protocol.
///
Expand Down Expand Up @@ -366,7 +367,8 @@ impl PeerMessages {
PartialEq,
Eq,
Deserialize_repr,
TryFromPrimitive
TryFromPrimitive,
Hash
)]
#[repr(u8)]
pub enum ExtendedMessageType {
Expand Down Expand Up @@ -456,6 +458,30 @@ pub struct ExtendedMessage {
pub total_size: Option<usize>,
}

impl hash::Hash for ExtendedMessage {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
if let Some(extensions) = &self.supported_extensions {
let mut pairs: Vec<_> = extensions.iter().collect();
pairs.sort_by_key(|i| i.0);

pairs.hash(state);
}

self.local_port.hash(state);
self.version.hash(state);
self.your_ip.hash(state);
self.ipv6.hash(state);
self.ipv4.hash(state);
self.outstanding_requests.hash(state);
self.metadata_size.hash(state);
if let Some(msg_type) = &self.msg_type {
msg_type.hash(state);
}
self.piece.hash(state);
self.total_size.hash(state);
}
}

impl ExtendedMessage {
pub fn new() -> Self {
Self::default()
Expand Down Expand Up @@ -502,7 +528,7 @@ impl ExtendedMessage {
}

/// BitTorrent Handshake message structure
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Handshake {
/// Protocol identifier (typically "BitTorrent protocol")
pub protocol: Bytes,
Expand Down