diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..d0a9774 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,12 @@ +{ + "permissions": { + "allow": [ + "Bash(cargo test)", + "Bash(cargo test:*)", + "Bash(cargo clippy:*)", + "Bash(cargo fix:*)" + ], + "deny": [], + "ask": [] + } +} \ No newline at end of file diff --git a/src/cursor/mod.rs b/src/cursor/mod.rs index 2c58de9..d9719cf 100644 --- a/src/cursor/mod.rs +++ b/src/cursor/mod.rs @@ -9,10 +9,10 @@ mod tests; mod tokio_imp; use crate::{BufList, errors::ReadExactError}; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use std::{ cmp::Ordering, - io::{self, IoSliceMut, SeekFrom}, + io::{self, IoSlice, IoSliceMut, SeekFrom}, }; /// A `Cursor` wraps an in-memory `BufList` and provides it with a [`Seek`] implementation. @@ -195,6 +195,58 @@ impl> io::BufRead for Cursor { } } +impl> Buf for Cursor { + fn remaining(&self) -> usize { + let total = self.data.num_bytes(self.inner.as_ref()); + total.saturating_sub(self.data.pos) as usize + } + + fn chunk(&self) -> &[u8] { + self.data.fill_buf_impl(self.inner.as_ref()) + } + + fn advance(&mut self, amt: usize) { + self.data.consume_impl(self.inner.as_ref(), amt); + } + + fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize { + if iovs.is_empty() { + return 0; + } + + let list = self.inner.as_ref(); + let mut filled = 0; + let mut current_chunk = self.data.chunk; + let mut current_pos = self.data.pos; + + // Iterate through chunks starting from the current position + while filled < iovs.len() && current_chunk < list.num_chunks() { + if let Some(chunk) = list.get_chunk(current_chunk) { + let chunk_start_pos = list.get_start_pos()[current_chunk]; + let offset_in_chunk = (current_pos - chunk_start_pos) as usize; + + if offset_in_chunk < chunk.len() { + let chunk_slice = &chunk.as_ref()[offset_in_chunk..]; + iovs[filled] = IoSlice::new(chunk_slice); + filled += 1; + } + + current_chunk += 1; + // Move to the start of the next chunk + if let Some(&next_start_pos) = list.get_start_pos().get(current_chunk) { + current_pos = next_start_pos; + } else { + break; + } + } else { + break; + } + } + + filled + } +} + #[derive(Clone, Debug)] struct CursorData { /// The chunk number the cursor is pointing to. Kept in sync with pos. diff --git a/src/cursor/tests.rs b/src/cursor/tests.rs index 97c2096..2adf719 100644 --- a/src/cursor/tests.rs +++ b/src/cursor/tests.rs @@ -56,6 +56,15 @@ enum CursorOp { // fill_buf can't be tested here because oracle is a contiguous block. Instead, we check its // return value separately. Consume(prop::sample::Index), + // Buf trait operations + BufRemaining, + BufChunk, + BufAdvance(prop::sample::Index), + BufChunksVectored(prop::sample::Index), + BufCopyToBytes(prop::sample::Index), + BufGetU8, + BufGetU64, + BufGetU64Le, // No need to test futures03 imps since they're simple wrappers around the main imps. #[cfg(feature = "tokio1")] PollRead { @@ -183,6 +192,204 @@ impl CursorOp { buf_list.consume(amt); oracle.consume(amt); } + Self::BufRemaining => { + eprintln!("buf_remaining"); + + let buf_list_remaining = buf_list.remaining(); + let oracle_remaining = oracle.remaining(); + ensure!( + buf_list_remaining == oracle_remaining, + "remaining didn't match: buf_list {} == oracle {}", + buf_list_remaining, + oracle_remaining + ); + } + Self::BufChunk => { + eprintln!("buf_chunk"); + + let buf_list_chunk = buf_list.chunk(); + let oracle_chunk = oracle.chunk(); + + // We can't directly compare chunks because BufList returns one + // segment at a time while oracle returns the entire remaining + // buffer. Instead, verify that: + // + // 1. is_empty matches for both chunks. + // 2. Both start with the same data (buf_list's chunk is a prefix of oracle's) + ensure!( + buf_list_chunk.is_empty() == oracle_chunk.is_empty(), + "chunk emptiness didn't match: buf_list is_empty {} == oracle is_empty {}", + buf_list_chunk.is_empty(), + oracle_chunk.is_empty() + ); + + if !buf_list_chunk.is_empty() { + // Verify buf_list's chunk is a prefix of oracle's chunk + ensure!( + oracle_chunk.starts_with(buf_list_chunk), + "buf_list chunk is not a prefix of oracle chunk" + ); + } + } + Self::BufAdvance(index) => { + let amt = index.index(1 + num_bytes * 5 / 4); + eprintln!("buf_advance: {}", amt); + + // Skip if already past the end, as the oracle's Buf impl has a debug assertion + // that checks position even when advancing by 0 + if buf_list.remaining() > 0 || amt == 0 && oracle.remaining() > 0 { + // Cap the advance amount to the remaining bytes to avoid + // hitting the debug assertion in std::io::Cursor's Buf + // impl. While the Buf trait doesn't require this, the + // oracle has a debug_assert that panics if we try to + // advance past the end. + let amt = amt.min(buf_list.remaining()); + buf_list.advance(amt); + oracle.advance(amt); + } else { + eprintln!(" skipping: cursor past end"); + } + } + Self::BufChunksVectored(index) => { + let num_iovs = index.index(1 + num_bytes); + eprintln!("buf_chunks_vectored: {} iovs", num_iovs); + + // First verify remaining() matches + let buf_list_remaining = buf_list.remaining(); + let oracle_remaining = oracle.remaining(); + ensure!( + buf_list_remaining == oracle_remaining, + "chunks_vectored: remaining didn't match before \ + calling chunks_vectored: buf_list {} == oracle {}", + buf_list_remaining, + oracle_remaining + ); + + let mut buf_list_iovs = vec![io::IoSlice::new(&[]); num_iovs]; + let mut oracle_iovs = vec![io::IoSlice::new(&[]); num_iovs]; + + let buf_list_filled = buf_list.chunks_vectored(&mut buf_list_iovs); + let oracle_filled = oracle.chunks_vectored(&mut oracle_iovs); + + // We can't directly compare filled counts or total bytes + // because BufList may have multiple chunks while the oracle + // (std::io::Cursor) is contiguous. When there are fewer iovs + // than chunks, BufList will only fill what it can, while oracle + // fills everything into one iov. + // + // Instead, we verify that: + // 1. Both returned at least some data if there are bytes + // remaining + // 2. The data that was returned matches (buf_list's data is a + // prefix of oracle's data) + let buf_list_bytes: Vec = buf_list_iovs[..buf_list_filled] + .iter() + .flat_map(|iov| iov.as_ref().iter().copied()) + .collect(); + let oracle_bytes: Vec = oracle_iovs[..oracle_filled] + .iter() + .flat_map(|iov| iov.as_ref().iter().copied()) + .collect(); + + if buf_list_remaining > 0 && num_iovs > 0 { + // If there are bytes remaining and iovs available, should + // return some data. + ensure!( + !buf_list_bytes.is_empty(), + "chunks_vectored should return some data \ + when remaining > 0 and num_iovs > 0" + ); + ensure!( + !oracle_bytes.is_empty(), + "oracle chunks_vectored should return some data \ + when remaining > 0 and num_iovs > 0" + ); + + // Verify that buf_list's data matches the beginning of + // oracle's data. + ensure!( + oracle_bytes.starts_with(&buf_list_bytes), + "buf_list chunks_vectored data should match beginning \ + of oracle data" + ); + } else if buf_list_remaining == 0 { + // If no bytes remaining, should return no data + ensure!( + buf_list_bytes.is_empty() && oracle_bytes.is_empty(), + "chunks_vectored should return no data when \ + remaining == 0" + ); + } + // If num_iovs == 0, we can't check anything since no iovs were + // provided. All we're doing is ensuring that buf_list doesn't + // panic. + } + Self::BufCopyToBytes(index) => { + let len = index.index(1 + num_bytes * 5 / 4); + eprintln!("buf_copy_to_bytes: {}", len); + + // copy_to_bytes can panic if len > remaining, so check first + let buf_list_remaining = buf_list.remaining(); + let oracle_remaining = oracle.remaining(); + + if len <= buf_list_remaining && len <= oracle_remaining { + let buf_list_bytes = buf_list.copy_to_bytes(len); + let oracle_bytes = oracle.copy_to_bytes(len); + + ensure!(buf_list_bytes == oracle_bytes, "copy_to_bytes didn't match"); + } else { + // Both should panic, so just skip this operation + eprintln!(" skipping: len {} > remaining {}", len, buf_list_remaining); + } + } + Self::BufGetU8 => { + eprintln!("buf_get_u8"); + + if buf_list.remaining() >= 1 && oracle.remaining() >= 1 { + let buf_list_val = buf_list.get_u8(); + let oracle_val = oracle.get_u8(); + ensure!( + buf_list_val == oracle_val, + "get_u8 didn't match: buf_list {} == oracle {}", + buf_list_val, + oracle_val + ); + } else { + eprintln!(" skipping: not enough bytes remaining"); + } + } + Self::BufGetU64 => { + eprintln!("buf_get_u64"); + + if buf_list.remaining() >= 8 && oracle.remaining() >= 8 { + let buf_list_val = buf_list.get_u64(); + let oracle_val = oracle.get_u64(); + ensure!( + buf_list_val == oracle_val, + "get_u64 didn't match: buf_list {} == oracle {}", + buf_list_val, + oracle_val + ); + } else { + eprintln!(" skipping: not enough bytes remaining"); + } + } + Self::BufGetU64Le => { + eprintln!("buf_get_u64_le"); + + if buf_list.remaining() >= 8 && oracle.remaining() >= 8 { + let buf_list_val = buf_list.get_u64_le(); + let oracle_val = oracle.get_u64_le(); + ensure!( + buf_list_val == oracle_val, + "get_u64_le didn't match: buf_list {} == oracle {}", + buf_list_val, + oracle_val + ); + } else { + eprintln!(" skipping: not enough bytes remaining"); + } + } #[cfg(feature = "tokio1")] Self::PollRead { capacity, filled } => { use std::{mem::MaybeUninit, pin::Pin, task::Poll}; @@ -322,3 +529,65 @@ impl CursorOp { fn cursor_ops_strategy() -> impl Strategy> { prop::collection::vec(any::(), 0..256) } + +#[test] +fn test_cursor_buf_trait() { + // Create a BufList with multiple chunks + let mut buf_list = BufList::new(); + buf_list.push_chunk(&b"hello "[..]); + buf_list.push_chunk(&b"world"[..]); + buf_list.push_chunk(&b"!"[..]); + + let mut cursor = crate::Cursor::new(buf_list.clone()); + + // Test remaining() + assert_eq!(cursor.remaining(), 12); + + // Test chunk() + assert_eq!(cursor.chunk(), b"hello "); + + // Test advance() + cursor.advance(6); + assert_eq!(cursor.remaining(), 6); + assert_eq!(cursor.chunk(), b"world"); + + // Advance within the same chunk + cursor.advance(3); + assert_eq!(cursor.remaining(), 3); + assert_eq!(cursor.chunk(), b"ld"); + + // Advance to the next chunk + cursor.advance(2); + assert_eq!(cursor.remaining(), 1); + assert_eq!(cursor.chunk(), b"!"); + + // Advance to the end + cursor.advance(1); + assert_eq!(cursor.remaining(), 0); + assert_eq!(cursor.chunk(), b""); + + // Test chunks_vectored + let mut cursor = crate::Cursor::new(buf_list.clone()); + let mut iovs = [io::IoSlice::new(&[]); 3]; + let filled = cursor.chunks_vectored(&mut iovs); + assert_eq!(filled, 3); + assert_eq!(iovs[0].as_ref(), b"hello "); + assert_eq!(iovs[1].as_ref(), b"world"); + assert_eq!(iovs[2].as_ref(), b"!"); + + // Test chunks_vectored after advancing + cursor.advance(6); + let mut iovs = [io::IoSlice::new(&[]); 3]; + let filled = cursor.chunks_vectored(&mut iovs); + assert_eq!(filled, 2); + assert_eq!(iovs[0].as_ref(), b"world"); + assert_eq!(iovs[1].as_ref(), b"!"); + + // Test chunks_vectored with more iovs than remaining chunks + let cursor2 = crate::Cursor::new(&buf_list); + let mut iovs2 = [io::IoSlice::new(&[]); 10]; + let filled2 = cursor2.chunks_vectored(&mut iovs2); + assert_eq!(filled2, 3, "Should only fill 3 iovs for 3 chunks"); + let total_bytes: usize = iovs2[..filled2].iter().map(|iov| iov.len()).sum(); + assert_eq!(total_bytes, 12, "Total bytes should be 12"); +}