diff --git a/examples/protocol_id.rs b/examples/protocol_id.rs new file mode 100644 index 0000000..5e19f2f --- /dev/null +++ b/examples/protocol_id.rs @@ -0,0 +1,47 @@ +//! Example demonstrating network isolation using protocol IDs. +//! +//! This example shows how to create isolated DHT networks that don't interfere +//! with each other using the protocol ID. +//! +//! Run with: +//! ```bash +//! cargo run --example protocol_id +//! ``` + +use mainline::{Dht, DEFAULT_STAGING_PROTOCOL_ID}; + +fn main() -> Result<(), std::io::Error> { + println!("Protocol ID Example\n"); + + // Example 1: Default behavior - participates in main BitTorrent DHT + println!("1. Creating DHT node without protocol ID (default behavior):"); + println!(" This node will communicate with the main BitTorrent DHT network"); + let _default_dht = Dht::client()?; + println!(" ✓ Created\n"); + + // Example 2: Using a custom protocol ID for an isolated network + println!("2. Creating DHT node with custom protocol ID:"); + println!(" Protocol ID: /myapp/mainline/1.0.0"); + println!(" This node will only communicate with other nodes using the same protocol ID"); + let _custom_dht = Dht::builder() + .protocol_id("/myapp/mainline/1.0.0") + .build()?; + println!(" ✓ Created\n"); + + // Example 3: Using the default staging protocol ID constant + println!("4. Creating DHT node using DEFAULT_STAGING_PROTOCOL_ID:"); + println!(" Protocol ID: {}", DEFAULT_STAGING_PROTOCOL_ID); + println!(" Useful for creating isolated test networks"); + let _staging_dht = Dht::builder() + .protocol_id(DEFAULT_STAGING_PROTOCOL_ID) + .build()?; + println!(" ✓ Created\n"); + + println!("Network Isolation Rules:"); + println!("• Nodes with the same protocol ID can communicate"); + println!("• Nodes with different protocol IDs ignore each other's messages"); + println!("• Nodes without a protocol ID (None) accept all messages (backward compatible)"); + println!("• Nodes with a protocol ID ONLY accept messages with matching protocol ID"); + + Ok(()) +} diff --git a/src/common/messages.rs b/src/common/messages.rs index 4ac2259..9a29972 100644 --- a/src/common/messages.rs +++ b/src/common/messages.rs @@ -29,6 +29,9 @@ pub(crate) struct Message { /// For bep0043. When set true on a request, indicates that the requester can't reply to requests and that responders should not add requester to their routing tables. /// Should only be set on requests - undefined behavior when set on a response. pub read_only: bool, + + /// Optional protocol ID for network isolation + pub protocol_id: Option>, } #[derive(Debug, PartialEq, Clone)] @@ -220,6 +223,7 @@ impl Message { .requester_ip .map(|sockaddr| sockaddr_to_bytes(&sockaddr)), read_only: if self.read_only { Some(1) } else { Some(0) }, + protocol_id: self.protocol_id, variant: match self.message_type { MessageType::Request(RequestSpecific { requester_id, @@ -412,6 +416,7 @@ impl Message { } else { false }, + protocol_id: msg.protocol_id, message_type: match msg.variant { internal::DHTMessageVariant::Request(req_variant) => { MessageType::Request(match req_variant { @@ -664,6 +669,10 @@ impl Message { _ => None, } } + + pub fn protocol_id(&self) -> Option<&[u8]> { + self.protocol_id.as_deref() + } } fn bytes_to_sockaddr>(bytes: T) -> Result { @@ -781,6 +790,7 @@ mod tests { version: None, requester_ip: None, read_only: false, + protocol_id: None, message_type: MessageType::Request(RequestSpecific { requester_id: Id::random(), request_type: RequestTypeSpecific::Ping, @@ -801,6 +811,7 @@ mod tests { version: Some([0xde, 0xad, 0, 1]), requester_ip: Some("99.100.101.102:1030".parse().unwrap()), read_only: false, + protocol_id: None, message_type: MessageType::Response(ResponseSpecific::Ping(PingResponseArguments { responder_id: Id::random(), })), @@ -820,6 +831,7 @@ mod tests { version: Some([0x62, 0x61, 0x72, 0x66]), requester_ip: None, read_only: false, + protocol_id: None, message_type: MessageType::Request(RequestSpecific { requester_id: Id::random(), request_type: RequestTypeSpecific::FindNode(FindNodeRequestArguments { @@ -842,6 +854,7 @@ mod tests { version: Some([0x62, 0x61, 0x72, 0x66]), requester_ip: None, read_only: true, + protocol_id: None, message_type: MessageType::Request(RequestSpecific { requester_id: Id::random(), request_type: RequestTypeSpecific::FindNode(FindNodeRequestArguments { @@ -864,6 +877,7 @@ mod tests { version: Some([1, 2, 3, 4]), requester_ip: Some("50.51.52.53:5455".parse().unwrap()), read_only: false, + protocol_id: None, message_type: MessageType::Response(ResponseSpecific::FindNode( FindNodeResponseArguments { responder_id: Id::random(), @@ -897,6 +911,7 @@ mod tests { version: Some([72, 73, 0, 1]), requester_ip: None, read_only: false, + protocol_id: None, message_type: MessageType::Request(RequestSpecific { requester_id: Id::random(), request_type: RequestTypeSpecific::GetPeers(GetPeersRequestArguments { @@ -919,6 +934,7 @@ mod tests { version: Some([1, 2, 3, 4]), requester_ip: Some("50.51.52.53:5455".parse().unwrap()), read_only: true, + protocol_id: None, message_type: MessageType::Response(ResponseSpecific::NoValues( NoValuesResponseArguments { responder_id: Id::random(), @@ -959,6 +975,7 @@ mod tests { version: Some([1, 2, 3, 4]), requester_ip: Some("50.51.52.53:5455".parse().unwrap()), read_only: false, + protocol_id: None, message_type: MessageType::Response(ResponseSpecific::GetPeers( GetPeersResponseArguments { responder_id: Id::random(), @@ -983,6 +1000,7 @@ mod tests { read_only: None, transaction_id: [1, 2, 3, 4], version: None, + protocol_id: None, variant: internal::DHTMessageVariant::Response( internal::DHTResponseSpecific::NoValues { arguments: internal::DHTNoValuesResponseArguments { @@ -1007,6 +1025,7 @@ mod tests { version: Some([72, 73, 0, 1]), requester_ip: None, read_only: false, + protocol_id: None, message_type: MessageType::Request(RequestSpecific { requester_id: Id::random(), request_type: RequestTypeSpecific::GetValue(GetValueRequestArguments { @@ -1031,6 +1050,7 @@ mod tests { version: Some([1, 2, 3, 4]), requester_ip: Some("50.51.52.53:5455".parse().unwrap()), read_only: false, + protocol_id: None, message_type: MessageType::Response(ResponseSpecific::GetImmutable( GetImmutableResponseArguments { responder_id: Id::random(), @@ -1055,6 +1075,7 @@ mod tests { version: Some([1, 2, 3, 4]), requester_ip: Some("50.51.52.53:5455".parse().unwrap()), read_only: false, + protocol_id: None, message_type: MessageType::Request(RequestSpecific { requester_id: Id::random(), request_type: RequestTypeSpecific::Put(PutRequest { @@ -1083,6 +1104,7 @@ mod tests { version: Some([1, 2, 3, 4]), requester_ip: Some("50.51.52.53:5455".parse().unwrap()), read_only: false, + protocol_id: None, message_type: MessageType::Request(RequestSpecific { requester_id: Id::random(), request_type: RequestTypeSpecific::Put(PutRequest { @@ -1106,4 +1128,25 @@ mod tests { let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap(); assert_eq!(parsed_msg, original_msg); } + + #[test] + fn test_protocol_id_request() { + // Old nodes (without protocol_id field) should be able to parse messages with protocol_id + // because bencode ignores unknown fields + let protocol_id = b"/pubky/mainline/1.0.0"; + let msg_with_protocol = Message { + transaction_id: 258, + version: Some([0x62, 0x61, 0x72, 0x66]), + requester_ip: None, + read_only: false, + protocol_id: Some(protocol_id.to_vec().into_boxed_slice()), + message_type: MessageType::Request(RequestSpecific { + requester_id: Id::random(), + request_type: RequestTypeSpecific::Ping, + }), + }; + let bytes = msg_with_protocol.to_bytes().unwrap(); + let parsed = Message::from_bytes(&bytes).unwrap(); + assert_eq!(parsed.protocol_id(), Some(protocol_id.as_ref())); + } } diff --git a/src/common/messages/internal.rs b/src/common/messages/internal.rs index 3674678..aff8c76 100644 --- a/src/common/messages/internal.rs +++ b/src/common/messages/internal.rs @@ -23,6 +23,11 @@ pub struct DHTMessage { #[serde(default)] #[serde(rename = "ro")] pub read_only: Option, + + #[serde(default)] + #[serde(rename = "p", with = "serde_bytes")] + /// protocol ID for network isolation + pub protocol_id: Option>, } impl DHTMessage { diff --git a/src/dht.rs b/src/dht.rs index 1c225c7..af72271 100644 --- a/src/dht.rs +++ b/src/dht.rs @@ -109,6 +109,20 @@ impl DhtBuilder { self } + /// Set the protocol ID for network isolation + /// + /// When set, this node will only communicate with other nodes using the same protocol ID. + /// Messages from nodes with different or no protocol IDs will be rejected. + /// + /// Format: "/prefix/mainline/version" (e.g., "/pubky/mainline/1.0.0") + /// + /// When None (default), accepts all messages for backward compatiblility. + pub fn protocol_id(&mut self, protocol_id: impl Into) -> &mut Self { + self.0.protocol_id = Some(protocol_id.into()); + + self + } + /// Create a Dht node. pub fn build(&self) -> Result { Dht::new(self.0.clone()) @@ -1103,4 +1117,131 @@ mod test { .iter() .all(|n| n.to_bootstrap().len() == size - 1)); } + + #[test] + fn protocol_id_isolation_different_networks_cannot_communicate() { + let testnet = Testnet::new(5).unwrap(); + + // Create node A with protocol ID "/network_a/mainline/1.0.0" + let node_a = Dht::builder() + .protocol_id("/network_a/mainline/1.0.0") + .bootstrap(&testnet.bootstrap) + .build() + .unwrap(); + + // Create node B with different protocol ID "/network_b/mainline/1.0.0" + let node_b = Dht::builder() + .protocol_id("/network_b/mainline/1.0.0") + .bootstrap(&testnet.bootstrap) + .build() + .unwrap(); + + // Wait for nodes to attempt bootstrapping + std::thread::sleep(std::time::Duration::from_millis(500)); + + // Node A puts immutable data + let value = b"Hello from Network A"; + let target = node_a.put_immutable(value).unwrap(); + + // Node B (on different network) should NOT be able to get the data + // Because they have different protocol IDs, B's requests are ignored by A's network + let result = node_b.get_immutable(target); + assert!( + result.is_none(), + "Node B should not be able to retrieve data from Network A" + ); + } + + #[test] + fn protocol_id_isolation_same_network_can_communicate() { + let mut nodes: Vec = vec![]; + let mut bootstrap = vec![]; + + // Create first node with protocol ID + let node = Dht::builder() + .protocol_id("/test_network/mainline/1.0.0") + .server_mode() + .no_bootstrap() + .build() + .unwrap(); + + let info = node.info(); + bootstrap.push(format!("127.0.0.1:{}", info.local_addr().port())); + nodes.push(node); + + // Create more nodes with SAME protocol ID + for _ in 1..5 { + let node = Dht::builder() + .protocol_id("/test_network/mainline/1.0.0") + .server_mode() + .bootstrap(&bootstrap) + .build() + .unwrap(); + nodes.push(node); + } + + for node in &nodes { + node.bootstrapped(); + } + + // Create two client nodes with the SAME protocol ID + let node_a = Dht::builder() + .protocol_id("/test_network/mainline/1.0.0") + .bootstrap(&bootstrap) + .build() + .unwrap(); + let node_b = Dht::builder() + .protocol_id("/test_network/mainline/1.0.0") + .bootstrap(&bootstrap) + .build() + .unwrap(); + + // Node A puts immutable data + let value = b"Hello from same network"; + let target = node_a.put_immutable(value).unwrap(); + + // Node B (on same network) SHOULD be able to get the data + let result = node_b.get_immutable(target); + assert!( + result.is_some(), + "Node B should be able to retrieve data from Node A (same network)" + ); + assert_eq!(result.unwrap().as_ref(), value); + } + + #[test] + fn protocol_id_node_rejects_messages_without_protocol_id() { + let testnet = Testnet::new(5).unwrap(); + + // Create node with protocol ID + let node_with_protocol = Dht::builder() + .protocol_id("/custom_network/mainline/1.0.0") + .bootstrap(&testnet.bootstrap) + .build() + .unwrap(); + + // Create node WITHOUT protocol ID (default) + let node_without_protocol = Dht::builder() + .bootstrap(&testnet.bootstrap) + .build() + .unwrap(); + + // Wait for bootstrapping + std::thread::sleep(std::time::Duration::from_millis(500)); + + // Node without protocol ID puts data + let value = b"Hello from default network"; + let target = node_without_protocol.put_immutable(value).unwrap(); + + // Node with protocol ID will NOT be able to read from default network nodes + // because nodes with protocol IDs reject messages without protocol IDs + let result = node_with_protocol.get_immutable(target); + + // The node with protocol ID cannot retrieve data from nodes without protocol IDs + // because it rejects their responses (no protocol ID = rejected) + assert!( + result.is_none(), + "Node with protocol ID should not retrieve data from nodes without protocol ID" + ); + } } diff --git a/src/lib.rs b/src/lib.rs index 72335af..e543fdf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,9 @@ pub use rpc::{ pub use ed25519_dalek::SigningKey; +/// Default protocol ID for staging isolated network +pub const DEFAULT_STAGING_PROTOCOL_ID: &str = "/pubky_staging/mainline/1.0.0"; + pub mod errors { //! Exported errors #[cfg(feature = "node")] diff --git a/src/rpc/config.rs b/src/rpc/config.rs index a9c253e..8f84365 100644 --- a/src/rpc/config.rs +++ b/src/rpc/config.rs @@ -35,6 +35,13 @@ pub struct Config { /// /// Defaults to None, where we depend on suggestions from responding nodes. pub public_ip: Option, + /// Optional protocol ID for network isolation + /// + /// When Some(id) only accepts messages with matching protocol ID. + /// When None, accepts all messages. + /// + /// Defaults to None + pub protocol_id: Option, } impl Default for Config { @@ -46,6 +53,7 @@ impl Default for Config { server_settings: Default::default(), server_mode: false, public_ip: None, + protocol_id: None, } } } diff --git a/src/rpc/socket.rs b/src/rpc/socket.rs index 39f3a83..20d0d2b 100644 --- a/src/rpc/socket.rs +++ b/src/rpc/socket.rs @@ -31,6 +31,7 @@ pub struct KrpcSocket { inflight_requests: InflightRequests, last_cleanup: Instant, local_addr: SocketAddrV4, + protocol_id: Option>, // poll_interval: Duration, } @@ -62,6 +63,7 @@ impl KrpcSocket { inflight_requests: InflightRequests::new(request_timeout), last_cleanup: Instant::now(), local_addr, + protocol_id: config.protocol_id.as_ref().map(|s| s.as_bytes().to_vec()), }) } @@ -154,6 +156,10 @@ impl KrpcSocket { match Message::from_bytes(bytes) { Ok(message) => { + if !self.validate_protocol_id(&message) { + return None; + } + let should_return = match message.message_type { MessageType::Request(_) => { trace!( @@ -231,6 +237,34 @@ impl KrpcSocket { false } + /// Returns true if message should be accepted: + /// - If we have no protocol_id set (None), accept all messages (backward compatible) + /// - If we have protocol_id set and message has matching protocol_id, accept + /// - Otherwise, reject + fn validate_protocol_id(&self, message: &Message) -> bool { + match &self.protocol_id { + None => true, // No validation - accept all (default) + Some(local_id) => { + match message.protocol_id() { + Some(msg_id) => { + if msg_id == local_id.as_slice() { + true + } else { + trace!( + context = "socket_validation", + local_protocol = ?String::from_utf8_lossy(local_id), + message_protocol = ?String::from_utf8_lossy(msg_id), + "Protocol ID mismatch" + ); + false + } + } + None => false, // Reject non-matching or no protocol ID + } + } + } + } + /// Increments self.next_tid and returns the previous value. fn tid(&mut self) -> u32 { // We don't bother much with reusing freed transaction ids, @@ -251,6 +285,10 @@ impl KrpcSocket { version: Some(VERSION), read_only: !self.server_mode, requester_ip: None, + protocol_id: self + .protocol_id + .as_ref() + .map(|v| v.clone().into_boxed_slice()), } } @@ -268,6 +306,10 @@ impl KrpcSocket { read_only: !self.server_mode, // BEP_0042 Only relevant in responses. requester_ip: Some(requester_ip), + protocol_id: self + .protocol_id + .as_ref() + .map(|v| v.clone().into_boxed_slice()), } }