From 74c30024cd600832c48834fa9a52c64ec393f531 Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Mon, 12 May 2025 13:58:49 -0700 Subject: [PATCH 1/8] Add KTlsInnerStream --- Cargo.lock | 8 +- ktls/Cargo.toml | 2 +- ktls/src/client.rs | 24 ++ ktls/src/ffi.rs | 98 ++++- ktls/src/lib.rs | 4 + ktls/src/protocol.rs | 85 +++++ ktls/src/stream.rs | 888 +++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1099 insertions(+), 10 deletions(-) create mode 100644 ktls/src/client.rs create mode 100644 ktls/src/protocol.rs create mode 100644 ktls/src/stream.rs diff --git a/Cargo.lock b/Cargo.lock index f34db36..f0dc86c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -745,9 +745,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.25" +version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "822ee9188ac4ec04a2f0531e55d035fb2de73f18b41a63c70c2712503b6fb13c" +checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ "aws-lc-rs", "once_cell", @@ -766,9 +766,9 @@ checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" [[package]] name = "rustls-webpki" -version = "0.103.1" +version = "0.103.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" +checksum = "7149975849f1abb3832b246010ef62ccc80d3a76169517ada7188252b9cfb437" dependencies = [ "aws-lc-rs", "ring", diff --git a/ktls/Cargo.toml b/ktls/Cargo.toml index d0656fa..f48187f 100644 --- a/ktls/Cargo.toml +++ b/ktls/Cargo.toml @@ -17,7 +17,7 @@ libc = { version = "0.2.155", features = ["const-extern-fn"] } thiserror = "2" tracing = "0.1.40" tokio-rustls = { default-features = false, version = "0.26.0" } -rustls = { version = "0.23.12", default-features = false } +rustls = { version = "0.23.27", default-features = false } smallvec = "1.13.2" memoffset = "0.9.1" pin-project-lite = "0.2.14" diff --git a/ktls/src/client.rs b/ktls/src/client.rs new file mode 100644 index 0000000..815fab7 --- /dev/null +++ b/ktls/src/client.rs @@ -0,0 +1,24 @@ +use std::sync::Arc; + + +pub struct KTlsConnector { + config: Arc, +} + +impl KTlsConnector { + pub fn new(config: Arc) -> Self { + Self { config } + } + + // pub async fn try_connect( + // &self, + // domain: ServerName<'static>, + // stream: IO + // ) -> Result< +} + +// pub struct KTlsClientStream + +// pub struct TryConnectError { +// connection: +// } diff --git a/ktls/src/ffi.rs b/ktls/src/ffi.rs index 77af19b..be36e9d 100644 --- a/ktls/src/ffi.rs +++ b/ktls/src/ffi.rs @@ -1,4 +1,4 @@ -use std::os::unix::prelude::RawFd; +use std::{ffi::c_void, io, os::unix::prelude::RawFd}; use ktls_sys::bindings as ktls; use rustls::{ @@ -243,15 +243,14 @@ const TLS_SET_RECORD_TYPE: libc::c_int = 1; const ALERT: u8 = 0x15; // Yes, really. cmsg components are aligned to [libc::c_long] -#[cfg_attr(target_pointer_width = "32", repr(C, align(4)))] -#[cfg_attr(target_pointer_width = "64", repr(C, align(8)))] -struct Cmsg { +pub(crate) struct Cmsg { + _align: [libc::c_ulong; 0], hdr: libc::cmsghdr, data: [u8; N], } impl Cmsg { - fn new(level: i32, typ: i32, data: [u8; N]) -> Self { + pub(crate) fn new(level: i32, typ: i32, data: [u8; N]) -> Self { Self { hdr: libc::cmsghdr { // on Linux this is a usize, on macOS this is a u32 @@ -261,8 +260,21 @@ impl Cmsg { cmsg_type: typ, }, data, + _align: [], } } + + pub(crate) fn level(&self) -> i32 { + self.hdr.cmsg_level + } + + pub(crate) fn typ(&self) -> i32 { + self.hdr.cmsg_type + } + + pub(crate) fn data(&self) -> &[u8] { + &self.data[..self.hdr.cmsg_len.min(N)] + } } pub fn send_close_notify(fd: RawFd) -> std::io::Result<()> { @@ -292,3 +304,79 @@ pub fn send_close_notify(fd: RawFd) -> std::io::Result<()> { } Ok(()) } + +/// A wrapper around [`libc::sendmsg`]. +pub(crate) fn sendmsg( + fd: RawFd, + data: &[io::IoSlice<'_>], + cmsg: Option<&Cmsg>, + flags: i32, +) -> io::Result { + let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; + + if let Some(cmsg) = cmsg { + msg.msg_control = cmsg as *const _ as *mut c_void; + msg.msg_controllen = std::mem::size_of_val(cmsg); + } + + msg.msg_iov = data.as_ptr() as *const _ as *mut libc::iovec; + msg.msg_iovlen = data.len(); + + let ret = unsafe { libc::sendmsg(fd, &msg, flags) }; + match ret { + -1 => Err(io::Error::last_os_error()), + len => Ok(len as usize), + } +} + +/// Use [`libc::recvmsg`] to receive a whole message (with optional control +/// message). +/// +/// This will repeatedly call `recvmsg` until it reaches the end of the current +/// record. +pub(crate) fn recvmsg_whole( + fd: RawFd, + data: &mut Vec, + mut cmsg: Option<&mut Cmsg>, + flags: i32, +) -> io::Result { + if data.capacity() < 16 { + data.reserve(16); + } + + loop { + let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; + if let Some(cmsg) = cmsg.as_deref_mut() { + msg.msg_control = cmsg as *mut _ as *mut c_void; + msg.msg_controllen = std::mem::size_of_val(cmsg); + } + + if data.spare_capacity_mut().is_empty() { + data.reserve(128); + } + + let spare = data.spare_capacity_mut(); + let mut iov = libc::iovec { + iov_base: spare.as_mut_ptr() as *mut c_void, + iov_len: spare.len(), + }; + + msg.msg_iov = &mut iov; + msg.msg_iovlen = 1; + + // SAFETY: We have made sure to initialize msg with valid pointers (or NULL). + let ret = unsafe { libc::recvmsg(fd, &mut msg, flags) }; + let count = match ret { + -1 => return Err(io::Error::last_os_error()), + len => len as usize, + }; + + // SAFETY: recvmsg has just written count to the bytes in the spare capacity of + // the vector. + unsafe { data.set_len(data.len() + count) }; + + if msg.msg_flags & libc::MSG_EOR != 0 { + break Ok(msg.msg_flags); + } + } +} diff --git a/ktls/src/lib.rs b/ktls/src/lib.rs index c428fe5..6b0aadb 100644 --- a/ktls/src/lib.rs +++ b/ktls/src/lib.rs @@ -36,6 +36,10 @@ pub use ktls_stream::KtlsStream; mod cork_stream; pub use cork_stream::CorkStream; +mod client; +mod stream; +mod protocol; + #[derive(Debug, Default)] pub struct CompatibleCiphers { pub tls12: CompatibleCiphersForVersion, diff --git a/ktls/src/protocol.rs b/ktls/src/protocol.rs new file mode 100644 index 0000000..833dcca --- /dev/null +++ b/ktls/src/protocol.rs @@ -0,0 +1,85 @@ +//! TLS protocol enums that are not publically exposed by rustls. + +#![allow(non_upper_case_globals)] + +use std::fmt; + +macro_rules! c_enum { + { + $( #[$attr:meta] )* + $vis:vis enum $name:ident: $repr:ty { + $( + $( #[$vattr:meta] )* + $variant:ident = $value:expr + ),* $(,)? + } + } => { + $( #[$attr] )* + #[repr(transparent)] + $vis struct $name(pub $repr); + + impl $name { + $( + $( #[$vattr] )* + pub const $variant: Self = Self($value); + )* + } + + impl fmt::Debug for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + $( const $variant: $repr = $name::$variant.0; )* + + let text = match self.0 { + $( $variant => concat!(stringify!($name), "::", stringify!($variant)), )* + _ => return f.debug_tuple(stringify!($name)).field(&self.0).finish() + }; + + f.write_str(text) + } + } + + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + $( const $variant: $repr = $name::$variant.0; )* + + let text = match self.0 { + $( $variant => stringify!($variant), )* + _ => return <$repr as fmt::Display>::fmt(&self.0, f) + }; + + f.write_str(text) + } + } + + impl From<$repr> for $name { + fn from(value: $repr) -> Self { + Self(value) + } + } + + impl From<$name> for $repr { + fn from(value: $name) -> Self { + value.0 + } + } + } +} + + +c_enum! { + #[derive(Copy, Clone, Eq, PartialEq)] + pub(crate) enum AlertLevel: u8 { + Warning = 1, + Fatal = 2, + } +} + +c_enum! { + #[derive(Copy, Clone, Eq, PartialEq)] + pub(crate) enum KeyUpdateRequest: u8 { + UpdateNotRequested = 0, + UpdateRequested = 1 + } +} + + diff --git a/ktls/src/stream.rs b/ktls/src/stream.rs new file mode 100644 index 0000000..6fb9d31 --- /dev/null +++ b/ktls/src/stream.rs @@ -0,0 +1,888 @@ +use rustls::{ + AlertDescription, ConnectionTrafficSecrets, ContentType, HandshakeType, InvalidMessage, + PeerMisbehaved, ProtocolVersion, SupportedCipherSuite, +}; +use std::io; +use std::ops::{Deref, DerefMut}; +use std::os::unix::io::AsRawFd; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use rustls::kernel::KernelConnection; + +use crate::ffi::{setup_tls_info, Cmsg, Direction}; +use crate::protocol::{AlertLevel, KeyUpdateRequest}; +use crate::CryptoInfo; + +type KernelClientConnection = KernelConnection; +type KernelServerConnection = KernelConnection; + +struct ConnectionData { + rx_messages_since_last_key_update: u64, + tx_messages_since_last_key_update: u64, + + awaiting_key_update: bool, + + confidentiality_limit: u64, + + conn: Conn, +} + +type ClientConnectionData = ConnectionData; +type ServerConnectionData = ConnectionData; + +enum Side { + Client(Client), + Server(Server), +} + +enum BufferedData { + EarlyData(OffsetVec), + Scratch(Vec), +} + +#[derive(Default)] +struct KTlsStreamState { + write_closed: bool, + read_closed: bool, +} + +trait KTlsConnection: Send + Sync + 'static { + fn as_side(&self) -> Side<&KernelClientConnection, &KernelServerConnection>; + fn as_side_mut(&mut self) -> Side<&mut KernelClientConnection, &mut KernelServerConnection>; + + fn protocol_version(&self) -> ProtocolVersion { + match self.as_side() { + Side::Client(client) => client.protocol_version(), + Side::Server(server) => server.protocol_version(), + } + } + + fn update_rx_secret(&mut self) -> Result<(u64, ConnectionTrafficSecrets), rustls::Error> { + match self.as_side_mut() { + Side::Client(client) => client.update_rx_secret(), + Side::Server(server) => server.update_rx_secret(), + } + } + + fn update_tx_secret(&mut self) -> Result<(u64, ConnectionTrafficSecrets), rustls::Error> { + match self.as_side_mut() { + Side::Client(client) => client.update_tx_secret(), + Side::Server(server) => server.update_tx_secret(), + } + } + + fn negotiated_cipher_suite(&self) -> SupportedCipherSuite { + match self.as_side() { + Side::Client(client) => client.negotiated_cipher_suite(), + Side::Server(server) => server.negotiated_cipher_suite(), + } + } +} + +impl KTlsConnection for KernelClientConnection { + fn as_side(&self) -> Side<&KernelClientConnection, &KernelServerConnection> { + Side::Client(self) + } + + fn as_side_mut(&mut self) -> Side<&mut KernelClientConnection, &mut KernelServerConnection> { + Side::Client(self) + } +} + +impl KTlsConnection for KernelServerConnection { + fn as_side(&self) -> Side<&KernelClientConnection, &KernelServerConnection> { + Side::Server(self) + } + + fn as_side_mut(&mut self) -> Side<&mut KernelClientConnection, &mut KernelServerConnection> { + Side::Server(self) + } +} + +pin_project_lite::pin_project! { + #[project = KTlsStreamProject] + pub(crate) struct KTlsStreamInner { + #[pin] + socket: IO, + data: BufferedData, + state: KTlsStreamState, + + // KernelConnection is quite large so we box it here to avoid excessively + // increasing the size of `KTlsStream`. + conn: Box>, + } +} + +/// Everything in [`KTlsStreamProject`] except `data`. +/// +/// Due to the way we reuse the buffer in `data` we frequently need to be able +/// to borrow "everything except `data`" when implementing handling for control +/// messages. +struct KTlsStreamCoreProject<'a, IO, Conn: ?Sized> { + socket: Pin<&'a mut IO>, + state: &'a mut KTlsStreamState, + conn: &'a mut Box>, +} + +impl KTlsStreamInner { + /// Create a new client stream from a socket and [`KernelConnection`]. + /// + /// This assumes that `socket` has already been initialized as a kTLS + /// socket. + pub(crate) fn new_client(socket: IO, conn: KernelClientConnection) -> Self { + Self::new_inner(socket, Vec::new(), conn) + } +} + +impl KTlsStreamInner { + /// Create a new client stream from a socket and [`KernelConnection`]. + /// + /// This assumes that `socket` has already been initialized as a kTLS + /// socket. If early data was recieved in the handshake, then it should be + /// passed in `early`, otherwise it should be empty. + pub(crate) fn new_server(socket: IO, early: Vec, conn: KernelServerConnection) -> Self { + Self::new_inner(socket, early, conn) + } +} + +impl KTlsStreamInner> { + fn new_inner(socket: IO, early: Vec, conn: KernelConnection) -> Self { + let suite_common = match conn.negotiated_cipher_suite() { + #[cfg(feature = "tls12")] + rustls::SupportedCipherSuite::Tls12(suite) => &suite.common, + rustls::SupportedCipherSuite::Tls13(suite) => &suite.common, + _ => panic!("rustls has feature tls12 enabled but ktls does not"), + }; + + let data = if early.is_empty() { + BufferedData::Scratch(early) + } else { + BufferedData::EarlyData(OffsetVec::new(early)) + }; + + Self { + socket, + data, + state: KTlsStreamState::default(), + conn: Box::new(ConnectionData { + // Use 16 as a safety margin to deal with messages that have + // been sent after the handshake has been established. + rx_messages_since_last_key_update: 16, + tx_messages_since_last_key_update: 16, + awaiting_key_update: false, + confidentiality_limit: suite_common.confidentiality_limit, + + conn, + }), + } + } +} + +impl KTlsStreamInner { + fn read_early_data(buffer: &mut BufferedData, buf: &mut ReadBuf<'_>) -> usize { + let cursor = match buffer { + BufferedData::EarlyData(cursor) => cursor, + _ => return 0, + }; + + let count = cursor.read_buf(buf); + if cursor.is_empty() { + let mut scratch = std::mem::take(cursor).into_cleared_vec(); + scratch.shrink_to(DEFAULT_SCRATCH_CAPACITY); + + *buffer = BufferedData::Scratch(scratch); + } + + count + } +} + +impl KTlsStreamInner +where + IO: AsyncRead + AsyncWrite + AsRawFd, + Conn: KTlsConnection, +{ + pub(crate) fn handle_control_message(self: Pin<&mut Self>) -> io::Result<()> { + let mut this = self.project(); + let (mut core, data) = this.as_core_parts(); + core.handle_control_message(data) + } +} + +impl KTlsStreamProject<'_, IO, Conn> { + fn as_core_parts<'a>( + &'a mut self, + ) -> (KTlsStreamCoreProject<'a, IO, Conn>, &'a mut BufferedData) { + ( + KTlsStreamCoreProject { + socket: self.socket.as_mut(), + state: self.state, + conn: self.conn, + }, + self.data, + ) + } +} + +impl KTlsStreamCoreProject<'_, IO, Conn> +where + IO: AsyncRead + AsyncWrite + AsRawFd, + Conn: KTlsConnection, +{ + fn key_update(&mut self, request: KeyUpdateRequest) -> io::Result<()> { + #[rustfmt::skip] + let message = [ + HandshakeType::KeyUpdate.into(), //typ + 0, 0, 1, // length + request.into() + ]; + + self.send_cmsg(ContentType::Handshake, &[io::IoSlice::new(&message)])?; + + if request == KeyUpdateRequest::UpdateRequested { + self.conn.awaiting_key_update = true; + } + + self.conn.tx_messages_since_last_key_update = 1; + + let (seq, secrets) = match self.conn.update_tx_secret() { + Ok(secrets) => secrets, + Err(e) => { + return Err(self.abort_with_alert( + AlertDescription::InternalError, + KTlsError::KeyUpdateFailed(e), + )); + } + }; + + let crypto = + match CryptoInfo::from_rustls(self.conn.conn.negotiated_cipher_suite(), (seq, secrets)) + { + Ok(crypto) => crypto, + Err(e) => { + let _ = self.abort(AlertDescription::InternalError); + + // This should be impossible. We have already validated + // that the cipher is compatible during connection setup + // so it should not fail now. + panic!("negotiated TLS cipher is no longer compatible for key update: {e}") + } + }; + + if let Err(e) = setup_tls_info(self.socket.as_raw_fd(), Direction::Tx, crypto) { + // The other side of the connection won't be able to decrypt this but it will + // cause them to abort the connection, which is good enough. + let _ = self.abort(AlertDescription::InternalError); + + return Err(e); + } + + Ok(()) + } + + fn handle_read_complete(&mut self, bytes: usize) -> io::Result<()> { + let count = written.div_ceil(TLS_MAX_MESSAGE_LEN); + self.conn.rx_messages_since_last_key_update += count; + + if self.conn.confidentiality_limit == u64::MAX { + return Ok(()); + } + + let hard_limit = self.conn.confidentiality_limit - self.conn.confidentiality_limit / 32; + let soft_limit = self.conn.confidentiality_limit / 2; + + if self.conn.rx_messages_since_last_key_update > hard_limit { + let _ = self.abort(AlertDescription::InternalError); + return Err(io::Error::other( + KTlsStreamError::ConfidentialityLimitReached, + )); + } + + if !self.conn.awaiting_key_update + && self.conn.rx_messages_since_last_key_update > soft_limit + { + // We actually need the peer to update their keys + self.key_update(KeyUpdateRequest::UpdateRequested)?; + } + + Ok(()) + } + + fn handle_write_complete(&mut self, bytes: usize) -> io::Result<()> { + let count = written.div_ceil(TLS_MAX_MESSAGE_LEN); + self.conn.tx_messages_since_last_key_update += count; + + if self.conn.confidentiality_limit == u64::MAX { + return Ok(()); + } + + let hard_limit = self.conn.confidentiality_limit - self.conn.confidentiality_limit / 32; + let soft_limit = self.conn.confidentiality_limit / 2; + + if self.conn.tx_messages_since_last_key_update > hard_limit { + let _ = self.abort(AlertDescription::InternalError); + return Err(io::Error::other( + KTlsStreamError::ConfidentialityLimitReached, + )); + } + + if self.conn.rx_messages_since_last_key_update > soft_limit { + let request = if self.conn.rx_messages_since_last_key_update > soft_limit / 2 { + KeyUpdateRequest::UpdateRequested + } else { + KeyUpdateRequest::UpdateNotRequested + }; + + self.key_update(request)?; + } + + Ok(()) + } + + fn handle_control_message(&mut self, buffered_data: &mut BufferedData) -> io::Result<()> { + if self.state.read_closed { + return Err(io::Error::other(KTlsStreamError::ConnectionShutDown)); + } + + let mut data = match buffered_data { + BufferedData::EarlyData(_) => { + panic!("all buffered application data must be handled before processing control messages") + } + BufferedData::Scratch(data) => ClearOnDrop(data), + }; + + let mut cmsg = Cmsg::new(0, 0, [0]); + + let flags = match crate::ffi::recvmsg_whole( + self.socket.as_raw_fd(), + &mut data, + Some(&mut cmsg), + libc::MSG_DONTWAIT, + ) { + Ok(flags) => flags, + Err(e) if e.raw_os_error() == Some(libc::EAGAIN) => { + // We should only ever get EAGAIN if there is no message available. + assert!( + data.is_empty(), + "recvmsg returned EAGAIN after reading a partial record" + ); + return Ok(()); + } + Err(e) => return Err(e), + }; + + if cmsg.level() != libc::SOL_TLS || cmsg.typ() != libc::TLS_GET_RECORD_TYPE { + panic!( + "recvmsg returned an unexpected control message (level = {}, type = {})", + cmsg.level(), + cmsg.typ() + ); + } + + // This should never happen, since TLS_GET_RECORD_TYPE messages are always 1 + // byte + debug_assert!( + flags & libc::MSG_CTRUNC == 0, + "recvmsg control message was truncated" + ); + + match ContentType::from(cmsg.data()[0]) { + ContentType::ApplicationData => { + // This shouldn't happen in normal operation but can happen when + // users are directly calling handle_control_message. + // + // It's not ideal, but we can handle it. + + let buffer = std::mem::take(&mut *data); + drop(data); + *buffered_data = BufferedData::EarlyData(OffsetVec::new(buffer)); + return Ok(()); + } + + ContentType::Alert => { + let (level, desc) = match &data[..] { + &[level, desc] => (level.into(), desc.into()), + _ => { + // The peer sent an invalid alert. We send back an error + // and close the connection. + return Err(self.abort_with_error( + AlertDescription::DecodeError, + InvalidMessage::MessageTooLarge, + )); + } + }; + + self.handle_alert(level, desc)?; + } + + ContentType::Handshake => { + self.handle_handshake(&data)?; + } + + ContentType::ChangeCipherSpec => { + // ChangeCipherSpec should only be sent under the following conditions: + // - TLS 1.2: during a handshake or a rehandshake + // - TLS 1.3: during a handshake + // + // We don't have to worry about handling messages during a handshake + // and rustls does not support TLS 1.2 rehandshakes so we just emit + // an error here and abort the connection. + return Err(self.abort_with_error( + AlertDescription::UnexpectedMessage, + PeerMisbehaved::IllegalMiddleboxChangeCipherSpec, + )); + } + + // Any other message results in an error + _ => { + return Err(self.abort_with_error( + AlertDescription::UnexpectedMessage, + InvalidMessage::InvalidContentType, + )) + } + } + + Ok(()) + } + + fn handle_alert(&mut self, level: AlertLevel, desc: AlertDescription) -> io::Result<()> { + match desc { + // The peer has closed their end of the connection. We close the read half + // of the connection since we will receive no more data frames. + AlertDescription::CloseNotify => { + self.state.read_closed = true; + } + + // TLS 1.2 allows alerts to be sent with a warning level without terminating + // the connection. In this case we ignore the alert. + _ if self.conn.conn.protocol_version() == ProtocolVersion::TLSv1_2 + && level == AlertLevel::Warning => {} + + // All other alerts are treated as fatal and result in us immediately shutting + // down the connection and emitting an error. + _ => { + self.state.read_closed = true; + self.state.write_closed = true; + + return Err(io::Error::other(KTlsStreamError::Alert(desc))); + } + } + + Ok(()) + } + + fn handle_handshake(&mut self, mut data: &[u8]) -> io::Result<()> { + let mut first = true; + + while !data.is_empty() { + let (ty, len, rest) = match data { + &[ty, a, b, c, ref rest @ ..] => ( + HandshakeType::from(ty), + u32::from_be_bytes([0, a, b, c]) as usize, + rest, + ), + _ => { + return Err(self.abort_with_error( + AlertDescription::DecodeError, + InvalidMessage::MessageTooShort, + )) + } + }; + + if len > rest.len() { + return Err(self.abort_with_error( + AlertDescription::DecodeError, + InvalidMessage::MessageTooShort, + )); + } + + let (msg, rest) = rest.split_at(len); + data = rest; + + // KeyUpdate messages must be the only sub-message within their message. + if ty == HandshakeType::KeyUpdate + && self.conn.protocol_version() == ProtocolVersion::TLSv1_3 + { + if !first || !data.is_empty() { + return Err(self.abort_with_alert( + AlertDescription::UnexpectedMessage, + PeerMisbehaved::KeyEpochWithPendingFragment, + )); + } + } + + self.handle_single_handshake(typ, msg)?; + first = false; + } + + Ok(()) + } + + fn handle_single_handshake(&mut self, typ: HandshakeType, data: &[u8]) -> io::Result<()> { + match typ { + HandshakeType::KeyUpdate + if self.conn.conn.protocol_version() == ProtocolVersion::TLSv1_3 => + { + let req = match data { + &[req] => KeyUpdateRequest::from(req), + _ => { + return Err(self.abort_with_error( + AlertDescription::DecodeError, + InvalidMessage::InvalidKeyUpdate, + )) + } + }; + + let (seq, secrets) = match self.conn.conn.update_rx_secret() { + Ok(secrets) => secrets, + Err(e) => { + return Err(self.abort_with_error( + AlertDescription::InternalError, + KTlsStreamError::KeyUpdateFailed(e), + )) + } + }; + + let crypto = match CryptoInfo::from_rustls( + self.conn.conn.negotiated_cipher_suite(), + (seq, secrets), + ) { + Ok(crypto) => crypto, + Err(e) => { + let _ = self.abort(AlertDescription::InternalError); + + // This should be impossible. We have already validated + // that the cipher is compatible during connection setup + // so it should not fail now. + panic!("negotiated TLS cipher is no longer compatible for key update: {e}") + } + }; + + if let Err(e) = setup_tls_info(self.socket.as_raw_fd(), Direction::Rx, crypto) { + // If setup_tls_info fails then the connection is done for, + // so we just an alert. + let _ = self.abort(AlertDescription::InternalError); + return Err(e); + } + + match req { + KeyUpdateRequest::UpdateNotRequested => return Ok(()), + KeyUpdateRequest::UpdateRequested => (), + _ => { + return Err(self.abort_with_error( + AlertDescription::IllegalParameter, + InvalidMessage::InvalidKeyUpdate, + )); + } + } + + self.key_update(KeyUpdateRequest::UpdateNotRequested)?; + } + + HandshakeType::NewSessionTicket + if self.conn.conn.protocol_version() == ProtocolVersion::TLSv1_3 => + { + match self.conn.conn.as_side() { + Side::Client(conn) => match conn.handle_new_session_ticket(data) { + Ok(()) => (), + // Convert some messages into their higher-level equivalents + Err(rustls::Error::InvalidMessage(err)) => { + return Err(self.abort_with_error(AlertDescription::DecodeError, err)); + } + Err(rustls::Error::PeerMisbehaved(err)) => { + return Err( + self.abort_with_error(AlertDescription::UnexpectedMessage, err) + ); + } + + // Other errors are not necessarily fatal + Err(e) => return Err(KTlsStreamError::SessionTicketFailed(e)), + }, + Side::Server(_) => { + return Err(self.abort_with_error( + AlertDescription::UnexpectedMessage, + InvalidMessage::UnexpectedMessage( + "TLS 1.2 peer sent a TLS 1.3 NewSessionTicket message", + ), + )) + } + } + } + + _ => { + return match self.conn.conn.protocol_version() { + ProtocolVersion::TLSv1_3 => self.abort_with_error( + AlertDescription::UnexpectedMessage, + InvalidMessage::UnexpectedMessage( + "expected KeyUpdate or NewSessionTicket handshake messages only", + ), + ), + _ => self.abort_with_error( + AlertDescription::UnexpectedMessage, + InvalidMessage::UnexpectedMessage( + "handshake messages are not expected on TLS 1.2 connections", + ), + ), + } + } + } + + Ok(()) + } + + fn abort(&mut self, alert: AlertDescription) -> io::Result<()> { + let write_closed = self.state.write_closed; + + self.state.read_closed = true; + self.state.write_closed = true; + + if !write_closed { + self.send_alert(AlertLevel::Fatal, alert)?; + } + + Ok(()) + } + + fn abort_with_error( + &mut self, + alert: AlertDescription, + error: impl Into, + ) -> io::Error { + // We don't propagate any errors here since we already have an existing error. + let _ = self.abort(alert); + + io::Error::other(error.into()) + } + + fn send_alert(&self, level: AlertLevel, desc: AlertDescription) -> io::Result<()> { + let message = [level.into(), desc.into()]; + let iov = [io::IoSlice::new(&message)]; + + self.send_cmsg(ContentType::Alert, &iov) + } + + fn shutdown(&self) -> io::Result<()> { + self.state.write_closed = true; + self.send_alert(AlertLevel::Warning, AlertDescription::CloseNotify)?; + Ok(()) + } + + fn send_cmsg(&self, typ: ContentType, data: &[io::IoSlice<'_>]) -> io::Result<()> { + self.conn.tx_messages_since_last_key_update += 1; + + let cmsg = Cmsg::new(libc::SOL_TLS, libc::TLS_SET_RECORD_TYPE, [typ.into()]); + // TODO: Should an error here abort the whole connection? + crate::ffi::sendmsg(self.socket.as_raw_fd(), data, Some(&cmsg), 0)?; + Ok(()) + } +} + +impl AsyncRead for KTlsStreamInner +where + IO: AsyncRead + AsyncWrite + AsRawFd, + Conn: KTlsConnection, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let mut this = self.project(); + + if matches!(this.data, BufferedData::EarlyData(_)) { + match Self::read_early_data(this.data, buf) { + 0 => (), + _ => return Poll::Ready(Ok(())), + } + } + + // We want to gracefully handle control messages, but we don't want to + // hold up the task if there are lots of them. + for _ in 0..4 { + if this.state.read_closed { + return Poll::Ready(Ok(())); + } + + let start = buf.filled().len(); + match this.socket.as_mut().poll_read(cx, buf) { + // Linux returns EIO when there is a control message to be read + // but there is no CMsg space to write to. + // + // If we get this as an error it means there is a control message + // that we need to handle. + Poll::Ready(Err(e)) if e.raw_os_error() == Some(libc::EIO) => (), + poll @ Poll::Ready(Ok(())) => { + let end = buf.filled().len(); + let written = end.checked_sub(start).unwrap_or(buf.capacity()); + + this.as_core_parts().0.handle_read_complete(written)?; + } + poll => return poll, + } + + let (mut core, data) = this.as_core_parts(); + core.handle_control_message(data)?; + } + + cx.waker().wake_by_ref(); + Poll::Pending + } +} + +impl AsyncWrite for KTlsStreamInner +where + IO: AsyncRead + AsyncWrite + AsRawFd, + Conn: KTlsConnection, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + + if this.state.write_closed { + return Poll::Ready(Ok(0)); + } + + match this.socket.poll_write(cx, buf) { + poll @ Poll::Ready(Ok(bytes)) => { + this.as_core_parts().0.handle_write_complete(bytes)?; + } + poll => poll, + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let mut this = self.project(); + + if this.state.write_closed { + return Poll::Ready(Ok(0)); + } + + match this.socket.poll_write_vectored(cx, buf) { + poll @ Poll::Ready(Ok(bytes)) => { + this.as_core_parts().0.handle_write_complete(bytes)?; + } + poll => poll, + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + this.socket.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.state.write_closed { + if let Err(e) = this.as_core_parts().0.shutdown() { + return Poll::Ready(Err(e)); + } + } + + this.socket.poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.socket.is_write_vectored() + } +} + +#[derive(Default)] +struct OffsetVec { + data: Vec, + offset: usize, +} + +impl OffsetVec { + pub fn new(data: Vec) -> Self { + Self { data, offset: 0 } + } + + pub fn is_empty(&self) -> bool { + self.offset == self.data.len() + } + + pub fn into_cleared_vec(mut self) -> Vec { + self.data.clear(); + self.data + } + + pub fn read_buf(&mut self, buf: &mut ReadBuf<'_>) -> usize { + let tail = &self.data[self.offset..]; + let removed = &tail[..tail.len().min(buf.remaining())]; + buf.put_slice(removed); + self.offset += removed.len(); + removed.len() + } +} + +struct ClearOnDrop<'a>(&'a mut Vec); + +impl Deref for ClearOnDrop<'_> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl DerefMut for ClearOnDrop<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.0 + } +} + +impl Drop for ClearOnDrop<'_> { + fn drop(&mut self) { + self.0.clear(); + } +} + +const DEFAULT_SCRATCH_CAPACITY: usize = 256; +const TLS_MAX_MESSAGE_LEN: usize = 1 << 14; + +#[derive(Debug, thiserror::Error)] +enum KTlsStreamError { + #[error("received corrupt message of type {0:?}")] + InvalidMessage(InvalidMessage), + + #[error("peer misbehaved: {0:?}")] + PeerMisbehaved(PeerMisbehaved), + + #[error("{0}")] + KeyUpdateFailed(#[source] rustls::Error), + + #[error("failed to handle a provided session ticket: {0}")] + SessionTicketFailed(#[source] rustls::Error), + + #[error("the connection has been shut down")] + ConnectionShutDown, + + #[error("the connection has reached its confidentiality limit and has been shut down")] + ConfidentialityLimitReached, + + #[error("connection peer closed the connection with an alert: {0:?}")] + Alert(AlertDescription), +} + +impl From for KTlsStreamError { + fn from(error: InvalidMessage) -> Self { + Self::InvalidMessage(error) + } +} + +impl From for KTlsStreamError { + fn from(error: PeerMisbehaved) -> Self { + Self::PeerMisbehaved(error) + } +} From 7640230258e5bb2f289bcaad00f7946bcb176a34 Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Mon, 19 May 2025 16:52:44 -0700 Subject: [PATCH 2/8] Add generic KTlsStreamImpl type This type contains all the tricky logic WRT to actually implementing the TLS protocol on top of kTLS. Everything else can be a (mostly) straightforward wrapper around it. --- ktls/src/ffi.rs | 6 +- ktls/src/lib.rs | 9 +- ktls/src/stream.rs | 717 +++++++++++++++++++-------------------------- 3 files changed, 314 insertions(+), 418 deletions(-) diff --git a/ktls/src/ffi.rs b/ktls/src/ffi.rs index be36e9d..cb29f0a 100644 --- a/ktls/src/ffi.rs +++ b/ktls/src/ffi.rs @@ -229,12 +229,10 @@ impl CryptoInfo { } } -pub fn setup_tls_info(fd: RawFd, dir: Direction, info: CryptoInfo) -> Result<(), crate::Error> { +pub fn setup_tls_info(fd: RawFd, dir: Direction, info: CryptoInfo) -> io::Result<()> { let ret = unsafe { libc::setsockopt(fd, SOL_TLS, dir.into(), info.as_ptr(), info.size() as _) }; if ret < 0 { - return Err(crate::Error::TlsCryptoInfoError( - std::io::Error::last_os_error(), - )); + return Err(std::io::Error::last_os_error()); } Ok(()) } diff --git a/ktls/src/lib.rs b/ktls/src/lib.rs index 6b0aadb..4847c05 100644 --- a/ktls/src/lib.rs +++ b/ktls/src/lib.rs @@ -36,9 +36,8 @@ pub use ktls_stream::KtlsStream; mod cork_stream; pub use cork_stream::CorkStream; -mod client; -mod stream; mod protocol; +mod stream; #[derive(Debug, Default)] pub struct CompatibleCiphers { @@ -216,7 +215,7 @@ fn sample_cipher_setup(sock: &TcpStream, cipher_suite: SupportedCipherSuite) -> setup_ulp(fd).map_err(Error::UlpError)?; - setup_tls_info(fd, ffi::Direction::Tx, crypto_info)?; + setup_tls_info(fd, ffi::Direction::Tx, crypto_info).map_err(Error::TlsCryptoInfoError)?; Ok(()) } @@ -340,10 +339,10 @@ fn setup_inner(fd: RawFd, conn: Connection) -> Result<(), Error> { ffi::setup_ulp(fd).map_err(Error::UlpError)?; let tx = CryptoInfo::from_rustls(cipher_suite, secrets.tx)?; - setup_tls_info(fd, ffi::Direction::Tx, tx)?; + setup_tls_info(fd, ffi::Direction::Tx, tx).map_err(Error::TlsCryptoInfoError)?; let rx = CryptoInfo::from_rustls(cipher_suite, secrets.rx)?; - setup_tls_info(fd, ffi::Direction::Rx, rx)?; + setup_tls_info(fd, ffi::Direction::Rx, rx).map_err(Error::TlsCryptoInfoError)?; Ok(()) } diff --git a/ktls/src/stream.rs b/ktls/src/stream.rs index 6fb9d31..bead648 100644 --- a/ktls/src/stream.rs +++ b/ktls/src/stream.rs @@ -1,15 +1,15 @@ -use rustls::{ - AlertDescription, ConnectionTrafficSecrets, ContentType, HandshakeType, InvalidMessage, - PeerMisbehaved, ProtocolVersion, SupportedCipherSuite, -}; use std::io; use std::ops::{Deref, DerefMut}; -use std::os::unix::io::AsRawFd; +use std::os::fd::AsRawFd; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use rustls::kernel::KernelConnection; +use rustls::{ + AlertDescription, ConnectionTrafficSecrets, ContentType, HandshakeType, InvalidMessage, + PeerMisbehaved, ProtocolVersion, SupportedCipherSuite, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::ffi::{setup_tls_info, Cmsg, Direction}; use crate::protocol::{AlertLevel, KeyUpdateRequest}; @@ -18,219 +18,142 @@ use crate::CryptoInfo; type KernelClientConnection = KernelConnection; type KernelServerConnection = KernelConnection; -struct ConnectionData { - rx_messages_since_last_key_update: u64, - tx_messages_since_last_key_update: u64, +pin_project_lite::pin_project! { + #[project = KTlsStreamProject] + pub(crate) struct KTlsStreamImpl { + #[pin] + socket: IO, + state: StreamState, + data: StreamData, + } +} - awaiting_key_update: bool, +impl KTlsStreamProject<'_, IO, Conn> +where + IO: AsyncRead + AsyncWrite + AsRawFd, + Conn: ?Sized, + StreamData: StreamSide, +{ + fn poll_read(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + if self.state.early_data() { + if self.poll_read_early_data(buf) != 0 { + return Poll::Ready(Ok(())); + } + } - confidentiality_limit: u64, + for _ in 0..4 { + if self.state.read_closed() { + return Poll::Ready(Ok(())); + } - conn: Conn, -} + match self.socket.as_mut().poll_read(cx, buf) { + // Linux returns EIO when there is a control message to be read + // but there is no CMsg space to write to. + // + // If we get this as an error it means there is a control message + // that we need to handle. + Poll::Ready(Err(e)) if e.raw_os_error() == Some(libc::EIO) => (), + poll => return poll, + } -type ClientConnectionData = ConnectionData; -type ServerConnectionData = ConnectionData; + self.handle_control_message()?; + } -enum Side { - Client(Client), - Server(Server), -} + // We've already handled multiple control messages with this poll, yield + // for now but arrange to be woken up right away. + cx.waker().wake_by_ref(); + Poll::Pending + } -enum BufferedData { - EarlyData(OffsetVec), - Scratch(Vec), -} + fn poll_read_early_data(&mut self, buf: &mut ReadBuf<'_>) -> usize { + let data = &self.data.buffer[self.data.offset..]; -#[derive(Default)] -struct KTlsStreamState { - write_closed: bool, - read_closed: bool, -} + let available = buf.remaining(); + let data = &data[..available.min(data.len())]; + buf.put_slice(data); -trait KTlsConnection: Send + Sync + 'static { - fn as_side(&self) -> Side<&KernelClientConnection, &KernelServerConnection>; - fn as_side_mut(&mut self) -> Side<&mut KernelClientConnection, &mut KernelServerConnection>; + let len = data.len(); + self.data.offset += data.len(); + if self.data.offset == self.data.buffer.len() { + self.data.buffer.clear(); + self.data.offset = 0; + self.state.0 &= !StreamState::EARLY_DATA; - fn protocol_version(&self) -> ProtocolVersion { - match self.as_side() { - Side::Client(client) => client.protocol_version(), - Side::Server(server) => server.protocol_version(), + self.data.buffer.shrink_to(MAX_SCRATCH_CAPACITY); } - } - fn update_rx_secret(&mut self) -> Result<(u64, ConnectionTrafficSecrets), rustls::Error> { - match self.as_side_mut() { - Side::Client(client) => client.update_rx_secret(), - Side::Server(server) => server.update_rx_secret(), - } + len } - fn update_tx_secret(&mut self) -> Result<(u64, ConnectionTrafficSecrets), rustls::Error> { - match self.as_side_mut() { - Side::Client(client) => client.update_tx_secret(), - Side::Server(server) => server.update_tx_secret(), + fn poll_write(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + if self.state.contains(StreamState::PENDING_CLOSE) { + std::task::ready!(self.poll_do_close(cx))?; } - } - fn negotiated_cipher_suite(&self) -> SupportedCipherSuite { - match self.as_side() { - Side::Client(client) => client.negotiated_cipher_suite(), - Side::Server(server) => server.negotiated_cipher_suite(), + if self.state.write_closed() { + return Poll::Ready(Ok(0)); } - } -} -impl KTlsConnection for KernelClientConnection { - fn as_side(&self) -> Side<&KernelClientConnection, &KernelServerConnection> { - Side::Client(self) + self.socket.as_mut().poll_write(cx, buf) } - fn as_side_mut(&mut self) -> Side<&mut KernelClientConnection, &mut KernelServerConnection> { - Side::Client(self) - } -} - -impl KTlsConnection for KernelServerConnection { - fn as_side(&self) -> Side<&KernelClientConnection, &KernelServerConnection> { - Side::Server(self) - } - - fn as_side_mut(&mut self) -> Side<&mut KernelClientConnection, &mut KernelServerConnection> { - Side::Server(self) - } -} + fn poll_write_vectored( + &mut self, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + if self.state.contains(StreamState::PENDING_CLOSE) { + std::task::ready!(self.poll_do_close(cx))?; + } -pin_project_lite::pin_project! { - #[project = KTlsStreamProject] - pub(crate) struct KTlsStreamInner { - #[pin] - socket: IO, - data: BufferedData, - state: KTlsStreamState, + if self.state.write_closed() { + return Poll::Ready(Ok(0)); + } - // KernelConnection is quite large so we box it here to avoid excessively - // increasing the size of `KTlsStream`. - conn: Box>, + self.socket.as_mut().poll_write_vectored(cx, bufs) } -} -/// Everything in [`KTlsStreamProject`] except `data`. -/// -/// Due to the way we reuse the buffer in `data` we frequently need to be able -/// to borrow "everything except `data`" when implementing handling for control -/// messages. -struct KTlsStreamCoreProject<'a, IO, Conn: ?Sized> { - socket: Pin<&'a mut IO>, - state: &'a mut KTlsStreamState, - conn: &'a mut Box>, -} + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state.contains(StreamState::PENDING_CLOSE) { + std::task::ready!(self.poll_do_close(cx))?; + } -impl KTlsStreamInner { - /// Create a new client stream from a socket and [`KernelConnection`]. - /// - /// This assumes that `socket` has already been initialized as a kTLS - /// socket. - pub(crate) fn new_client(socket: IO, conn: KernelClientConnection) -> Self { - Self::new_inner(socket, Vec::new(), conn) - } -} + if self.state.write_closed() { + return Poll::Ready(Ok(())); + } -impl KTlsStreamInner { - /// Create a new client stream from a socket and [`KernelConnection`]. - /// - /// This assumes that `socket` has already been initialized as a kTLS - /// socket. If early data was recieved in the handshake, then it should be - /// passed in `early`, otherwise it should be empty. - pub(crate) fn new_server(socket: IO, early: Vec, conn: KernelServerConnection) -> Self { - Self::new_inner(socket, early, conn) + self.socket.as_mut().poll_flush(cx) } -} -impl KTlsStreamInner> { - fn new_inner(socket: IO, early: Vec, conn: KernelConnection) -> Self { - let suite_common = match conn.negotiated_cipher_suite() { - #[cfg(feature = "tls12")] - rustls::SupportedCipherSuite::Tls12(suite) => &suite.common, - rustls::SupportedCipherSuite::Tls13(suite) => &suite.common, - _ => panic!("rustls has feature tls12 enabled but ktls does not"), - }; - - let data = if early.is_empty() { - BufferedData::Scratch(early) - } else { - BufferedData::EarlyData(OffsetVec::new(early)) - }; - - Self { - socket, - data, - state: KTlsStreamState::default(), - conn: Box::new(ConnectionData { - // Use 16 as a safety margin to deal with messages that have - // been sent after the handshake has been established. - rx_messages_since_last_key_update: 16, - tx_messages_since_last_key_update: 16, - awaiting_key_update: false, - confidentiality_limit: suite_common.confidentiality_limit, - - conn, - }), + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state.contains(StreamState::PENDING_CLOSE) { + std::task::ready!(self.poll_do_close(cx))?; } + + self.state.0 |= StreamState::WRITE_CLOSED; + self.socket.as_mut().poll_shutdown(cx) } -} -impl KTlsStreamInner { - fn read_early_data(buffer: &mut BufferedData, buf: &mut ReadBuf<'_>) -> usize { - let cursor = match buffer { - BufferedData::EarlyData(cursor) => cursor, - _ => return 0, - }; + fn poll_do_close(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.socket.as_mut().poll_flush(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => { + self.state.0 &= !StreamState::PENDING_CLOSE; - let count = cursor.read_buf(buf); - if cursor.is_empty() { - let mut scratch = std::mem::take(cursor).into_cleared_vec(); - scratch.shrink_to(DEFAULT_SCRATCH_CAPACITY); + if result.is_ok() { + if let Err(e) = + self.send_alert(AlertLevel::Warning, AlertDescription::CloseNotify) + { + return Poll::Ready(Err(e)); + } + } - *buffer = BufferedData::Scratch(scratch); + self.state.0 |= StreamState::WRITE_CLOSED; + Poll::Ready(result) + } } - - count - } -} - -impl KTlsStreamInner -where - IO: AsyncRead + AsyncWrite + AsRawFd, - Conn: KTlsConnection, -{ - pub(crate) fn handle_control_message(self: Pin<&mut Self>) -> io::Result<()> { - let mut this = self.project(); - let (mut core, data) = this.as_core_parts(); - core.handle_control_message(data) } -} -impl KTlsStreamProject<'_, IO, Conn> { - fn as_core_parts<'a>( - &'a mut self, - ) -> (KTlsStreamCoreProject<'a, IO, Conn>, &'a mut BufferedData) { - ( - KTlsStreamCoreProject { - socket: self.socket.as_mut(), - state: self.state, - conn: self.conn, - }, - self.data, - ) - } -} - -impl KTlsStreamCoreProject<'_, IO, Conn> -where - IO: AsyncRead + AsyncWrite + AsRawFd, - Conn: KTlsConnection, -{ fn key_update(&mut self, request: KeyUpdateRequest) -> io::Result<()> { #[rustfmt::skip] let message = [ @@ -241,25 +164,18 @@ where self.send_cmsg(ContentType::Handshake, &[io::IoSlice::new(&message)])?; - if request == KeyUpdateRequest::UpdateRequested { - self.conn.awaiting_key_update = true; - } - - self.conn.tx_messages_since_last_key_update = 1; - - let (seq, secrets) = match self.conn.update_tx_secret() { + let (seq, secrets) = match self.data.update_tx_secret() { Ok(secrets) => secrets, Err(e) => { - return Err(self.abort_with_alert( + return Err(self.abort_with_error( AlertDescription::InternalError, - KTlsError::KeyUpdateFailed(e), + KTlsStreamError::KeyUpdateFailed(e), )); } }; let crypto = - match CryptoInfo::from_rustls(self.conn.conn.negotiated_cipher_suite(), (seq, secrets)) - { + match CryptoInfo::from_rustls(self.data.negotiated_cipher_suite(), (seq, secrets)) { Ok(crypto) => crypto, Err(e) => { let _ = self.abort(AlertDescription::InternalError); @@ -282,79 +198,28 @@ where Ok(()) } - fn handle_read_complete(&mut self, bytes: usize) -> io::Result<()> { - let count = written.div_ceil(TLS_MAX_MESSAGE_LEN); - self.conn.rx_messages_since_last_key_update += count; - - if self.conn.confidentiality_limit == u64::MAX { - return Ok(()); - } - - let hard_limit = self.conn.confidentiality_limit - self.conn.confidentiality_limit / 32; - let soft_limit = self.conn.confidentiality_limit / 2; - - if self.conn.rx_messages_since_last_key_update > hard_limit { - let _ = self.abort(AlertDescription::InternalError); - return Err(io::Error::other( - KTlsStreamError::ConfidentialityLimitReached, - )); - } + fn handle_control_message(&mut self) -> io::Result<()> { + let mut take = TakeBuffer::new(self); + let (this, data) = take.as_parts_mut(); - if !self.conn.awaiting_key_update - && self.conn.rx_messages_since_last_key_update > soft_limit - { - // We actually need the peer to update their keys - self.key_update(KeyUpdateRequest::UpdateRequested)?; - } - - Ok(()) + this.handle_control_message_impl(data) } - fn handle_write_complete(&mut self, bytes: usize) -> io::Result<()> { - let count = written.div_ceil(TLS_MAX_MESSAGE_LEN); - self.conn.tx_messages_since_last_key_update += count; - - if self.conn.confidentiality_limit == u64::MAX { - return Ok(()); + fn handle_control_message_impl(&mut self, buffer: &mut Vec) -> io::Result<()> { + if self.state.read_closed() { + return Err(io::Error::other(KTlsStreamError::Closed)) } - let hard_limit = self.conn.confidentiality_limit - self.conn.confidentiality_limit / 32; - let soft_limit = self.conn.confidentiality_limit / 2; - - if self.conn.tx_messages_since_last_key_update > hard_limit { - let _ = self.abort(AlertDescription::InternalError); - return Err(io::Error::other( - KTlsStreamError::ConfidentialityLimitReached, - )); + // We reuse the early data buffer to read the control message so it is + // an error to attempt to do so without having handled all the early + // data beforehand. + if self.state.early_data() { + return Err(io::Error::other(KTlsStreamError::ControlMessageWithBufferedData)); } - if self.conn.rx_messages_since_last_key_update > soft_limit { - let request = if self.conn.rx_messages_since_last_key_update > soft_limit / 2 { - KeyUpdateRequest::UpdateRequested - } else { - KeyUpdateRequest::UpdateNotRequested - }; - - self.key_update(request)?; - } - - Ok(()) - } - - fn handle_control_message(&mut self, buffered_data: &mut BufferedData) -> io::Result<()> { - if self.state.read_closed { - return Err(io::Error::other(KTlsStreamError::ConnectionShutDown)); - } - - let mut data = match buffered_data { - BufferedData::EarlyData(_) => { - panic!("all buffered application data must be handled before processing control messages") - } - BufferedData::Scratch(data) => ClearOnDrop(data), - }; + let mut data = ClearOnDrop(buffer); let mut cmsg = Cmsg::new(0, 0, [0]); - let flags = match crate::ffi::recvmsg_whole( self.socket.as_raw_fd(), &mut data, @@ -395,9 +260,9 @@ where // // It's not ideal, but we can handle it. - let buffer = std::mem::take(&mut *data); - drop(data); - *buffered_data = BufferedData::EarlyData(OffsetVec::new(buffer)); + std::mem::forget(data); + self.state.0 |= StreamState::EARLY_DATA; + return Ok(()); } @@ -452,20 +317,18 @@ where // The peer has closed their end of the connection. We close the read half // of the connection since we will receive no more data frames. AlertDescription::CloseNotify => { - self.state.read_closed = true; + self.state.0 |= StreamState::READ_CLOSED; } // TLS 1.2 allows alerts to be sent with a warning level without terminating // the connection. In this case we ignore the alert. - _ if self.conn.conn.protocol_version() == ProtocolVersion::TLSv1_2 + _ if self.data.protocol_version() == ProtocolVersion::TLSv1_2 && level == AlertLevel::Warning => {} // All other alerts are treated as fatal and result in us immediately shutting // down the connection and emitting an error. _ => { - self.state.read_closed = true; - self.state.write_closed = true; - + self.state.0 = StreamState::CLOSED; return Err(io::Error::other(KTlsStreamError::Alert(desc))); } } @@ -503,17 +366,17 @@ where // KeyUpdate messages must be the only sub-message within their message. if ty == HandshakeType::KeyUpdate - && self.conn.protocol_version() == ProtocolVersion::TLSv1_3 + && self.data.protocol_version() == ProtocolVersion::TLSv1_3 { if !first || !data.is_empty() { - return Err(self.abort_with_alert( + return Err(self.abort_with_error( AlertDescription::UnexpectedMessage, PeerMisbehaved::KeyEpochWithPendingFragment, )); } } - self.handle_single_handshake(typ, msg)?; + self.handle_single_handshake(ty, msg)?; first = false; } @@ -523,7 +386,7 @@ where fn handle_single_handshake(&mut self, typ: HandshakeType, data: &[u8]) -> io::Result<()> { match typ { HandshakeType::KeyUpdate - if self.conn.conn.protocol_version() == ProtocolVersion::TLSv1_3 => + if self.data.protocol_version() == ProtocolVersion::TLSv1_3 => { let req = match data { &[req] => KeyUpdateRequest::from(req), @@ -535,7 +398,7 @@ where } }; - let (seq, secrets) = match self.conn.conn.update_rx_secret() { + let (seq, secrets) = match self.data.update_rx_secret() { Ok(secrets) => secrets, Err(e) => { return Err(self.abort_with_error( @@ -546,7 +409,7 @@ where }; let crypto = match CryptoInfo::from_rustls( - self.conn.conn.negotiated_cipher_suite(), + self.data.negotiated_cipher_suite(), (seq, secrets), ) { Ok(crypto) => crypto, @@ -582,10 +445,10 @@ where } HandshakeType::NewSessionTicket - if self.conn.conn.protocol_version() == ProtocolVersion::TLSv1_3 => + if self.data.protocol_version() == ProtocolVersion::TLSv1_3 => { - match self.conn.conn.as_side() { - Side::Client(conn) => match conn.handle_new_session_ticket(data) { + match self.data.as_side_mut() { + Side::Client(conn) => match conn.conn.handle_new_session_ticket(data) { Ok(()) => (), // Convert some messages into their higher-level equivalents Err(rustls::Error::InvalidMessage(err)) => { @@ -598,7 +461,9 @@ where } // Other errors are not necessarily fatal - Err(e) => return Err(KTlsStreamError::SessionTicketFailed(e)), + Err(e) => { + return Err(io::Error::other(KTlsStreamError::SessionTicketFailed(e))) + } }, Side::Server(_) => { return Err(self.abort_with_error( @@ -612,7 +477,7 @@ where } _ => { - return match self.conn.conn.protocol_version() { + return Err(match self.data.protocol_version() { ProtocolVersion::TLSv1_3 => self.abort_with_error( AlertDescription::UnexpectedMessage, InvalidMessage::UnexpectedMessage( @@ -625,7 +490,7 @@ where "handshake messages are not expected on TLS 1.2 connections", ), ), - } + }) } } @@ -633,10 +498,8 @@ where } fn abort(&mut self, alert: AlertDescription) -> io::Result<()> { - let write_closed = self.state.write_closed; - - self.state.read_closed = true; - self.state.write_closed = true; + let write_closed = self.state.write_closed(); + self.state.0 = StreamState::WRITE_CLOSED | StreamState::READ_CLOSED; if !write_closed { self.send_alert(AlertLevel::Fatal, alert)?; @@ -663,15 +526,7 @@ where self.send_cmsg(ContentType::Alert, &iov) } - fn shutdown(&self) -> io::Result<()> { - self.state.write_closed = true; - self.send_alert(AlertLevel::Warning, AlertDescription::CloseNotify)?; - Ok(()) - } - fn send_cmsg(&self, typ: ContentType, data: &[io::IoSlice<'_>]) -> io::Result<()> { - self.conn.tx_messages_since_last_key_update += 1; - let cmsg = Cmsg::new(libc::SOL_TLS, libc::TLS_SET_RECORD_TYPE, [typ.into()]); // TODO: Should an error here abort the whole connection? crate::ffi::sendmsg(self.socket.as_raw_fd(), data, Some(&cmsg), 0)?; @@ -679,117 +534,49 @@ where } } -impl AsyncRead for KTlsStreamInner +impl AsyncRead for KTlsStreamImpl where IO: AsyncRead + AsyncWrite + AsRawFd, - Conn: KTlsConnection, + Conn: ?Sized, + StreamData: StreamSide, { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut this = self.project(); - - if matches!(this.data, BufferedData::EarlyData(_)) { - match Self::read_early_data(this.data, buf) { - 0 => (), - _ => return Poll::Ready(Ok(())), - } - } - - // We want to gracefully handle control messages, but we don't want to - // hold up the task if there are lots of them. - for _ in 0..4 { - if this.state.read_closed { - return Poll::Ready(Ok(())); - } - - let start = buf.filled().len(); - match this.socket.as_mut().poll_read(cx, buf) { - // Linux returns EIO when there is a control message to be read - // but there is no CMsg space to write to. - // - // If we get this as an error it means there is a control message - // that we need to handle. - Poll::Ready(Err(e)) if e.raw_os_error() == Some(libc::EIO) => (), - poll @ Poll::Ready(Ok(())) => { - let end = buf.filled().len(); - let written = end.checked_sub(start).unwrap_or(buf.capacity()); - - this.as_core_parts().0.handle_read_complete(written)?; - } - poll => return poll, - } - - let (mut core, data) = this.as_core_parts(); - core.handle_control_message(data)?; - } - - cx.waker().wake_by_ref(); - Poll::Pending + self.project().poll_read(cx, buf) } } -impl AsyncWrite for KTlsStreamInner +impl AsyncWrite for KTlsStreamImpl where IO: AsyncRead + AsyncWrite + AsRawFd, - Conn: KTlsConnection, + Conn: ?Sized, + StreamData: StreamSide, { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { - let mut this = self.project(); - - if this.state.write_closed { - return Poll::Ready(Ok(0)); - } - - match this.socket.poll_write(cx, buf) { - poll @ Poll::Ready(Ok(bytes)) => { - this.as_core_parts().0.handle_write_complete(bytes)?; - } - poll => poll, - } + ) -> Poll> { + self.project().poll_write(cx, buf) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], - ) -> Poll> { - let mut this = self.project(); - - if this.state.write_closed { - return Poll::Ready(Ok(0)); - } - - match this.socket.poll_write_vectored(cx, buf) { - poll @ Poll::Ready(Ok(bytes)) => { - this.as_core_parts().0.handle_write_complete(bytes)?; - } - poll => poll, - } + ) -> Poll> { + self.project().poll_write_vectored(cx, bufs) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - this.socket.poll_flush(cx) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().poll_flush(cx) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - if !this.state.write_closed { - if let Err(e) = this.as_core_parts().0.shutdown() { - return Poll::Ready(Err(e)); - } - } - - this.socket.poll_shutdown(cx) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().poll_shutdown(cx) } fn is_write_vectored(&self) -> bool { @@ -797,62 +584,123 @@ where } } -#[derive(Default)] -struct OffsetVec { - data: Vec, +pub(crate) struct StreamData { + /// This buffer is used to store early data and also as a buffer to store + /// received control messages. + buffer: Vec, offset: usize, + + conn: Conn, } -impl OffsetVec { - pub fn new(data: Vec) -> Self { - Self { data, offset: 0 } +impl StreamData +where + Self: StreamSide, + Conn: ?Sized, +{ + fn protocol_version(&self) -> ProtocolVersion { + match self.as_side() { + Side::Client(client) => client.conn.protocol_version(), + Side::Server(server) => server.conn.protocol_version(), + } } - pub fn is_empty(&self) -> bool { - self.offset == self.data.len() + fn negotiated_cipher_suite(&self) -> SupportedCipherSuite { + match self.as_side() { + Side::Client(client) => client.conn.negotiated_cipher_suite(), + Side::Server(server) => server.conn.negotiated_cipher_suite(), + } } - pub fn into_cleared_vec(mut self) -> Vec { - self.data.clear(); - self.data + fn update_tx_secret(&mut self) -> Result<(u64, ConnectionTrafficSecrets), rustls::Error> { + match self.as_side_mut() { + Side::Client(client) => client.conn.update_tx_secret(), + Side::Server(server) => server.conn.update_tx_secret(), + } } - pub fn read_buf(&mut self, buf: &mut ReadBuf<'_>) -> usize { - let tail = &self.data[self.offset..]; - let removed = &tail[..tail.len().min(buf.remaining())]; - buf.put_slice(removed); - self.offset += removed.len(); - removed.len() + fn update_rx_secret(&mut self) -> Result<(u64, ConnectionTrafficSecrets), rustls::Error> { + match self.as_side_mut() { + Side::Client(client) => client.conn.update_rx_secret(), + Side::Server(server) => server.conn.update_rx_secret(), + } } } -struct ClearOnDrop<'a>(&'a mut Vec); +pub(crate) trait StreamSide: 'static { + fn as_side( + &self, + ) -> Side<&StreamData, &StreamData>; -impl Deref for ClearOnDrop<'_> { - type Target = Vec; + fn as_side_mut( + &mut self, + ) -> Side<&mut StreamData, &mut StreamData>; +} - fn deref(&self) -> &Self::Target { - &*self.0 +impl StreamSide for StreamData { + fn as_side( + &self, + ) -> Side<&StreamData, &StreamData> { + Side::Client(self) + } + + fn as_side_mut( + &mut self, + ) -> Side<&mut StreamData, &mut StreamData> + { + Side::Client(self) } } -impl DerefMut for ClearOnDrop<'_> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut *self.0 +impl StreamSide for StreamData { + fn as_side( + &self, + ) -> Side<&StreamData, &StreamData> { + Side::Server(self) + } + + fn as_side_mut( + &mut self, + ) -> Side<&mut StreamData, &mut StreamData> + { + Side::Server(self) } } -impl Drop for ClearOnDrop<'_> { - fn drop(&mut self) { - self.0.clear(); +#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Hash)] +struct StreamState(u8); + +#[rustfmt::skip] +impl StreamState { + const READ_CLOSED: u8 = 0b00001; + const WRITE_CLOSED: u8 = 0b00010; + const CLOSED: u8 = 0b00011; + const EARLY_DATA: u8 = 0b00100; + const PENDING_CLOSE: u8 = 0b01000; +} + +impl StreamState { + fn contains(self, flags: u8) -> bool { + self.0 & flags == flags + } + + fn read_closed(self) -> bool { + self.contains(Self::READ_CLOSED) + } + + fn write_closed(self) -> bool { + self.contains(Self::WRITE_CLOSED) + } + + fn early_data(self) -> bool { + self.contains(Self::EARLY_DATA) } } -const DEFAULT_SCRATCH_CAPACITY: usize = 256; -const TLS_MAX_MESSAGE_LEN: usize = 1 << 14; +const MAX_SCRATCH_CAPACITY: usize = 1024; #[derive(Debug, thiserror::Error)] -enum KTlsStreamError { +pub enum KTlsStreamError { #[error("received corrupt message of type {0:?}")] InvalidMessage(InvalidMessage), @@ -865,11 +713,11 @@ enum KTlsStreamError { #[error("failed to handle a provided session ticket: {0}")] SessionTicketFailed(#[source] rustls::Error), - #[error("the connection has been shut down")] - ConnectionShutDown, + #[error("the connection has been closed by the peer")] + Closed, - #[error("the connection has reached its confidentiality limit and has been shut down")] - ConfidentialityLimitReached, + #[error("cannot handle control messages while there is buffered data to read")] + ControlMessageWithBufferedData, #[error("connection peer closed the connection with an alert: {0:?}")] Alert(AlertDescription), @@ -886,3 +734,54 @@ impl From for KTlsStreamError { Self::PeerMisbehaved(error) } } + +pub(crate) enum Side { + Client(Client), + Server(Server), +} + +struct ClearOnDrop<'a>(&'a mut Vec); + +impl Deref for ClearOnDrop<'_> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &*self.0 + } +} + +impl DerefMut for ClearOnDrop<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.0 + } +} + +impl Drop for ClearOnDrop<'_> { + fn drop(&mut self) { + self.0.clear(); + } +} + +struct TakeBuffer<'a, 'b, IO, Conn: ?Sized> { + stream: &'a mut KTlsStreamProject<'b, IO, Conn>, + buffer: Vec, +} + +impl<'a, 'b, IO, Conn: ?Sized> TakeBuffer<'a, 'b, IO, Conn> { + pub fn new(stream: &'a mut KTlsStreamProject<'b, IO, Conn>) -> Self { + Self { + buffer: std::mem::take(&mut stream.data.buffer), + stream, + } + } + + pub fn as_parts_mut(&mut self) -> (&mut KTlsStreamProject<'b, IO, Conn>, &mut Vec) { + (&mut *self.stream, &mut self.buffer) + } +} + +impl<'a, 'b, IO, Conn: ?Sized> Drop for TakeBuffer<'a, 'b, IO, Conn> { + fn drop(&mut self) { + self.stream.data.buffer = std::mem::take(&mut self.buffer); + } +} From 7ce238564597fa3e7ffaea524f803faf25b4fee0 Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Mon, 26 May 2025 16:58:02 -0700 Subject: [PATCH 3/8] Add client stream --- ktls/src/client.rs | 131 +++++++++++++++++++++++++++++++++++++------ ktls/src/lib.rs | 4 ++ ktls/src/protocol.rs | 57 ++++++++++++++++++- 3 files changed, 174 insertions(+), 18 deletions(-) diff --git a/ktls/src/client.rs b/ktls/src/client.rs index 815fab7..42e906d 100644 --- a/ktls/src/client.rs +++ b/ktls/src/client.rs @@ -1,24 +1,123 @@ -use std::sync::Arc; +use std::os::fd::AsRawFd; +use std::{fmt, io}; +use rustls::client::{ClientConnectionData, UnbufferedClientConnection}; +use rustls::kernel::KernelConnection; +use tokio::io::{AsyncRead, AsyncWrite}; -pub struct KTlsConnector { - config: Arc, -} +use crate::ffi::Direction; +use crate::stream::KTlsStreamImpl; +use crate::CryptoInfo; + +pub struct KTlsClientStream(KTlsStreamImpl>); + +impl KTlsClientStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + pub fn from_unbuffered_connnection( + socket: IO, + conn: UnbufferedClientConnection, + ) -> Result> { + // We attempt to set up the TLS ULP before doing anything else so that + // we can indicate that the kernel doesn't support kTLS before returning + // any other error. + if let Err(e) = crate::ffi::setup_ulp(socket.as_raw_fd()) { + let error = if e.raw_os_error() == Some(libc::ENOENT) { + ConnectError::KTlsUnsupported + } else { + ConnectError::IO(e) + }; -impl KTlsConnector { - pub fn new(config: Arc) -> Self { - Self { config } + return Err(TryConnectError { + error, + socket: Some(socket), + conn: Some(conn), + }); + } + + // TODO: Validate that the negotiated connection is actually + // supported by kTLS on the current machine. + + Ok(Self::from_unbuffered_connnection_with_tls_ulp_enabled( + socket, conn, + )?) } - // pub async fn try_connect( - // &self, - // domain: ServerName<'static>, - // stream: IO - // ) -> Result< + /// Create a new `KTlsClientStream` from a socket that already has had the TLS ULP + /// enabled on it. + fn from_unbuffered_connnection_with_tls_ulp_enabled( + socket: IO, + conn: UnbufferedClientConnection, + ) -> Result { + let (secrets, kconn) = match conn.dangerous_into_kernel_connection() { + Ok(secrets) => secrets, + Err(e) => return Err(ConnectError::ExtractSecrets(e)), + }; + + let suite = kconn.negotiated_cipher_suite(); + let tx = CryptoInfo::from_rustls(suite, secrets.tx) + .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + let rx = CryptoInfo::from_rustls(suite, secrets.rx) + .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + + crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Tx, tx) + .map_err(ConnectError::IO)?; + crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Rx, rx) + .map_err(ConnectError::IO)?; + + todo!() + } } -// pub struct KTlsClientStream +#[derive(Debug, thiserror::Error)] +pub enum ConnectError { + /// kTLS is not supported by the current kernel + #[error("kTLS is not supported by the current kernel")] + KTlsUnsupported, + + #[error("the negotiated cipher suite is not supported by kTLS")] + UnsupportedCipherSuite, + + #[error("the peer closed the connection before the TLS handshake could be completed")] + PeerClosedBeforeHandshakeCompleted, + + #[error("{0}")] + IO(#[source] io::Error), -// pub struct TryConnectError { -// connection: -// } + #[error("failed to create rustls client connection: {0}")] + Config(#[source] rustls::Error), + + #[error("an error occurred during the handshake: {0}")] + Handshake(#[source] rustls::Error), + + #[error("unable to extract connection secrets from rustls connection: {0}")] + ExtractSecrets(#[source] rustls::Error), +} + +#[derive(thiserror::Error)] +#[error("{error}")] +pub struct TryConnectError { + #[source] + pub error: ConnectError, + pub socket: Option, + pub conn: Option, +} + +impl fmt::Debug for TryConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TryConnectError") + .field("error", &self.error) + .finish_non_exhaustive() + } +} + +impl From for TryConnectError { + fn from(error: ConnectError) -> Self { + Self { + error, + socket: None, + conn: None, + } + } +} diff --git a/ktls/src/lib.rs b/ktls/src/lib.rs index 4847c05..01f95af 100644 --- a/ktls/src/lib.rs +++ b/ktls/src/lib.rs @@ -36,9 +36,13 @@ pub use ktls_stream::KtlsStream; mod cork_stream; pub use cork_stream::CorkStream; +mod client; mod protocol; mod stream; +pub use crate::stream::KTlsStreamError; +pub use crate::client::KTlsClientStream; + #[derive(Debug, Default)] pub struct CompatibleCiphers { pub tls12: CompatibleCiphersForVersion, diff --git a/ktls/src/protocol.rs b/ktls/src/protocol.rs index 833dcca..5f8e634 100644 --- a/ktls/src/protocol.rs +++ b/ktls/src/protocol.rs @@ -2,7 +2,9 @@ #![allow(non_upper_case_globals)] -use std::fmt; +use std::{fmt, io}; + +use tokio::io::{AsyncRead, ReadBuf}; macro_rules! c_enum { { @@ -65,7 +67,6 @@ macro_rules! c_enum { } } - c_enum! { #[derive(Copy, Clone, Eq, PartialEq)] pub(crate) enum AlertLevel: u8 { @@ -82,4 +83,56 @@ c_enum! { } } +pub(crate) async fn read_record(stream: &mut IO, buf: &mut Vec) -> io::Result<()> +where + IO: AsyncRead + Unpin, +{ + use tokio::io::AsyncReadExt; + + let mut header = [0u8; 5]; + stream.read_exact(&mut header).await?; + let bytes: [u8; 5] = header.try_into().unwrap(); + buf.extend_from_slice(&header); + + let header = TlsHeader::decode(bytes); + + buf.reserve(header.len as usize); + let new_len = buf.len() + header.len as usize; + let mut rdbuf = ReadBuf::uninit(&mut buf.spare_capacity_mut()[..header.len as usize]); + + loop { + let remaining = rdbuf.remaining(); + if remaining == 0 { + break; + } + + stream.read_buf(&mut rdbuf).await?; + if rdbuf.remaining() == remaining { + return Err(io::Error::from(io::ErrorKind::UnexpectedEof)) + } + } + + // SAFETY: If we get here we guarantee that rdbuf.remaining() == 0. + // ReadBuf's contract means that we can assume that it has + // been fully initialized under those conditions. + unsafe { buf.set_len(new_len) }; + + Ok(()) +} + +#[allow(dead_code)] +struct TlsHeader { + ty: rustls::ContentType, + version: rustls::ProtocolVersion, + len: u16, +} +impl TlsHeader { + pub fn decode(bytes: [u8; 5]) -> Self { + let ty = rustls::ContentType::from(bytes[0]); + let version = rustls::ProtocolVersion::from(u16::from_be_bytes([bytes[1], bytes[2]])); + let len = u16::from_be_bytes([bytes[3], bytes[4]]); + + Self { ty, version, len } + } +} From 1bfd1bb8484b8052dcaaa26787bd4adf9f89dce0 Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Mon, 26 May 2025 17:56:28 -0700 Subject: [PATCH 4/8] Add KTlsServerStream and trait impls --- Cargo.lock | 5 +- ktls/Cargo.toml | 1 + ktls/src/client.rs | 107 +++++++++++++++++++--------------- ktls/src/error.rs | 53 +++++++++++++++++ ktls/src/lib.rs | 6 +- ktls/src/protocol.rs | 4 +- ktls/src/server.rs | 135 +++++++++++++++++++++++++++++++++++++++++++ ktls/src/stream.rs | 95 +++++++++++++++++++++--------- 8 files changed, 328 insertions(+), 78 deletions(-) create mode 100644 ktls/src/error.rs create mode 100644 ktls/src/server.rs diff --git a/Cargo.lock b/Cargo.lock index f0dc86c..8b8736c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,9 +101,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bytes" @@ -326,6 +326,7 @@ dependencies = [ name = "ktls" version = "6.0.2" dependencies = [ + "bitflags", "futures-util", "ktls-sys 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", "lazy_static", diff --git a/ktls/Cargo.toml b/ktls/Cargo.toml index f48187f..f20f817 100644 --- a/ktls/Cargo.toml +++ b/ktls/Cargo.toml @@ -26,6 +26,7 @@ ktls-sys = "1.0.1" num_enum = "0.7.3" futures-util = "0.3.30" nix = { version = "0.29.0", features = ["socket", "uio", "net"] } +bitflags = "2.9.1" [dev-dependencies] lazy_static = "1.5.0" diff --git a/ktls/src/client.rs b/ktls/src/client.rs index 42e906d..7f6f81b 100644 --- a/ktls/src/client.rs +++ b/ktls/src/client.rs @@ -1,15 +1,23 @@ -use std::os::fd::AsRawFd; -use std::{fmt, io}; +use std::io; +use std::os::fd::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; use rustls::client::{ClientConnectionData, UnbufferedClientConnection}; use rustls::kernel::KernelConnection; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use crate::ffi::Direction; use crate::stream::KTlsStreamImpl; use crate::CryptoInfo; +use crate::{ConnectError, TryConnectError}; -pub struct KTlsClientStream(KTlsStreamImpl>); +pin_project_lite::pin_project! { + pub struct KTlsClientStream { + #[pin] + stream: KTlsStreamImpl> + } +} impl KTlsClientStream where @@ -18,7 +26,7 @@ where pub fn from_unbuffered_connnection( socket: IO, conn: UnbufferedClientConnection, - ) -> Result> { + ) -> Result> { // We attempt to set up the TLS ULP before doing anything else so that // we can indicate that the kernel doesn't support kTLS before returning // any other error. @@ -66,58 +74,63 @@ where crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Rx, rx) .map_err(ConnectError::IO)?; - todo!() + Ok(Self { + stream: KTlsStreamImpl::new(socket, Vec::new(), kconn), + }) } } -#[derive(Debug, thiserror::Error)] -pub enum ConnectError { - /// kTLS is not supported by the current kernel - #[error("kTLS is not supported by the current kernel")] - KTlsUnsupported, - - #[error("the negotiated cipher suite is not supported by kTLS")] - UnsupportedCipherSuite, - - #[error("the peer closed the connection before the TLS handshake could be completed")] - PeerClosedBeforeHandshakeCompleted, - - #[error("{0}")] - IO(#[source] io::Error), +impl AsyncRead for KTlsClientStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().stream.poll_read(cx, buf) + } +} - #[error("failed to create rustls client connection: {0}")] - Config(#[source] rustls::Error), +impl AsyncWrite for KTlsClientStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().stream.poll_write(cx, buf) + } - #[error("an error occurred during the handshake: {0}")] - Handshake(#[source] rustls::Error), + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_flush(cx) + } - #[error("unable to extract connection secrets from rustls connection: {0}")] - ExtractSecrets(#[source] rustls::Error), -} + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_shutdown(cx) + } -#[derive(thiserror::Error)] -#[error("{error}")] -pub struct TryConnectError { - #[source] - pub error: ConnectError, - pub socket: Option, - pub conn: Option, -} + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().stream.poll_write_vectored(cx, bufs) + } -impl fmt::Debug for TryConnectError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TryConnectError") - .field("error", &self.error) - .finish_non_exhaustive() + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() } } -impl From for TryConnectError { - fn from(error: ConnectError) -> Self { - Self { - error, - socket: None, - conn: None, - } +impl AsRawFd for KTlsClientStream +where + IO: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.stream.as_raw_fd() } } diff --git a/ktls/src/error.rs b/ktls/src/error.rs new file mode 100644 index 0000000..3a7f338 --- /dev/null +++ b/ktls/src/error.rs @@ -0,0 +1,53 @@ +use std::{fmt, io}; + +#[derive(Debug, thiserror::Error)] +pub enum ConnectError { + /// kTLS is not supported by the current kernel + #[error("kTLS is not supported by the current kernel")] + KTlsUnsupported, + + #[error("the negotiated cipher suite is not supported by kTLS")] + UnsupportedCipherSuite, + + #[error("the peer closed the connection before the TLS handshake could be completed")] + PeerClosedBeforeHandshakeCompleted, + + #[error("{0}")] + IO(#[source] io::Error), + + #[error("failed to create rustls client connection: {0}")] + Config(#[source] rustls::Error), + + #[error("an error occurred during the handshake: {0}")] + Handshake(#[source] rustls::Error), + + #[error("unable to extract connection secrets from rustls connection: {0}")] + ExtractSecrets(#[source] rustls::Error), +} + +#[derive(thiserror::Error)] +#[error("{error}")] +pub struct TryConnectError { + #[source] + pub error: ConnectError, + pub socket: Option, + pub conn: Option, +} + +impl fmt::Debug for TryConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TryConnectError") + .field("error", &self.error) + .finish_non_exhaustive() + } +} + +impl From for TryConnectError { + fn from(error: ConnectError) -> Self { + Self { + error, + socket: None, + conn: None, + } + } +} diff --git a/ktls/src/lib.rs b/ktls/src/lib.rs index 01f95af..a866088 100644 --- a/ktls/src/lib.rs +++ b/ktls/src/lib.rs @@ -37,11 +37,15 @@ mod cork_stream; pub use cork_stream::CorkStream; mod client; +mod error; mod protocol; +mod server; mod stream; -pub use crate::stream::KTlsStreamError; pub use crate::client::KTlsClientStream; +pub use crate::error::{ConnectError, TryConnectError}; +pub use crate::server::KTlsServerStream; +pub use crate::stream::KTlsStreamError; #[derive(Debug, Default)] pub struct CompatibleCiphers { diff --git a/ktls/src/protocol.rs b/ktls/src/protocol.rs index 5f8e634..9b40429 100644 --- a/ktls/src/protocol.rs +++ b/ktls/src/protocol.rs @@ -99,7 +99,7 @@ where buf.reserve(header.len as usize); let new_len = buf.len() + header.len as usize; let mut rdbuf = ReadBuf::uninit(&mut buf.spare_capacity_mut()[..header.len as usize]); - + loop { let remaining = rdbuf.remaining(); if remaining == 0 { @@ -108,7 +108,7 @@ where stream.read_buf(&mut rdbuf).await?; if rdbuf.remaining() == remaining { - return Err(io::Error::from(io::ErrorKind::UnexpectedEof)) + return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } } diff --git a/ktls/src/server.rs b/ktls/src/server.rs new file mode 100644 index 0000000..ae4776c --- /dev/null +++ b/ktls/src/server.rs @@ -0,0 +1,135 @@ +use std::io; +use std::os::fd::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use rustls::kernel::KernelConnection; +use rustls::server::{ServerConnectionData, UnbufferedServerConnection}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::ffi::Direction; +use crate::stream::KTlsStreamImpl; +use crate::{ConnectError, CryptoInfo, TryConnectError}; + +pin_project_lite::pin_project! { + pub struct KTlsServerStream { + #[pin] + stream: KTlsStreamImpl> + } +} + +impl KTlsServerStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + pub fn from_unbuffered_connnection( + socket: IO, + conn: UnbufferedServerConnection, + ) -> Result> { + // We attempt to set up the TLS ULP before doing anything else so that + // we can indicate that the kernel doesn't support kTLS before returning + // any other error. + if let Err(e) = crate::ffi::setup_ulp(socket.as_raw_fd()) { + let error = if e.raw_os_error() == Some(libc::ENOENT) { + ConnectError::KTlsUnsupported + } else { + ConnectError::IO(e) + }; + + return Err(TryConnectError { + error, + socket: Some(socket), + conn: Some(conn), + }); + } + + Ok(Self::from_unbuffered_connnection_with_tls_ulp_enabled( + socket, + Vec::new(), + conn, + )?) + } + + /// Create a new `KTlsServerStream` from a socket that already has had the TLS ULP + /// enabled on it. + pub fn from_unbuffered_connnection_with_tls_ulp_enabled( + socket: IO, + early_data: Vec, + conn: UnbufferedServerConnection, + ) -> Result { + let (secrets, kconn) = match conn.dangerous_into_kernel_connection() { + Ok(secrets) => secrets, + Err(e) => return Err(ConnectError::ExtractSecrets(e)), + }; + + let suite = kconn.negotiated_cipher_suite(); + let tx = CryptoInfo::from_rustls(suite, secrets.tx) + .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + let rx = CryptoInfo::from_rustls(suite, secrets.rx) + .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + + crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Tx, tx) + .map_err(ConnectError::IO)?; + crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Rx, rx) + .map_err(ConnectError::IO)?; + + Ok(Self { + stream: KTlsStreamImpl::new(socket, early_data, kconn), + }) + } +} + +impl AsyncRead for KTlsServerStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().stream.poll_read(cx, buf) + } +} + +impl AsyncWrite for KTlsServerStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().stream.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().stream.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } +} + +impl AsRawFd for KTlsServerStream +where + IO: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.stream.as_raw_fd() + } +} diff --git a/ktls/src/stream.rs b/ktls/src/stream.rs index bead648..ff8fcbc 100644 --- a/ktls/src/stream.rs +++ b/ktls/src/stream.rs @@ -1,6 +1,6 @@ use std::io; use std::ops::{Deref, DerefMut}; -use std::os::fd::AsRawFd; +use std::os::fd::{AsRawFd, RawFd}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -19,12 +19,48 @@ type KernelClientConnection = KernelConnection; pin_project_lite::pin_project! { + /// A generic kTLS stream. + /// + /// Most of the behaviour is identical between client and server streams so + /// this type allows either. In the cases where there is a difference, the + /// [`StreamSide`] trait is used to get the correct side of the stream. #[project = KTlsStreamProject] pub(crate) struct KTlsStreamImpl { #[pin] socket: IO, state: StreamState, - data: StreamData, + data: Box>, + } +} + +impl KTlsStreamImpl +where + IO: AsyncRead + AsyncWrite + AsRawFd, + Conn: ?Sized, + StreamData: StreamSide, +{ + pub(crate) fn new(socket: IO, early_data: Vec, conn: Conn) -> Self + where + Conn: Sized, + { + let (state, buffer) = match () { + _ if !early_data.is_empty() => (StreamState::EARLY_DATA, early_data), + _ if early_data.capacity() != 0 => (StreamState::default(), early_data), + _ => ( + StreamState::default(), + Vec::with_capacity(DEFAULT_SCRATCH_CAPACITY), + ), + }; + + Self { + socket, + state, + data: Box::new(StreamData { + buffer, + offset: 0, + conn, + }), + } } } @@ -77,7 +113,7 @@ where if self.data.offset == self.data.buffer.len() { self.data.buffer.clear(); self.data.offset = 0; - self.state.0 &= !StreamState::EARLY_DATA; + *self.state &= !StreamState::EARLY_DATA; self.data.buffer.shrink_to(MAX_SCRATCH_CAPACITY); } @@ -130,7 +166,7 @@ where std::task::ready!(self.poll_do_close(cx))?; } - self.state.0 |= StreamState::WRITE_CLOSED; + *self.state |= StreamState::WRITE_CLOSED; self.socket.as_mut().poll_shutdown(cx) } @@ -138,7 +174,7 @@ where match self.socket.as_mut().poll_flush(cx) { Poll::Pending => Poll::Pending, Poll::Ready(result) => { - self.state.0 &= !StreamState::PENDING_CLOSE; + *self.state &= !StreamState::PENDING_CLOSE; if result.is_ok() { if let Err(e) = @@ -148,7 +184,7 @@ where } } - self.state.0 |= StreamState::WRITE_CLOSED; + *self.state |= StreamState::WRITE_CLOSED; Poll::Ready(result) } } @@ -207,14 +243,16 @@ where fn handle_control_message_impl(&mut self, buffer: &mut Vec) -> io::Result<()> { if self.state.read_closed() { - return Err(io::Error::other(KTlsStreamError::Closed)) + return Err(io::Error::other(KTlsStreamError::Closed)); } // We reuse the early data buffer to read the control message so it is // an error to attempt to do so without having handled all the early // data beforehand. if self.state.early_data() { - return Err(io::Error::other(KTlsStreamError::ControlMessageWithBufferedData)); + return Err(io::Error::other( + KTlsStreamError::ControlMessageWithBufferedData, + )); } let mut data = ClearOnDrop(buffer); @@ -261,7 +299,7 @@ where // It's not ideal, but we can handle it. std::mem::forget(data); - self.state.0 |= StreamState::EARLY_DATA; + *self.state |= StreamState::EARLY_DATA; return Ok(()); } @@ -317,7 +355,7 @@ where // The peer has closed their end of the connection. We close the read half // of the connection since we will receive no more data frames. AlertDescription::CloseNotify => { - self.state.0 |= StreamState::READ_CLOSED; + *self.state |= StreamState::READ_CLOSED; } // TLS 1.2 allows alerts to be sent with a warning level without terminating @@ -328,7 +366,7 @@ where // All other alerts are treated as fatal and result in us immediately shutting // down the connection and emitting an error. _ => { - self.state.0 = StreamState::CLOSED; + *self.state = StreamState::CLOSED; return Err(io::Error::other(KTlsStreamError::Alert(desc))); } } @@ -499,7 +537,7 @@ where fn abort(&mut self, alert: AlertDescription) -> io::Result<()> { let write_closed = self.state.write_closed(); - self.state.0 = StreamState::WRITE_CLOSED | StreamState::READ_CLOSED; + *self.state = StreamState::WRITE_CLOSED | StreamState::READ_CLOSED; if !write_closed { self.send_alert(AlertLevel::Fatal, alert)?; @@ -584,6 +622,15 @@ where } } +impl AsRawFd for KTlsStreamImpl +where + IO: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + pub(crate) struct StreamData { /// This buffer is used to store early data and also as a buffer to store /// received control messages. @@ -667,23 +714,18 @@ impl StreamSide for StreamData { } } -#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Hash)] -struct StreamState(u8); - -#[rustfmt::skip] -impl StreamState { - const READ_CLOSED: u8 = 0b00001; - const WRITE_CLOSED: u8 = 0b00010; - const CLOSED: u8 = 0b00011; - const EARLY_DATA: u8 = 0b00100; - const PENDING_CLOSE: u8 = 0b01000; +bitflags::bitflags! { + #[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Hash)] + struct StreamState : u8 { + const READ_CLOSED = 0b00001; + const WRITE_CLOSED = 0b00010; + const CLOSED = 0b00011; + const EARLY_DATA = 0b00100; + const PENDING_CLOSE = 0b01000; + } } impl StreamState { - fn contains(self, flags: u8) -> bool { - self.0 & flags == flags - } - fn read_closed(self) -> bool { self.contains(Self::READ_CLOSED) } @@ -697,6 +739,7 @@ impl StreamState { } } +const DEFAULT_SCRATCH_CAPACITY: usize = 64; const MAX_SCRATCH_CAPACITY: usize = 1024; #[derive(Debug, thiserror::Error)] From 8ee477db04dcc1b28c4a4d188825c6151a196949 Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Mon, 26 May 2025 18:47:43 -0700 Subject: [PATCH 5/8] Add generic KTlsStream and conversions --- ktls/src/client.rs | 2 +- ktls/src/generic.rs | 96 ++++++++++++++++++++++++++++++++++ ktls/src/lib.rs | 4 +- ktls/src/server.rs | 2 +- ktls/src/stream.rs | 123 +++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 222 insertions(+), 5 deletions(-) create mode 100644 ktls/src/generic.rs diff --git a/ktls/src/client.rs b/ktls/src/client.rs index 7f6f81b..06fa6ba 100644 --- a/ktls/src/client.rs +++ b/ktls/src/client.rs @@ -15,7 +15,7 @@ use crate::{ConnectError, TryConnectError}; pin_project_lite::pin_project! { pub struct KTlsClientStream { #[pin] - stream: KTlsStreamImpl> + pub(crate) stream: KTlsStreamImpl> } } diff --git a/ktls/src/generic.rs b/ktls/src/generic.rs new file mode 100644 index 0000000..005f8fa --- /dev/null +++ b/ktls/src/generic.rs @@ -0,0 +1,96 @@ +use std::io; +use std::os::fd::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::stream::{DynConn, KTlsStreamImpl}; +use crate::{KTlsClientStream, KTlsServerStream, Side}; + +pin_project_lite::pin_project! { + pub struct KTlsStream { + #[pin] + stream: KTlsStreamImpl + } +} + +impl KTlsStream { + pub fn into_side(self) -> Side, KTlsServerStream> { + match self.stream.into_side() { + Side::Client(stream) => Side::Client(KTlsClientStream { stream }), + Side::Server(stream) => Side::Server(KTlsServerStream { stream }), + } + } +} + +impl From> for KTlsStream { + fn from(value: KTlsClientStream) -> Self { + Self { + stream: value.stream.into_dyn(), + } + } +} + +impl From> for KTlsStream { + fn from(value: KTlsServerStream) -> Self { + Self { + stream: value.stream.into_dyn(), + } + } +} + +impl AsyncRead for KTlsStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().stream.poll_read(cx, buf) + } +} + +impl AsyncWrite for KTlsStream +where + IO: AsyncWrite + AsyncRead + AsRawFd, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().stream.poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().stream.poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + self.project().stream.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } +} + +impl AsRawFd for KTlsStream +where + IO: AsRawFd, +{ + fn as_raw_fd(&self) -> RawFd { + self.stream.as_raw_fd() + } +} diff --git a/ktls/src/lib.rs b/ktls/src/lib.rs index a866088..ca03196 100644 --- a/ktls/src/lib.rs +++ b/ktls/src/lib.rs @@ -38,14 +38,16 @@ pub use cork_stream::CorkStream; mod client; mod error; +mod generic; mod protocol; mod server; mod stream; pub use crate::client::KTlsClientStream; pub use crate::error::{ConnectError, TryConnectError}; +pub use crate::generic::KTlsStream; pub use crate::server::KTlsServerStream; -pub use crate::stream::KTlsStreamError; +pub use crate::stream::{KTlsStreamError, Side}; #[derive(Debug, Default)] pub struct CompatibleCiphers { diff --git a/ktls/src/server.rs b/ktls/src/server.rs index ae4776c..af78ec3 100644 --- a/ktls/src/server.rs +++ b/ktls/src/server.rs @@ -14,7 +14,7 @@ use crate::{ConnectError, CryptoInfo, TryConnectError}; pin_project_lite::pin_project! { pub struct KTlsServerStream { #[pin] - stream: KTlsStreamImpl> + pub(crate) stream: KTlsStreamImpl> } } diff --git a/ktls/src/stream.rs b/ktls/src/stream.rs index ff8fcbc..14fecdc 100644 --- a/ktls/src/stream.rs +++ b/ktls/src/stream.rs @@ -35,7 +35,6 @@ pin_project_lite::pin_project! { impl KTlsStreamImpl where - IO: AsyncRead + AsyncWrite + AsRawFd, Conn: ?Sized, StreamData: StreamSide, { @@ -62,6 +61,43 @@ where }), } } + + pub(crate) fn into_side( + self, + ) -> Side, KTlsStreamImpl> + { + let Self { + socket, + state, + data, + } = self; + + match data.into_side() { + Side::Client(data) => Side::Client(KTlsStreamImpl { + socket, + state, + data, + }), + Side::Server(data) => Side::Server(KTlsStreamImpl { + socket, + state, + data, + }), + } + } +} + +impl KTlsStreamImpl +where + Conn: DynConn + 'static, +{ + pub(crate) fn into_dyn(self) -> KTlsStreamImpl { + KTlsStreamImpl { + socket: self.socket, + state: self.state, + data: self.data, + } + } } impl KTlsStreamProject<'_, IO, Conn> @@ -625,6 +661,7 @@ where impl AsRawFd for KTlsStreamImpl where IO: AsRawFd, + Conn: ?Sized, { fn as_raw_fd(&self) -> RawFd { self.socket.as_raw_fd() @@ -682,6 +719,10 @@ pub(crate) trait StreamSide: 'static { fn as_side_mut( &mut self, ) -> Side<&mut StreamData, &mut StreamData>; + + fn into_side( + self: Box, + ) -> Side>, Box>>; } impl StreamSide for StreamData { @@ -697,6 +738,13 @@ impl StreamSide for StreamData { { Side::Client(self) } + + fn into_side( + self: Box, + ) -> Side>, Box>> + { + Side::Client(self) + } } impl StreamSide for StreamData { @@ -712,6 +760,76 @@ impl StreamSide for StreamData { { Side::Server(self) } + + fn into_side( + self: Box, + ) -> Side>, Box>> + { + Side::Server(self) + } +} + +impl StreamSide for StreamData { + fn as_side( + &self, + ) -> Side<&StreamData, &StreamData> { + match self.conn.side() { + // SAFETY: The implementor of DynConn guarantees that it is safe to downcast here. + Side::Client(_) => Side::Client(unsafe { &*(self as *const Self as *const _) }), + Side::Server(_) => Side::Server(unsafe { &*(self as *const Self as *const _) }), + } + } + + fn as_side_mut( + &mut self, + ) -> Side<&mut StreamData, &mut StreamData> + { + match self.conn.side() { + // SAFETY: The implementor of DynConn guarantees that it is safe to downcast here. + Side::Client(_) => Side::Client(unsafe { &mut *(self as *mut Self as *mut _) }), + Side::Server(_) => Side::Server(unsafe { &mut *(self as *mut Self as *mut _) }), + } + } + + fn into_side( + self: Box, + ) -> Side>, Box>> + { + // SAFETY: The implementation of DynConn guarantees that it is safe to downcast here. + match self.conn.side() { + Side::Client(_) => { + Side::Client(unsafe { Box::from_raw(Box::into_raw(self) as *mut _) }) + } + Side::Server(_) => { + Side::Server(unsafe { Box::from_raw(Box::into_raw(self) as *mut _) }) + } + } + } +} + +// SAFETY: We are a `KernelClientConnection` so it is safe `&dyn DynConn` references to us +// to be downcasted. +unsafe impl DynConn for KernelClientConnection { + fn side(&self) -> Side<(), ()> { + Side::Client(()) + } +} + +// SAFETY: We are a `KernelServerConnection` so it is safe `&dyn DynConn` references to us +// to be downcasted. +unsafe impl DynConn for KernelServerConnection { + fn side(&self) -> Side<(), ()> { + Side::Server(()) + } +} + +/// A trait that indicates which side of the connection this instance is. +/// +/// # Safety +/// Depending on what `side()` returns, it must be safe to downcast this trait +/// to either a [`KernelClientConnection`] or a [`KernelServerConnection`]. +pub(crate) unsafe trait DynConn { + fn side(&self) -> Side<(), ()>; } bitflags::bitflags! { @@ -778,7 +896,8 @@ impl From for KTlsStreamError { } } -pub(crate) enum Side { +/// An enum splitting things by connection side. +pub enum Side { Client(Client), Server(Server), } From d3fa1185579578ede033e437dadb6cf39f83560b Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Mon, 26 May 2025 18:59:16 -0700 Subject: [PATCH 6/8] Add some documentation --- ktls/src/client.rs | 1 + ktls/src/generic.rs | 6 ++++++ ktls/src/server.rs | 1 + 3 files changed, 8 insertions(+) diff --git a/ktls/src/client.rs b/ktls/src/client.rs index 06fa6ba..705c62f 100644 --- a/ktls/src/client.rs +++ b/ktls/src/client.rs @@ -13,6 +13,7 @@ use crate::CryptoInfo; use crate::{ConnectError, TryConnectError}; pin_project_lite::pin_project! { + /// The client half of a kTLS stream. pub struct KTlsClientStream { #[pin] pub(crate) stream: KTlsStreamImpl> diff --git a/ktls/src/generic.rs b/ktls/src/generic.rs index 005f8fa..0558ef7 100644 --- a/ktls/src/generic.rs +++ b/ktls/src/generic.rs @@ -9,6 +9,12 @@ use crate::stream::{DynConn, KTlsStreamImpl}; use crate::{KTlsClientStream, KTlsServerStream, Side}; pin_project_lite::pin_project! { + /// A wrapper around an `IO` that takes care of managing kTLS state for + /// its underlying fd. + /// + /// This is the generic version of [`KTlsClientStream`] and [`KTlsServerStream`]. + /// It cannot be constructed directly. Instead, construct on of the two more + /// specific streams above and then convert them into [`KTlsStream`]. pub struct KTlsStream { #[pin] stream: KTlsStreamImpl diff --git a/ktls/src/server.rs b/ktls/src/server.rs index af78ec3..507192e 100644 --- a/ktls/src/server.rs +++ b/ktls/src/server.rs @@ -12,6 +12,7 @@ use crate::stream::KTlsStreamImpl; use crate::{ConnectError, CryptoInfo, TryConnectError}; pin_project_lite::pin_project! { + /// The server half of a kTLS stream. pub struct KTlsServerStream { #[pin] pub(crate) stream: KTlsStreamImpl> From 049d2197b8a441ed9995efb9bc9114369fd764a5 Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Fri, 6 Jun 2025 16:51:54 -0700 Subject: [PATCH 7/8] Cache support probe results and validate that ciphers are supported --- ktls/src/error.rs | 34 ++++++- ktls/src/generic.rs | 2 +- ktls/src/lib.rs | 3 +- ktls/src/server.rs | 201 ++++++++++++++++++++++++++++++++++-- ktls/src/suite.rs | 242 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 469 insertions(+), 13 deletions(-) create mode 100644 ktls/src/suite.rs diff --git a/ktls/src/error.rs b/ktls/src/error.rs index 3a7f338..d12209d 100644 --- a/ktls/src/error.rs +++ b/ktls/src/error.rs @@ -1,21 +1,26 @@ use std::{fmt, io}; +use rustls::SupportedCipherSuite; + +use crate::suite::CipherProbeError; + #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum ConnectError { - /// kTLS is not supported by the current kernel + /// kTLS is not supported by the current kernel. #[error("kTLS is not supported by the current kernel")] KTlsUnsupported, #[error("the negotiated cipher suite is not supported by kTLS")] - UnsupportedCipherSuite, + UnsupportedCipherSuite(SupportedCipherSuite), #[error("the peer closed the connection before the TLS handshake could be completed")] - PeerClosedBeforeHandshakeCompleted, + ConnectionClosedBeforeHandshakeCompleted, #[error("{0}")] IO(#[source] io::Error), - #[error("failed to create rustls client connection: {0}")] + #[error("failed to create rustls connection: {0}")] Config(#[source] rustls::Error), #[error("an error occurred during the handshake: {0}")] @@ -25,6 +30,15 @@ pub enum ConnectError { ExtractSecrets(#[source] rustls::Error), } +impl From for io::Error { + fn from(error: ConnectError) -> Self { + match error { + ConnectError::IO(error) => error, + _ => io::Error::other(error), + } + } +} + #[derive(thiserror::Error)] #[error("{error}")] pub struct TryConnectError { @@ -51,3 +65,15 @@ impl From for TryConnectError { } } } + +impl From> for ConnectError { + fn from(value: TryConnectError) -> Self { + value.error + } +} + +impl From> for io::Error { + fn from(error: TryConnectError) -> Self { + error.error.into() + } +} diff --git a/ktls/src/generic.rs b/ktls/src/generic.rs index 0558ef7..d466180 100644 --- a/ktls/src/generic.rs +++ b/ktls/src/generic.rs @@ -11,7 +11,7 @@ use crate::{KTlsClientStream, KTlsServerStream, Side}; pin_project_lite::pin_project! { /// A wrapper around an `IO` that takes care of managing kTLS state for /// its underlying fd. - /// + /// /// This is the generic version of [`KTlsClientStream`] and [`KTlsServerStream`]. /// It cannot be constructed directly. Instead, construct on of the two more /// specific streams above and then convert them into [`KTlsStream`]. diff --git a/ktls/src/lib.rs b/ktls/src/lib.rs index ca03196..207ab2c 100644 --- a/ktls/src/lib.rs +++ b/ktls/src/lib.rs @@ -42,11 +42,12 @@ mod generic; mod protocol; mod server; mod stream; +mod suite; pub use crate::client::KTlsClientStream; pub use crate::error::{ConnectError, TryConnectError}; pub use crate::generic::KTlsStream; -pub use crate::server::KTlsServerStream; +pub use crate::server::{KTlsAcceptor, KTlsServerStream}; pub use crate::stream::{KTlsStreamError, Side}; #[derive(Debug, Default)] diff --git a/ktls/src/server.rs b/ktls/src/server.rs index 507192e..dbf32da 100644 --- a/ktls/src/server.rs +++ b/ktls/src/server.rs @@ -1,15 +1,176 @@ use std::io; use std::os::fd::{AsRawFd, RawFd}; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; use rustls::kernel::KernelConnection; use rustls::server::{ServerConnectionData, UnbufferedServerConnection}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use rustls::unbuffered::{ConnectionState, EncodeError, UnbufferedStatus}; +use rustls::ServerConfig; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use crate::ffi::Direction; +use crate::protocol::read_record; use crate::stream::KTlsStreamImpl; -use crate::{ConnectError, CryptoInfo, TryConnectError}; +use crate::{CompatibleCiphers, ConnectError, CryptoInfo, TryConnectError}; + +/// A wrapper around [`rustls::ServerConfig`] which provides an async `accept` +/// method using kTLS. +/// +/// # Cipher Support +/// kTLS only has supposrt for a limited set of TLS ciphers. These can differ +/// based on the current kernel version and whether support for kTLS was +/// compiled in to the running kernel. If cipher negotiation selects a cipher +/// which is not supported by the current kernel, then you will get an error +/// when accepting the connection. +pub struct KTlsAcceptor { + config: Arc, +} + +impl KTlsAcceptor { + pub fn new(config: Arc) -> Self { + Self { config } + } + + pub async fn accept(&self, socket: IO) -> Result, ConnectError> + where + IO: AsyncWrite + AsyncRead + AsRawFd + Unpin, + { + Ok(self.try_accept(socket).await?) + } + + pub async fn try_accept( + &self, + mut socket: IO, + ) -> Result, TryConnectError> + where + IO: AsyncWrite + AsyncRead + AsRawFd + Unpin, + { + let mut conn = match UnbufferedServerConnection::new(self.config.clone()) { + Ok(conn) => conn, + Err(e) => { + return Err(TryConnectError { + error: ConnectError::Config(e), + socket: Some(socket), + conn: None, + }) + } + }; + + // We attempt to set up the TLS ULP before doing anything else so that + // we can indicate that the kernel doesn't support kTLS before returning + // any other error. + // + // This is also needed to prevent errors in one specific case: if we set + // up the ULP after the handshake has completed then a peer connected on + // localhost can immediately send its data and close the connection + // before we can call setsockopt. In this case we would get an error, even + // though no error has actually occurred. + if let Err(e) = crate::ffi::setup_ulp(socket.as_raw_fd()) { + let error = if e.raw_os_error() == Some(libc::ENOENT) { + ConnectError::KTlsUnsupported + } else { + ConnectError::IO(e) + }; + + return Err(TryConnectError { + error, + socket: Some(socket), + conn: Some(conn), + }); + } + + let mut incoming = Vec::with_capacity(1024); + let mut outgoing = Vec::with_capacity(1024); + let mut outgoing_used = 0usize; + let mut early = Vec::new(); + + loop { + let UnbufferedStatus { mut discard, state } = conn.process_tls_records(&mut incoming); + let state = match state { + Ok(state) => state, + Err(e) => return Err(ConnectError::Handshake(e).into()), + }; + + match state { + ConnectionState::BlockedHandshake => { + read_record(&mut socket, &mut incoming) + .await + .map_err(ConnectError::IO)?; + } + ConnectionState::PeerClosed | ConnectionState::Closed => { + return Err(TryConnectError { + error: ConnectError::ConnectionClosedBeforeHandshakeCompleted, + socket: Some(socket), + conn: None, + }) + } + ConnectionState::ReadEarlyData(mut data) => { + while let Some(record) = data.next_record() { + let record = record.map_err(ConnectError::Handshake)?; + discard += record.discard; + early.extend_from_slice(record.payload); + } + } + ConnectionState::EncodeTlsData(mut data) => { + match data.encode(&mut outgoing[outgoing_used..]) { + Ok(count) => outgoing_used += count, + Err(EncodeError::AlreadyEncoded) => unreachable!(), + Err(EncodeError::InsufficientSize(e)) => { + outgoing.resize(outgoing_used + e.required_size, 0u8); + + match data.encode(&mut outgoing[outgoing_used..]) { + Ok(count) => outgoing_used += count, + Err(e) => unreachable!("encode failed after resizing buffer: {e}"), + } + } + } + } + ConnectionState::TransmitTlsData(data) => { + socket + .write_all(&outgoing[..outgoing_used]) + .await + .map_err(ConnectError::IO)?; + outgoing_used = 0; + data.done(); + } + ConnectionState::WriteTraffic(_) => { + incoming.drain(..discard); + break; + } + ConnectionState::ReadTraffic(_) => unreachable!( + "ReadTraffic should not be encountered during the handshake process" + ), + _ => unreachable!("unexpected connection state"), + } + + incoming.drain(..discard); + } + + // We validate ciphers here as a convenience to produce better errors. + // We explicitly don't want to fail to create a kTLS cipher if probing + // fails, since the probe failing doesn't necessarily mean that creating + // this connection will fail. + if let Ok(support) = CompatibleCiphers::new().await { + let suite = conn.negotiated_cipher_suite().ok_or_else(|| { + ConnectError::Handshake(rustls::Error::General( + "handshake completed but no negotiated cipher suite is present".into(), + )) + })?; + + if !support.is_compatible(suite) { + return Err(TryConnectError { + error: ConnectError::UnsupportedCipherSuite(suite), + socket: Some(socket), + conn: Some(conn), + }); + } + } + + KTlsServerStream::from_unbuffered_connnection_validate(socket, early, conn).await + } +} pin_project_lite::pin_project! { /// The server half of a kTLS stream. @@ -23,7 +184,7 @@ impl KTlsServerStream where IO: AsyncWrite + AsyncRead + AsRawFd, { - pub fn from_unbuffered_connnection( + pub async fn from_unbuffered_connnection( socket: IO, conn: UnbufferedServerConnection, ) -> Result> { @@ -44,16 +205,42 @@ where }); } + Self::from_unbuffered_connection_validate(socket, Vec::new(), conn).await + } + + async fn from_unbuffered_connection_validate( + socket: IO, + early_data: Vec, + conn: UnbufferedServerConnection, + ) -> Result> { + // We validate ciphers here as a convenience to produce better errors. + // We explicitly don't want to fail to create a kTLS cipher if probing + // fails, since the probe failing doesn't necessarily mean that creating + // this connection will fail. + if let Ok(support) = CompatibleCiphers::new().await { + let suite = conn.negotiated_cipher_suite().ok_or_else(|| { + ConnectError::Handshake(rustls::Error::General( + "handshake completed but no negotiated cipher suite is present".into(), + )) + })?; + + if !support.is_compatible(suite) { + return Err(TryConnectError { + error: ConnectError::UnsupportedCipherSuite(suite), + socket: Some(socket), + conn: Some(conn), + }); + } + } + Ok(Self::from_unbuffered_connnection_with_tls_ulp_enabled( - socket, - Vec::new(), - conn, + socket, early_data, conn, )?) } /// Create a new `KTlsServerStream` from a socket that already has had the TLS ULP /// enabled on it. - pub fn from_unbuffered_connnection_with_tls_ulp_enabled( + fn from_unbuffered_connnection_with_tls_ulp_enabled( socket: IO, early_data: Vec, conn: UnbufferedServerConnection, diff --git a/ktls/src/suite.rs b/ktls/src/suite.rs new file mode 100644 index 0000000..f5cf6a0 --- /dev/null +++ b/ktls/src/suite.rs @@ -0,0 +1,242 @@ +use std::io; +use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::os::fd::AsRawFd; +use std::sync::atomic::AtomicU32; +use std::sync::OnceLock; + +use ktls_sys::bindings as sys; +use rustls::CipherSuite; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::OnceCell; + +use crate::ffi::{CryptoInfo, Direction}; +use crate::{KtlsCipherSuite, KtlsCipherType, KtlsVersion}; + +#[derive(Debug, Default)] +pub struct CompatibleCiphers { + pub tls12: CompatibleCiphersForVersion, + pub tls13: CompatibleCiphersForVersion, +} + +#[derive(Debug, Default)] +pub struct CompatibleCiphersForVersion { + pub aes_gcm_128: bool, + pub aes_gcm_256: bool, + pub chacha20_poly1305: bool, +} + +static COMPATIBLE: OnceCell = OnceCell::new(); + +impl CompatibleCiphers { + /// List compatible ciphers. This listens on a TCP socket and blocks for a + /// little while. Do once at the very start of a program. Should probably be + /// behind a lazy_static / once_cell + pub async fn new() -> Result { + COMPATIBLE.get_or_try_init(Self::probe()).await + } + + /// Returns true if we're reasonably confident that functions like + /// [config_ktls_client] and [config_ktls_server] will succeed. + pub fn is_compatible(&self, suite: SupportedCipherSuite) -> bool { + let kcs = match KtlsCipherSuite::try_from(suite) { + Ok(kcs) => kcs, + Err(_) => return false, + }; + + let fields = match kcs.version { + KtlsVersion::TLS12 => &self.tls12, + KtlsVersion::TLS13 => &self.tls13, + }; + + match kcs.typ { + KtlsCipherType::AesGcm128 => fields.aes_gcm_128, + KtlsCipherType::AesGcm256 => fields.aes_gcm_256, + KtlsCipherType::Chacha20Poly1305 => fields.chacha20_poly1305, + } + } + + async fn probe() -> Result { + let mut listener = + TcpListener::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))) + .await + .map_err(CipherProbeError::Listener)?; + let local_addr = listener.local_addr().map_err(CipherProbeError::Listener)?; + + Ok(Self { + tls12: CompatibleCiphersForVersion { + aes_gcm_128: Self::probe_suite( + &mut listener, + local_addr, + KtlsVersion::TLS12, + KtlsCipherType::AesGcm128, + ) + .await?, + aes_gcm_256: Self::probe_suite( + &mut listener, + local_addr, + KtlsVersion::TLS12, + KtlsCipherType::AesGcm256, + ) + .await?, + chacha20_poly1305: Self::probe_suite( + &mut listener, + local_addr, + KtlsVersion::TLS12, + KtlsCipherType::Chacha20Poly1305, + ) + .await?, + }, + tls13: CompatibleCiphersForVersion { + aes_gcm_128: Self::probe_suite( + &mut listener, + local_addr, + KtlsVersion::TLS13, + KtlsCipherType::AesGcm128, + ) + .await?, + aes_gcm_256: Self::probe_suite( + &mut listener, + local_addr, + KtlsVersion::TLS13, + KtlsCipherType::AesGcm256, + ) + .await?, + chacha20_poly1305: Self::probe_suite( + &mut listener, + local_addr, + KtlsVersion::TLS13, + KtlsCipherType::Chacha20Poly1305, + ) + .await?, + }, + }) + } + + async fn probe_suite( + listener: &mut TcpListener, + local_addr: SocketAddr, + version: KtlsVersion, + suite: KtlsCipherType, + ) -> Result { + let stream = TcpStream::connect(local_addr) + .await + .map_err(CipherProbeError::Connect)?; + let _other = listener + .accept() + .await + .map_err(CipherProbeError::Listener)?; + + let version = match version { + KtlsVersion::TLS12 => crate::ffi::TLS_1_2_VERSION_NUMBER, + KtlsVersion::TLS13 => crate::ffi::TLS_1_3_VERSION_NUMBER, + }; + + let crypto_info = match kcs.typ { + KtlsCipherType::AesGcm128 => { + CryptoInfo::AesGcm128(sys::tls12_crypto_info_aes_gcm_128 { + info: sys::tls_crypto_info { + version: ffi_version, + cipher_type: sys::TLS_CIPHER_AES_GCM_128 as _, + }, + ..Default::default() + }) + } + KtlsCipherType::AesGcm256 => { + CryptoInfo::AesGcm256(sys::tls12_crypto_info_aes_gcm_256 { + info: sys::tls_crypto_info { + version: ffi_version, + cipher_type: sys::TLS_CIPHER_AES_GCM_256 as _, + }, + ..Default::default() + }) + } + KtlsCipherType::Chacha20Poly1305 => { + CryptoInfo::Chacha20Poly1305(sys::tls12_crypto_info_chacha20_poly1305 { + info: sys::tls_crypto_info { + version: ffi_version, + cipher_type: sys::TLS_CIPHER_CHACHA20_POLY1305 as _, + }, + ..Default::default() + }) + } + }; + + let fd = stream.as_raw_fd(); + + match crate::ffi::setup_ulp(fd) { + Ok(()) => (), + // Interpret the kernel not supporting kTLS as the suite not being supported. + Err(e) if e.raw_os_error() == Some(libc::ENOENT) => return Ok(false), + Err(e) => return Err(CipherProbeError::Ulp(e)), + } + + Ok(crate::ffi::setup_tls_info(fd, Direction::Tx, crypto_info).is_ok()) + } +} + +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum CipherProbeError { + #[error("failed to listen on a local socket: {0}")] + Listener(#[source] io::Error), + + #[error("failed to connect to the local socket: {0}")] + Connect(#[source] io::Error), + + #[error("failed to set up the TLS upper-level-protocol on the socket: {0}")] + Ulp(#[source] io::Error), +} + +fn sample_cipher_setup(sock: &TcpStream, cipher_suite: SupportedCipherSuite) -> Result<(), Error> { + let kcs = match KtlsCipherSuite::try_from(cipher_suite) { + Ok(kcs) => kcs, + Err(_) => panic!("unsupported cipher suite"), + }; + + let ffi_version = match kcs.version { + KtlsVersion::TLS12 => ffi::TLS_1_2_VERSION_NUMBER, + KtlsVersion::TLS13 => ffi::TLS_1_3_VERSION_NUMBER, + }; + + let crypto_info = match kcs.typ { + KtlsCipherType::AesGcm128 => CryptoInfo::AesGcm128(sys::tls12_crypto_info_aes_gcm_128 { + info: sys::tls_crypto_info { + version: ffi_version, + cipher_type: sys::TLS_CIPHER_AES_GCM_128 as _, + }, + iv: Default::default(), + key: Default::default(), + salt: Default::default(), + rec_seq: Default::default(), + }), + KtlsCipherType::AesGcm256 => CryptoInfo::AesGcm256(sys::tls12_crypto_info_aes_gcm_256 { + info: sys::tls_crypto_info { + version: ffi_version, + cipher_type: sys::TLS_CIPHER_AES_GCM_256 as _, + }, + iv: Default::default(), + key: Default::default(), + salt: Default::default(), + rec_seq: Default::default(), + }), + KtlsCipherType::Chacha20Poly1305 => { + CryptoInfo::Chacha20Poly1305(sys::tls12_crypto_info_chacha20_poly1305 { + info: sys::tls_crypto_info { + version: ffi_version, + cipher_type: sys::TLS_CIPHER_CHACHA20_POLY1305 as _, + }, + iv: Default::default(), + key: Default::default(), + salt: Default::default(), + rec_seq: Default::default(), + }) + } + }; + let fd = sock.as_raw_fd(); + + setup_ulp(fd).map_err(Error::UlpError)?; + + setup_tls_info(fd, ffi::Direction::Tx, crypto_info).map_err(Error::TlsCryptoInfoError)?; + + Ok(()) +} From fc0baafd17c81b03bc0b2add09376e94c5e185ea Mon Sep 17 00:00:00 2001 From: Sean Lynch <42618346+swlynch99@users.noreply.github.com> Date: Fri, 6 Jun 2025 17:57:59 -0700 Subject: [PATCH 8/8] Fix compilation errors and remove some code --- ktls/Cargo.toml | 2 +- ktls/src/client.rs | 4 +- ktls/src/error.rs | 2 - ktls/src/lib.rs | 326 +-------------------------------------------- ktls/src/server.rs | 6 +- ktls/src/suite.rs | 94 ++++--------- 6 files changed, 33 insertions(+), 401 deletions(-) diff --git a/ktls/Cargo.toml b/ktls/Cargo.toml index f20f817..97e0e30 100644 --- a/ktls/Cargo.toml +++ b/ktls/Cargo.toml @@ -21,7 +21,7 @@ rustls = { version = "0.23.27", default-features = false } smallvec = "1.13.2" memoffset = "0.9.1" pin-project-lite = "0.2.14" -tokio = { version = "1.39.2", features = ["net", "macros", "io-util"] } +tokio = { version = "1.39.2", features = ["net", "macros", "io-util", "sync"] } ktls-sys = "1.0.1" num_enum = "0.7.3" futures-util = "0.3.30" diff --git a/ktls/src/client.rs b/ktls/src/client.rs index 705c62f..36ba090 100644 --- a/ktls/src/client.rs +++ b/ktls/src/client.rs @@ -66,9 +66,9 @@ where let suite = kconn.negotiated_cipher_suite(); let tx = CryptoInfo::from_rustls(suite, secrets.tx) - .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + .map_err(|_| ConnectError::UnsupportedCipherSuite(suite))?; let rx = CryptoInfo::from_rustls(suite, secrets.rx) - .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + .map_err(|_| ConnectError::UnsupportedCipherSuite(suite))?; crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Tx, tx) .map_err(ConnectError::IO)?; diff --git a/ktls/src/error.rs b/ktls/src/error.rs index d12209d..6792e9e 100644 --- a/ktls/src/error.rs +++ b/ktls/src/error.rs @@ -2,8 +2,6 @@ use std::{fmt, io}; use rustls::SupportedCipherSuite; -use crate::suite::CipherProbeError; - #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum ConnectError { diff --git a/ktls/src/lib.rs b/ktls/src/lib.rs index 207ab2c..3634636 100644 --- a/ktls/src/lib.rs +++ b/ktls/src/lib.rs @@ -1,7 +1,4 @@ -use ffi::{setup_tls_info, setup_ulp, KtlsCompatibilityError}; -use futures_util::future::try_join_all; -use ktls_sys::bindings as sys; -use rustls::{Connection, SupportedCipherSuite, SupportedProtocolVersion}; +use rustls::{SupportedCipherSuite, SupportedProtocolVersion}; #[cfg(all(not(feature = "ring"), not(feature = "aws_lc_rs")))] compile_error!("This crate needs wither the 'ring' or 'aws_lc_rs' feature enabled"); @@ -12,18 +9,6 @@ use rustls::crypto::aws_lc_rs::cipher_suite; #[cfg(feature = "ring")] use rustls::crypto::ring::cipher_suite; -use smallvec::SmallVec; -use std::{ - future::Future, - io, - net::SocketAddr, - os::unix::prelude::{AsRawFd, RawFd}, -}; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite}, - net::{TcpListener, TcpStream}, -}; - mod ffi; pub use crate::ffi::CryptoInfo; @@ -49,314 +34,7 @@ pub use crate::error::{ConnectError, TryConnectError}; pub use crate::generic::KTlsStream; pub use crate::server::{KTlsAcceptor, KTlsServerStream}; pub use crate::stream::{KTlsStreamError, Side}; - -#[derive(Debug, Default)] -pub struct CompatibleCiphers { - pub tls12: CompatibleCiphersForVersion, - pub tls13: CompatibleCiphersForVersion, -} - -#[derive(Debug, Default)] -pub struct CompatibleCiphersForVersion { - pub aes_gcm_128: bool, - pub aes_gcm_256: bool, - pub chacha20_poly1305: bool, -} - -impl CompatibleCiphers { - /// List compatible ciphers. This listens on a TCP socket and blocks for a - /// little while. Do once at the very start of a program. Should probably be - /// behind a lazy_static / once_cell - pub async fn new() -> io::Result { - let mut ciphers = CompatibleCiphers::default(); - - let ln = TcpListener::bind("0.0.0.0:0").await?; - let local_addr = ln.local_addr()?; - - // Accepted conns of ln - let mut accepted_conns: SmallVec<[TcpStream; 12]> = SmallVec::new(); - - let accept_conns_fut = async { - loop { - if let Ok((conn, _addr)) = ln.accept().await { - accepted_conns.push(conn); - } - } - }; - - ciphers.test_ciphers(local_addr, accept_conns_fut).await?; - - Ok(ciphers) - } - - async fn test_ciphers( - &mut self, - local_addr: SocketAddr, - accept_conns_fut: impl Future, - ) -> io::Result<()> { - let ciphers: Vec<(SupportedCipherSuite, &mut bool)> = vec![ - ( - cipher_suite::TLS13_AES_128_GCM_SHA256, - &mut self.tls13.aes_gcm_128, - ), - ( - cipher_suite::TLS13_AES_256_GCM_SHA384, - &mut self.tls13.aes_gcm_256, - ), - ( - cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, - &mut self.tls13.chacha20_poly1305, - ), - ( - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - &mut self.tls12.aes_gcm_128, - ), - ( - cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - &mut self.tls12.aes_gcm_256, - ), - ( - cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - &mut self.tls12.chacha20_poly1305, - ), - ]; - - let create_connections_fut = - try_join_all((0..ciphers.len()).map(|_| TcpStream::connect(local_addr))); - - let socks = tokio::select! { - // Use biased here to optimize performance. - // - // With biased, tokio::select! would first poll create_connections_fut, - // which would poll all `TcpStream::connect` futures and requests - // new connections to `ln` then returns `Poll::Pending`. - // - // Then accept_conns_fut would be polled, which accepts all pending - // connections, wake up create_connections_fut then returns - // `Poll::Pending`. - // - // Finally, create_connections_fut wakes up and all connections - // are ready, the result is collected into a Vec and ends - // the tokio::select!. - biased; - - res = create_connections_fut => res?, - _ = accept_conns_fut => unreachable!(), - }; - - assert_eq!(ciphers.len(), socks.len()); - - ciphers - .into_iter() - .zip(socks) - .for_each(|((cipher_suite, field), sock)| { - *field = sample_cipher_setup(&sock, cipher_suite).is_ok(); - }); - - Ok(()) - } - - /// Returns true if we're reasonably confident that functions like - /// [config_ktls_client] and [config_ktls_server] will succeed. - pub fn is_compatible(&self, suite: SupportedCipherSuite) -> bool { - let kcs = match KtlsCipherSuite::try_from(suite) { - Ok(kcs) => kcs, - Err(_) => return false, - }; - - let fields = match kcs.version { - KtlsVersion::TLS12 => &self.tls12, - KtlsVersion::TLS13 => &self.tls13, - }; - - match kcs.typ { - KtlsCipherType::AesGcm128 => fields.aes_gcm_128, - KtlsCipherType::AesGcm256 => fields.aes_gcm_256, - KtlsCipherType::Chacha20Poly1305 => fields.chacha20_poly1305, - } - } -} - -fn sample_cipher_setup(sock: &TcpStream, cipher_suite: SupportedCipherSuite) -> Result<(), Error> { - let kcs = match KtlsCipherSuite::try_from(cipher_suite) { - Ok(kcs) => kcs, - Err(_) => panic!("unsupported cipher suite"), - }; - - let ffi_version = match kcs.version { - KtlsVersion::TLS12 => ffi::TLS_1_2_VERSION_NUMBER, - KtlsVersion::TLS13 => ffi::TLS_1_3_VERSION_NUMBER, - }; - - let crypto_info = match kcs.typ { - KtlsCipherType::AesGcm128 => CryptoInfo::AesGcm128(sys::tls12_crypto_info_aes_gcm_128 { - info: sys::tls_crypto_info { - version: ffi_version, - cipher_type: sys::TLS_CIPHER_AES_GCM_128 as _, - }, - iv: Default::default(), - key: Default::default(), - salt: Default::default(), - rec_seq: Default::default(), - }), - KtlsCipherType::AesGcm256 => CryptoInfo::AesGcm256(sys::tls12_crypto_info_aes_gcm_256 { - info: sys::tls_crypto_info { - version: ffi_version, - cipher_type: sys::TLS_CIPHER_AES_GCM_256 as _, - }, - iv: Default::default(), - key: Default::default(), - salt: Default::default(), - rec_seq: Default::default(), - }), - KtlsCipherType::Chacha20Poly1305 => { - CryptoInfo::Chacha20Poly1305(sys::tls12_crypto_info_chacha20_poly1305 { - info: sys::tls_crypto_info { - version: ffi_version, - cipher_type: sys::TLS_CIPHER_CHACHA20_POLY1305 as _, - }, - iv: Default::default(), - key: Default::default(), - salt: Default::default(), - rec_seq: Default::default(), - }) - } - }; - let fd = sock.as_raw_fd(); - - setup_ulp(fd).map_err(Error::UlpError)?; - - setup_tls_info(fd, ffi::Direction::Tx, crypto_info).map_err(Error::TlsCryptoInfoError)?; - - Ok(()) -} - -#[derive(thiserror::Error, Debug)] -pub enum Error { - #[error("failed to enable TLS ULP (upper level protocol): {0}")] - UlpError(#[source] std::io::Error), - - #[error("kTLS compatibility error: {0}")] - KtlsCompatibility(#[from] KtlsCompatibilityError), - - #[error("failed to export secrets")] - ExportSecrets(#[source] rustls::Error), - - #[error("failed to configure tx/rx (unsupported cipher?): {0}")] - TlsCryptoInfoError(#[source] std::io::Error), - - #[error("an I/O occured while draining the rustls stream: {0}")] - DrainError(#[source] std::io::Error), - - #[error("no negotiated cipher suite: call config_ktls_* only /after/ the handshake")] - NoNegotiatedCipherSuite, -} - -/// Configure kTLS for this socket. If this call succeeds, data can be written -/// and read from this socket, and the kernel takes care of encryption -/// transparently. I'm not clear how rekeying is handled (probably via control -/// messages, but can't find a code sample for it). -/// -/// The inner IO type must be wrapped in [CorkStream] since it's the only way -/// to drain a rustls stream cleanly. See its documentation for details. -pub async fn config_ktls_server( - mut stream: tokio_rustls::server::TlsStream>, -) -> Result, Error> -where - IO: AsRawFd + AsyncRead + AsyncReadReady + AsyncWrite + Unpin, -{ - stream.get_mut().0.corked = true; - let drained = drain(&mut stream).await.map_err(Error::DrainError)?; - let (io, conn) = stream.into_inner(); - let io = io.io; - - setup_inner(io.as_raw_fd(), Connection::Server(conn))?; - Ok(KtlsStream::new(io, drained)) -} - -/// Configure kTLS for this socket. If this call succeeds, data can be -/// written and read from this socket, and the kernel takes care of encryption -/// (and key updates, etc.) transparently. -/// -/// The inner IO type must be wrapped in [CorkStream] since it's the only way -/// to drain a rustls stream cleanly. See its documentation for details. -pub async fn config_ktls_client( - mut stream: tokio_rustls::client::TlsStream>, -) -> Result, Error> -where - IO: AsRawFd + AsyncRead + AsyncWrite + Unpin, -{ - stream.get_mut().0.corked = true; - let drained = drain(&mut stream).await.map_err(Error::DrainError)?; - let (io, conn) = stream.into_inner(); - let io = io.io; - - setup_inner(io.as_raw_fd(), Connection::Client(conn))?; - Ok(KtlsStream::new(io, drained)) -} - -/// Read all the bytes we can read without blocking. This is used to drained the -/// already-decrypted buffer from a tokio-rustls I/O type -async fn drain(stream: &mut (impl AsyncRead + Unpin)) -> std::io::Result>> { - tracing::trace!("Draining rustls stream"); - let mut drained = vec![0u8; 128 * 1024]; - let mut filled = 0; - - loop { - tracing::trace!("stream.read called"); - let n = match stream.read(&mut drained[filled..]).await { - Ok(n) => n, - Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - // actually this is expected for us! - tracing::trace!("stream.read returned UnexpectedEof, that's expected for us"); - break; - } - Err(e) => { - tracing::trace!("stream.read returned error: {e}"); - return Err(e); - } - }; - tracing::trace!("stream.read returned {n}"); - if n == 0 { - // that's what CorkStream returns when it's at a message boundary - break; - } - filled += n; - } - - let maybe_drained = if filled == 0 { - None - } else { - tracing::trace!("Draining rustls stream done: drained {filled} bytes"); - drained.resize(filled, 0); - Some(drained) - }; - Ok(maybe_drained) -} - -fn setup_inner(fd: RawFd, conn: Connection) -> Result<(), Error> { - let cipher_suite = match conn.negotiated_cipher_suite() { - Some(cipher_suite) => cipher_suite, - None => { - return Err(Error::NoNegotiatedCipherSuite); - } - }; - - let secrets = match conn.dangerous_extract_secrets() { - Ok(secrets) => secrets, - Err(err) => return Err(Error::ExportSecrets(err)), - }; - - ffi::setup_ulp(fd).map_err(Error::UlpError)?; - - let tx = CryptoInfo::from_rustls(cipher_suite, secrets.tx)?; - setup_tls_info(fd, ffi::Direction::Tx, tx).map_err(Error::TlsCryptoInfoError)?; - - let rx = CryptoInfo::from_rustls(cipher_suite, secrets.rx)?; - setup_tls_info(fd, ffi::Direction::Rx, rx).map_err(Error::TlsCryptoInfoError)?; - - Ok(()) -} +pub use crate::suite::{CipherProbeError, CompatibleCiphers, CompatibleCiphersForVersion}; /// TLS versions supported by this crate #[non_exhaustive] diff --git a/ktls/src/server.rs b/ktls/src/server.rs index dbf32da..963f67f 100644 --- a/ktls/src/server.rs +++ b/ktls/src/server.rs @@ -168,7 +168,7 @@ impl KTlsAcceptor { } } - KTlsServerStream::from_unbuffered_connnection_validate(socket, early, conn).await + KTlsServerStream::from_unbuffered_connection_validate(socket, early, conn).await } } @@ -252,9 +252,9 @@ where let suite = kconn.negotiated_cipher_suite(); let tx = CryptoInfo::from_rustls(suite, secrets.tx) - .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + .map_err(|_| ConnectError::UnsupportedCipherSuite(suite))?; let rx = CryptoInfo::from_rustls(suite, secrets.rx) - .map_err(|_| ConnectError::UnsupportedCipherSuite)?; + .map_err(|_| ConnectError::UnsupportedCipherSuite(suite))?; crate::ffi::setup_tls_info(socket.as_raw_fd(), Direction::Tx, tx) .map_err(ConnectError::IO)?; diff --git a/ktls/src/suite.rs b/ktls/src/suite.rs index f5cf6a0..5d4bfeb 100644 --- a/ktls/src/suite.rs +++ b/ktls/src/suite.rs @@ -1,38 +1,39 @@ use std::io; -use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::os::fd::AsRawFd; -use std::sync::atomic::AtomicU32; -use std::sync::OnceLock; use ktls_sys::bindings as sys; -use rustls::CipherSuite; +use rustls::SupportedCipherSuite; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::OnceCell; use crate::ffi::{CryptoInfo, Direction}; use crate::{KtlsCipherSuite, KtlsCipherType, KtlsVersion}; -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub struct CompatibleCiphers { pub tls12: CompatibleCiphersForVersion, pub tls13: CompatibleCiphersForVersion, } -#[derive(Debug, Default)] +#[derive(Clone, Debug, Default)] pub struct CompatibleCiphersForVersion { pub aes_gcm_128: bool, pub aes_gcm_256: bool, pub chacha20_poly1305: bool, } -static COMPATIBLE: OnceCell = OnceCell::new(); +static COMPATIBLE: OnceCell = OnceCell::const_new(); impl CompatibleCiphers { /// List compatible ciphers. This listens on a TCP socket and blocks for a /// little while. Do once at the very start of a program. Should probably be /// behind a lazy_static / once_cell pub async fn new() -> Result { - COMPATIBLE.get_or_try_init(Self::probe()).await + COMPATIBLE + .get_or_try_init(Self::probe) + .await + .map(Self::clone) } /// Returns true if we're reasonably confident that functions like @@ -131,32 +132,41 @@ impl CompatibleCiphers { KtlsVersion::TLS13 => crate::ffi::TLS_1_3_VERSION_NUMBER, }; - let crypto_info = match kcs.typ { + let crypto_info = match suite { KtlsCipherType::AesGcm128 => { CryptoInfo::AesGcm128(sys::tls12_crypto_info_aes_gcm_128 { info: sys::tls_crypto_info { - version: ffi_version, + version, cipher_type: sys::TLS_CIPHER_AES_GCM_128 as _, }, - ..Default::default() + iv: Default::default(), + key: Default::default(), + salt: Default::default(), + rec_seq: Default::default(), }) } KtlsCipherType::AesGcm256 => { CryptoInfo::AesGcm256(sys::tls12_crypto_info_aes_gcm_256 { info: sys::tls_crypto_info { - version: ffi_version, + version, cipher_type: sys::TLS_CIPHER_AES_GCM_256 as _, }, - ..Default::default() + iv: Default::default(), + key: Default::default(), + salt: Default::default(), + rec_seq: Default::default(), }) } KtlsCipherType::Chacha20Poly1305 => { CryptoInfo::Chacha20Poly1305(sys::tls12_crypto_info_chacha20_poly1305 { info: sys::tls_crypto_info { - version: ffi_version, + version, cipher_type: sys::TLS_CIPHER_CHACHA20_POLY1305 as _, }, - ..Default::default() + iv: Default::default(), + key: Default::default(), + salt: Default::default(), + rec_seq: Default::default(), }) } }; @@ -186,57 +196,3 @@ pub enum CipherProbeError { #[error("failed to set up the TLS upper-level-protocol on the socket: {0}")] Ulp(#[source] io::Error), } - -fn sample_cipher_setup(sock: &TcpStream, cipher_suite: SupportedCipherSuite) -> Result<(), Error> { - let kcs = match KtlsCipherSuite::try_from(cipher_suite) { - Ok(kcs) => kcs, - Err(_) => panic!("unsupported cipher suite"), - }; - - let ffi_version = match kcs.version { - KtlsVersion::TLS12 => ffi::TLS_1_2_VERSION_NUMBER, - KtlsVersion::TLS13 => ffi::TLS_1_3_VERSION_NUMBER, - }; - - let crypto_info = match kcs.typ { - KtlsCipherType::AesGcm128 => CryptoInfo::AesGcm128(sys::tls12_crypto_info_aes_gcm_128 { - info: sys::tls_crypto_info { - version: ffi_version, - cipher_type: sys::TLS_CIPHER_AES_GCM_128 as _, - }, - iv: Default::default(), - key: Default::default(), - salt: Default::default(), - rec_seq: Default::default(), - }), - KtlsCipherType::AesGcm256 => CryptoInfo::AesGcm256(sys::tls12_crypto_info_aes_gcm_256 { - info: sys::tls_crypto_info { - version: ffi_version, - cipher_type: sys::TLS_CIPHER_AES_GCM_256 as _, - }, - iv: Default::default(), - key: Default::default(), - salt: Default::default(), - rec_seq: Default::default(), - }), - KtlsCipherType::Chacha20Poly1305 => { - CryptoInfo::Chacha20Poly1305(sys::tls12_crypto_info_chacha20_poly1305 { - info: sys::tls_crypto_info { - version: ffi_version, - cipher_type: sys::TLS_CIPHER_CHACHA20_POLY1305 as _, - }, - iv: Default::default(), - key: Default::default(), - salt: Default::default(), - rec_seq: Default::default(), - }) - } - }; - let fd = sock.as_raw_fd(); - - setup_ulp(fd).map_err(Error::UlpError)?; - - setup_tls_info(fd, ffi::Direction::Tx, crypto_info).map_err(Error::TlsCryptoInfoError)?; - - Ok(()) -}