diff --git a/program-libs/batched-merkle-tree/src/batch.rs b/program-libs/batched-merkle-tree/src/batch.rs index 093e0b7815..b2c5ee7e14 100644 --- a/program-libs/batched-merkle-tree/src/batch.rs +++ b/program-libs/batched-merkle-tree/src/batch.rs @@ -1,4 +1,4 @@ -use light_bloom_filter::BloomFilter; +use light_bloom_filter::{BloomFilter, BloomFilterRef}; use light_hasher::{Hasher, Poseidon}; use light_zero_copy::vec::ZeroCopyVecU64; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; @@ -396,6 +396,20 @@ impl Batch { Ok(()) } + /// Immutable version of `check_non_inclusion` using `BloomFilterRef`. + pub fn check_non_inclusion_ref( + num_iters: usize, + bloom_filter_capacity: u64, + value: &[u8; 32], + store: &[u8], + ) -> Result<(), BatchedMerkleTreeError> { + let bloom_filter = BloomFilterRef::new(num_iters, bloom_filter_capacity, store)?; + if bloom_filter.contains(value) { + return Err(BatchedMerkleTreeError::NonInclusionCheckFailed); + } + Ok(()) + } + /// Marks the batch as inserted in the merkle tree. /// 1. Checks that the batch is ready. /// 2. increments the number of inserted zkps. diff --git a/program-libs/batched-merkle-tree/src/lib.rs b/program-libs/batched-merkle-tree/src/lib.rs index 50e4810920..f9558cdeca 100644 --- a/program-libs/batched-merkle-tree/src/lib.rs +++ b/program-libs/batched-merkle-tree/src/lib.rs @@ -178,6 +178,8 @@ pub mod initialize_state_tree; pub mod merkle_tree; pub mod merkle_tree_metadata; pub mod queue; +pub mod merkle_tree_ref; +pub mod queue_ref; pub mod queue_batch_metadata; pub mod rollover_address_tree; pub mod rollover_state_tree; diff --git a/program-libs/batched-merkle-tree/src/merkle_tree.rs b/program-libs/batched-merkle-tree/src/merkle_tree.rs index ca91e6ec9b..9197c73bcf 100644 --- a/program-libs/batched-merkle-tree/src/merkle_tree.rs +++ b/program-libs/batched-merkle-tree/src/merkle_tree.rs @@ -174,6 +174,7 @@ impl<'a> BatchedMerkleTreeAccount<'a> { account_data: &'a mut [u8], pubkey: &Pubkey, ) -> Result, BatchedMerkleTreeError> { + light_account_checks::checks::check_discriminator::(account_data)?; Self::from_bytes::(account_data, pubkey) } diff --git a/program-libs/batched-merkle-tree/src/merkle_tree_ref.rs b/program-libs/batched-merkle-tree/src/merkle_tree_ref.rs new file mode 100644 index 0000000000..0dc759d9e0 --- /dev/null +++ b/program-libs/batched-merkle-tree/src/merkle_tree_ref.rs @@ -0,0 +1,172 @@ +use std::ops::Deref; + +use light_account_checks::{ + checks::check_account_info, + discriminator::{Discriminator, DISCRIMINATOR_LEN}, + AccountInfoTrait, +}; +use light_compressed_account::{ + pubkey::Pubkey, ADDRESS_MERKLE_TREE_TYPE_V2, STATE_MERKLE_TREE_TYPE_V2, +}; +use light_merkle_tree_metadata::errors::MerkleTreeMetadataError; +use light_zero_copy::{ + cyclic_vec::ZeroCopyCyclicVecRefU64, + errors::ZeroCopyError, +}; +use zerocopy::Ref; + +use crate::{ + batch::Batch, constants::ACCOUNT_COMPRESSION_PROGRAM_ID, errors::BatchedMerkleTreeError, + merkle_tree::BatchedMerkleTreeAccount, merkle_tree_metadata::BatchedMerkleTreeMetadata, +}; + +/// Immutable batched Merkle tree reference. +/// +/// Uses `try_borrow_data()` + `&'a [u8]` instead of +/// `try_borrow_mut_data()` + `&'a mut [u8]`, avoiding UB from +/// dropping a `RefMut` guard while a raw-pointer-based mutable +/// reference continues to live. +/// +/// Only contains the fields that external consumers actually read: +/// metadata, root history, and bloom filter stores. +/// Hash chain stores are not parsed (only needed inside account-compression). +#[derive(Debug)] +pub struct BatchedMerkleTreeRef<'a> { + pubkey: Pubkey, + metadata: Ref<&'a [u8], BatchedMerkleTreeMetadata>, + root_history: ZeroCopyCyclicVecRefU64<'a, [u8; 32]>, + pub bloom_filter_stores: [&'a [u8]; 2], +} + +impl Discriminator for BatchedMerkleTreeRef<'_> { + const LIGHT_DISCRIMINATOR: [u8; 8] = *b"BatchMta"; + const LIGHT_DISCRIMINATOR_SLICE: &'static [u8] = b"BatchMta"; +} + +impl<'a> BatchedMerkleTreeRef<'a> { + /// Deserialize a batched state Merkle tree (immutable) from account info. + pub fn state_from_account_info( + account_info: &A, + ) -> Result, BatchedMerkleTreeError> { + Self::from_account_info::( + &ACCOUNT_COMPRESSION_PROGRAM_ID, + account_info, + ) + } + + /// Deserialize an address tree (immutable) from account info. + pub fn address_from_account_info( + account_info: &A, + ) -> Result, BatchedMerkleTreeError> { + Self::from_account_info::( + &ACCOUNT_COMPRESSION_PROGRAM_ID, + account_info, + ) + } + + pub(crate) fn from_account_info( + program_id: &[u8; 32], + account_info: &A, + ) -> Result, BatchedMerkleTreeError> { + check_account_info::(program_id, account_info)?; + let data = account_info.try_borrow_data()?; + // SAFETY: We extend the lifetime of the borrowed data to 'a. + // The borrow is shared (immutable), so dropping the Ref guard + // restores pinocchio's borrow state correctly for shared borrows. + let data_slice: &'a [u8] = unsafe { std::slice::from_raw_parts(data.as_ptr(), data.len()) }; + Self::from_bytes::(data_slice, &account_info.key().into()) + } + + /// Deserialize a state tree (immutable) from bytes. + #[cfg(not(target_os = "solana"))] + pub fn state_from_bytes( + account_data: &'a [u8], + pubkey: &Pubkey, + ) -> Result, BatchedMerkleTreeError> { + light_account_checks::checks::check_discriminator::( + account_data, + )?; + Self::from_bytes::(account_data, pubkey) + } + + /// Deserialize an address tree (immutable) from bytes. + #[cfg(not(target_os = "solana"))] + pub fn address_from_bytes( + account_data: &'a [u8], + pubkey: &Pubkey, + ) -> Result, BatchedMerkleTreeError> { + light_account_checks::checks::check_discriminator::( + account_data, + )?; + Self::from_bytes::(account_data, pubkey) + } + + pub(crate) fn from_bytes( + account_data: &'a [u8], + pubkey: &Pubkey, + ) -> Result, BatchedMerkleTreeError> { + // 1. Skip discriminator. + let (_discriminator, account_data) = account_data.split_at(DISCRIMINATOR_LEN); + + // 2. Parse metadata. + let (metadata, account_data) = + Ref::<&'a [u8], BatchedMerkleTreeMetadata>::from_prefix(account_data) + .map_err(ZeroCopyError::from)?; + if metadata.tree_type != TREE_TYPE { + return Err(MerkleTreeMetadataError::InvalidTreeType.into()); + } + + // 3. Parse root history (cyclic vec). + let (root_history, account_data) = + ZeroCopyCyclicVecRefU64::<[u8; 32]>::from_bytes_at(account_data)?; + + // 4. Parse bloom filter stores (immutable). + let bloom_filter_size = metadata.queue_batches.get_bloomfilter_size_bytes(); + let (bf_store_0, account_data) = account_data.split_at(bloom_filter_size); + let (bf_store_1, _account_data) = account_data.split_at(bloom_filter_size); + + // 5. Stop here -- hash_chain_stores are not needed for read-only access. + + Ok(BatchedMerkleTreeRef { + pubkey: *pubkey, + metadata, + root_history, + bloom_filter_stores: [bf_store_0, bf_store_1], + }) + } + + /// Check non-inclusion in all bloom filters which are not zeroed. + pub fn check_input_queue_non_inclusion( + &self, + value: &[u8; 32], + ) -> Result<(), BatchedMerkleTreeError> { + for i in 0..self.queue_batches.num_batches as usize { + Batch::check_non_inclusion_ref( + self.queue_batches.batches[i].num_iters as usize, + self.queue_batches.batches[i].bloom_filter_capacity, + value, + self.bloom_filter_stores[i], + )?; + } + Ok(()) + } + + pub fn pubkey(&self) -> &Pubkey { + &self.pubkey + } +} + +impl Deref for BatchedMerkleTreeRef<'_> { + type Target = BatchedMerkleTreeMetadata; + + fn deref(&self) -> &Self::Target { + &self.metadata + } +} + +impl<'a> BatchedMerkleTreeRef<'a> { + /// Return root from the root history by index. + pub fn get_root_by_index(&self, index: usize) -> Option<&[u8; 32]> { + self.root_history.get(index) + } +} diff --git a/program-libs/batched-merkle-tree/src/queue_ref.rs b/program-libs/batched-merkle-tree/src/queue_ref.rs new file mode 100644 index 0000000000..c5f3431922 --- /dev/null +++ b/program-libs/batched-merkle-tree/src/queue_ref.rs @@ -0,0 +1,173 @@ +use std::ops::Deref; + +use light_account_checks::{ + checks::check_account_info, + discriminator::{Discriminator, DISCRIMINATOR_LEN}, + AccountInfoTrait, +}; +use light_compressed_account::{pubkey::Pubkey, OUTPUT_STATE_QUEUE_TYPE_V2}; +use light_merkle_tree_metadata::errors::MerkleTreeMetadataError; +use light_zero_copy::{errors::ZeroCopyError, vec::ZeroCopyVecU64}; +use zerocopy::Ref; + +use crate::{ + constants::ACCOUNT_COMPRESSION_PROGRAM_ID, + errors::BatchedMerkleTreeError, + queue::{BatchedQueueAccount, BatchedQueueMetadata}, +}; + +/// Immutable batched queue reference. +/// +/// Uses `try_borrow_data()` + `&'a [u8]` instead of +/// `try_borrow_mut_data()` + `&'a mut [u8]`. +/// +/// Only contains the fields that external consumers actually read: +/// metadata and value vecs. Hash chain stores are not parsed. +#[derive(Debug)] +pub struct BatchedQueueRef<'a> { + pubkey: Pubkey, + metadata: Ref<&'a [u8], BatchedQueueMetadata>, + /// Value vec metadata: [length, capacity] per batch, parsed inline. + _value_vec_metas: [Ref<&'a [u8], [u64; 2]>; 2], + value_vec_data: [Ref<&'a [u8], [[u8; 32]]>; 2], +} + +impl Discriminator for BatchedQueueRef<'_> { + const LIGHT_DISCRIMINATOR: [u8; 8] = *b"queueacc"; + const LIGHT_DISCRIMINATOR_SLICE: &'static [u8] = b"queueacc"; +} + +impl<'a> BatchedQueueRef<'a> { + /// Deserialize an output queue (immutable) from account info. + pub fn output_from_account_info( + account_info: &A, + ) -> Result, BatchedMerkleTreeError> { + Self::from_account_info::( + &Pubkey::new_from_array(ACCOUNT_COMPRESSION_PROGRAM_ID), + account_info, + ) + } + + pub(crate) fn from_account_info( + program_id: &Pubkey, + account_info: &A, + ) -> Result, BatchedMerkleTreeError> { + check_account_info::(&program_id.to_bytes(), account_info)?; + let data = account_info.try_borrow_data()?; + // SAFETY: We extend the lifetime of the borrowed data to 'a. + // The borrow is shared (immutable), so dropping the Ref guard + // restores pinocchio's borrow state correctly for shared borrows. + let data_slice: &'a [u8] = unsafe { std::slice::from_raw_parts(data.as_ptr(), data.len()) }; + Self::from_bytes::(data_slice, account_info.key().into()) + } + + /// Deserialize an output queue (immutable) from bytes. + #[cfg(not(target_os = "solana"))] + pub fn output_from_bytes( + account_data: &'a [u8], + ) -> Result, BatchedMerkleTreeError> { + light_account_checks::checks::check_discriminator::(account_data)?; + Self::from_bytes::(account_data, Pubkey::default()) + } + + pub(crate) fn from_bytes( + account_data: &'a [u8], + pubkey: Pubkey, + ) -> Result, BatchedMerkleTreeError> { + // 1. Skip discriminator. + let (_discriminator, account_data) = account_data.split_at(DISCRIMINATOR_LEN); + + // 2. Parse metadata. + let (metadata, account_data) = + Ref::<&'a [u8], BatchedQueueMetadata>::from_prefix(account_data) + .map_err(ZeroCopyError::from)?; + + if metadata.metadata.queue_type != QUEUE_TYPE { + return Err(MerkleTreeMetadataError::InvalidQueueType.into()); + } + + // 3. Parse two value vecs inline. + // ZeroCopyVecU64 layout: [u64; 2] metadata (length, capacity), then [u8; 32] * capacity. + let metadata_size = ZeroCopyVecU64::<[u8; 32]>::metadata_size(); + + let (meta0_bytes, account_data) = account_data.split_at(metadata_size); + let (value_vec_meta0, _padding) = + Ref::<&'a [u8], [u64; 2]>::from_prefix(meta0_bytes).map_err(ZeroCopyError::from)?; + let capacity0 = value_vec_meta0[1] as usize; // CAPACITY_INDEX = 1 + let (value_vec_data0, account_data) = + Ref::<&'a [u8], [[u8; 32]]>::from_prefix_with_elems(account_data, capacity0) + .map_err(ZeroCopyError::from)?; + + let (meta1_bytes, account_data) = account_data.split_at(metadata_size); + let (value_vec_meta1, _padding) = + Ref::<&'a [u8], [u64; 2]>::from_prefix(meta1_bytes).map_err(ZeroCopyError::from)?; + let capacity1 = value_vec_meta1[1] as usize; + let (value_vec_data1, _account_data) = + Ref::<&'a [u8], [[u8; 32]]>::from_prefix_with_elems(account_data, capacity1) + .map_err(ZeroCopyError::from)?; + + // 4. Stop here -- hash_chain_stores are not needed for read-only access. + + Ok(BatchedQueueRef { + pubkey, + metadata, + _value_vec_metas: [value_vec_meta0, value_vec_meta1], + value_vec_data: [value_vec_data0, value_vec_data1], + }) + } + + /// Proves inclusion of leaf index if it exists in one of the batches. + /// Returns true if leaf index exists in one of the batches. + pub fn prove_inclusion_by_index( + &self, + leaf_index: u64, + hash_chain_value: &[u8; 32], + ) -> Result { + if leaf_index >= self.batch_metadata.next_index { + return Err(BatchedMerkleTreeError::InvalidIndex); + } + for (batch_index, batch) in self.batch_metadata.batches.iter().enumerate() { + if batch.leaf_index_exists(leaf_index) { + let index = batch.get_value_index_in_batch(leaf_index)?; + let element = self.value_vec_data[batch_index] + .get(index as usize) + .ok_or(BatchedMerkleTreeError::InclusionProofByIndexFailed)?; + + if *element == *hash_chain_value { + return Ok(true); + } else { + #[cfg(target_os = "solana")] + { + solana_msg::msg!( + "Index found but value doesn't match leaf_index {} compressed account hash: {:?} expected compressed account hash {:?}. (If the expected element is [0u8;32] it was already spent. Other possibly causes, data hash, discriminator, leaf index, or Merkle tree mismatch.)", + leaf_index, + hash_chain_value, *element + ); + } + return Err(BatchedMerkleTreeError::InclusionProofByIndexFailed); + } + } + } + Ok(false) + } + + /// Check if the pubkey is the associated Merkle tree of the queue. + pub fn check_is_associated(&self, pubkey: &Pubkey) -> Result<(), BatchedMerkleTreeError> { + if self.metadata.metadata.associated_merkle_tree != *pubkey { + return Err(MerkleTreeMetadataError::MerkleTreeAndQueueNotAssociated.into()); + } + Ok(()) + } + + pub fn pubkey(&self) -> &Pubkey { + &self.pubkey + } +} + +impl Deref for BatchedQueueRef<'_> { + type Target = BatchedQueueMetadata; + + fn deref(&self) -> &Self::Target { + &self.metadata + } +} diff --git a/program-libs/batched-merkle-tree/tests/merkle_tree_ref.rs b/program-libs/batched-merkle-tree/tests/merkle_tree_ref.rs new file mode 100644 index 0000000000..17159a1876 --- /dev/null +++ b/program-libs/batched-merkle-tree/tests/merkle_tree_ref.rs @@ -0,0 +1,283 @@ +mod test_helpers; + +use light_batched_merkle_tree::{ + merkle_tree::BatchedMerkleTreeAccount, merkle_tree_ref::BatchedMerkleTreeRef, +}; +use light_compressed_account::{pubkey::Pubkey, TreeType}; +use light_merkle_tree_metadata::errors::MerkleTreeMetadataError; +use test_helpers::{account_builders::MerkleTreeAccountBuilder, assertions::*}; + +#[test] +fn test_merkle_tree_ref_deserialization_matrix() { + // Test matrix: tree type x API method (table-driven test) + let test_cases = vec![ + ("State tree with state API", TreeType::StateV2, "state", true), + ("Address tree with address API", TreeType::AddressV2, "address", true), + ("State tree with address API", TreeType::StateV2, "address", false), + ("Address tree with state API", TreeType::AddressV2, "state", false), + ]; + + for (description, tree_type, api, should_succeed) in test_cases { + let (data, pubkey) = MerkleTreeAccountBuilder::state_tree() + .with_tree_type(tree_type) + .build(); + + let result = if api == "state" { + BatchedMerkleTreeRef::state_from_bytes(&data, &pubkey) + } else { + BatchedMerkleTreeRef::address_from_bytes(&data, &pubkey) + }; + + if should_succeed { + assert!( + result.is_ok(), + "{}: Expected success but got error {:?}", + description, + result.err() + ); + let tree_ref = result.unwrap(); + assert_eq!( + *tree_ref.pubkey(), + pubkey, + "{}: Pubkey mismatch", + description + ); + } else { + assert_metadata_error( + result, + MerkleTreeMetadataError::InvalidTreeType, + description, + ); + } + } +} + +#[test] +fn test_merkle_tree_ref_from_bytes_errors() { + // Test 1: Bad discriminator + let (data, pubkey) = MerkleTreeAccountBuilder::state_tree() + .build_with_bad_discriminator(); + let result = BatchedMerkleTreeRef::state_from_bytes(&data, &pubkey); + assert_account_error(result, "Bad discriminator should fail"); + + // Test 2: Insufficient size - truncate to just past discriminator so metadata parse fails + let (data, pubkey) = MerkleTreeAccountBuilder::state_tree().build(); + let truncated = &data[..16]; // 8 bytes discriminator + 8 bytes (not enough for metadata) + let result = BatchedMerkleTreeRef::state_from_bytes(truncated, &pubkey); + assert_zerocopy_error(result, "Insufficient size should fail"); + + // Test 3: Empty data (too small even for discriminator) + let empty_data: &[u8] = &[0u8; 4]; + let result = BatchedMerkleTreeRef::state_from_bytes(empty_data, &pubkey); + assert_account_error(result, "Empty data should fail discriminator check"); + + // Test 4: Wrong tree type + let (data, pubkey) = MerkleTreeAccountBuilder::state_tree() + .build_with_wrong_tree_type(999); + let result = BatchedMerkleTreeRef::state_from_bytes(&data, &pubkey); + assert_metadata_error( + result, + MerkleTreeMetadataError::InvalidTreeType, + "Wrong tree type should fail", + ); + + // Test 5: Root history out-of-bounds returns None + let (data, pubkey) = MerkleTreeAccountBuilder::state_tree() + .with_root_history_capacity(5) + .build(); + let tree_ref = BatchedMerkleTreeRef::state_from_bytes(&data, &pubkey).unwrap(); + assert!(tree_ref.get_root_by_index(10).is_none()); +} + +#[test] +fn test_merkle_tree_ref_different_configurations() { + // Table-driven test: different tree configurations + struct TestConfig { + name: &'static str, + batch_size: u64, + zkp_batch_size: u64, + root_history_capacity: u32, + height: u32, + bloom_filter_capacity: u64, + } + + let configs = vec![ + TestConfig { + name: "Minimal config", + batch_size: 2, + zkp_batch_size: 1, + root_history_capacity: 2, + height: 10, + bloom_filter_capacity: 1024, // Must be multiple of 64 for alignment + }, + TestConfig { + name: "Default config", + batch_size: 5, + zkp_batch_size: 1, + root_history_capacity: 10, + height: 40, + bloom_filter_capacity: 8000, + }, + TestConfig { + name: "Large config", + batch_size: 100, + zkp_batch_size: 10, + root_history_capacity: 100, + height: 26, + bloom_filter_capacity: 16000, // Must be multiple of 64 for alignment + }, + ]; + + for config in configs { + let (data, pubkey) = MerkleTreeAccountBuilder::state_tree() + .with_batch_size(config.batch_size) + .with_zkp_batch_size(config.zkp_batch_size) + .with_root_history_capacity(config.root_history_capacity) + .with_height(config.height) + .with_bloom_filter_capacity(config.bloom_filter_capacity) + .build(); + + let tree_ref = BatchedMerkleTreeRef::state_from_bytes(&data, &pubkey) + .unwrap_or_else(|_| panic!("{}: Failed to deserialize", config.name)); + + // Verify configuration is preserved + assert_eq!( + tree_ref.height, config.height, + "{}: Height mismatch", + config.name + ); + assert_eq!( + tree_ref.queue_batches.batch_size, config.batch_size, + "{}: Batch size mismatch", + config.name + ); + assert_eq!( + tree_ref.queue_batches.zkp_batch_size, config.zkp_batch_size, + "{}: ZKP batch size mismatch", + config.name + ); + assert_eq!( + tree_ref.queue_batches.bloom_filter_capacity, config.bloom_filter_capacity, + "{}: Bloom filter capacity mismatch", + config.name + ); + } +} + +#[test] +fn test_merkle_tree_ref_randomized_equivalence() { + use light_bloom_filter::BloomFilter; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let mut rng = StdRng::seed_from_u64(0xDEAD_BEEF); + let root_history_capacity: u32 = 10; + let bloom_filter_capacity: u64 = 100_000; + + let (mut account_data, pubkey) = MerkleTreeAccountBuilder::state_tree() + .with_root_history_capacity(root_history_capacity) + .with_bloom_filter_capacity(bloom_filter_capacity) + .with_num_iters(1) + .build(); + + for _ in 0..1000 { + let action = rng.gen_range(0..3u8); + match action { + 0 => { + // Push random root. + let mut tree_mut = + BatchedMerkleTreeAccount::state_from_bytes(&mut account_data, &pubkey) + .unwrap(); + tree_mut.root_history.push(rng.gen()); + } + 1 => { + // Insert into bloom filter of a random batch. + let mut tree_mut = + BatchedMerkleTreeAccount::state_from_bytes(&mut account_data, &pubkey) + .unwrap(); + let batch_idx = rng.gen_range(0..2usize); + let num_iters = + tree_mut.queue_batches.batches[batch_idx].num_iters as usize; + let capacity = + tree_mut.queue_batches.batches[batch_idx].bloom_filter_capacity; + let value: [u8; 32] = rng.gen(); + let mut bf = BloomFilter::new( + num_iters, + capacity, + &mut tree_mut.bloom_filter_stores[batch_idx], + ) + .unwrap(); + bf.insert(&value).unwrap(); + } + 2 => { + // Increment sequence number. + let mut tree_mut = + BatchedMerkleTreeAccount::state_from_bytes(&mut account_data, &pubkey) + .unwrap(); + tree_mut.sequence_number += 1; + } + _ => unreachable!(), + } + + // Clone data so we can deserialize both paths independently. + let mut account_data_clone = account_data.clone(); + + let tree_ref = + BatchedMerkleTreeRef::state_from_bytes(&account_data, &pubkey).unwrap(); + let tree_mut = + BatchedMerkleTreeAccount::state_from_bytes(&mut account_data_clone, &pubkey) + .unwrap(); + + // Metadata via Deref. + assert_eq!(*tree_ref, *tree_mut.get_metadata()); + + // Root history. + for i in 0..root_history_capacity as usize { + assert_eq!( + tree_ref.get_root_by_index(i).copied(), + tree_mut.get_root_by_index(i).copied(), + "Root mismatch at index {}", + i + ); + } + + // Bloom filter stores byte-equal. + for j in 0..2 { + assert_eq!( + tree_ref.bloom_filter_stores[j], + tree_mut.bloom_filter_stores[j].as_ref(), + "Bloom filter store {} mismatch", + j + ); + } + + // Pubkey. + assert_eq!(tree_ref.pubkey(), tree_mut.pubkey()); + } + + // Non-inclusion coverage: insert a known value and verify it fails non-inclusion, + // while a non-inserted value passes. + let inserted_value = [0x42; 32]; + let non_inserted_value = [0x99; 32]; + { + let mut tree_mut = + BatchedMerkleTreeAccount::state_from_bytes(&mut account_data, &pubkey).unwrap(); + let num_iters = tree_mut.queue_batches.batches[0].num_iters as usize; + let capacity = tree_mut.queue_batches.batches[0].bloom_filter_capacity; + let mut bf = + BloomFilter::new(num_iters, capacity, &mut tree_mut.bloom_filter_stores[0]).unwrap(); + bf.insert(&inserted_value).unwrap(); + } + let tree_ref = BatchedMerkleTreeRef::state_from_bytes(&account_data, &pubkey).unwrap(); + assert!( + tree_ref + .check_input_queue_non_inclusion(&inserted_value) + .is_err(), + "Inserted value should fail non-inclusion check" + ); + assert!( + tree_ref + .check_input_queue_non_inclusion(&non_inserted_value) + .is_ok(), + "Non-inserted value should pass non-inclusion check" + ); +} diff --git a/program-libs/batched-merkle-tree/tests/queue_ref.rs b/program-libs/batched-merkle-tree/tests/queue_ref.rs new file mode 100644 index 0000000000..c13cb35c55 --- /dev/null +++ b/program-libs/batched-merkle-tree/tests/queue_ref.rs @@ -0,0 +1,216 @@ +mod test_helpers; + +use light_batched_merkle_tree::{queue::BatchedQueueAccount, queue_ref::BatchedQueueRef}; +use light_compressed_account::pubkey::Pubkey; +use light_merkle_tree_metadata::errors::MerkleTreeMetadataError; +use test_helpers::{account_builders::QueueAccountBuilder, assertions::*}; + +#[test] +fn test_queue_ref_deserialization_errors() { + // Test 1: Bad discriminator + let (data, _pubkey) = QueueAccountBuilder::output_queue() + .build_with_bad_discriminator(); + let result = BatchedQueueRef::output_from_bytes(&data); + assert_account_error(result, "Bad discriminator should fail"); + + // Test 2: Insufficient size + let (data, _pubkey) = QueueAccountBuilder::output_queue().build_too_small(); + let result = BatchedQueueRef::output_from_bytes(&data); + assert_zerocopy_error(result, "Insufficient size should fail"); + + // Test 3: Wrong queue type + let (data, _pubkey) = QueueAccountBuilder::output_queue() + .build_with_wrong_queue_type(999); + let result = BatchedQueueRef::output_from_bytes(&data); + assert_metadata_error( + result, + MerkleTreeMetadataError::InvalidQueueType, + "Wrong queue type should fail", + ); + + // Test 4: Wrong pubkey association + let associated_tree = Pubkey::new_unique(); + let wrong_tree = Pubkey::new_unique(); + let (data, _pubkey) = QueueAccountBuilder::output_queue() + .with_associated_tree(associated_tree) + .build(); + let queue_ref = BatchedQueueRef::output_from_bytes(&data).unwrap(); + assert_metadata_error( + queue_ref.check_is_associated(&wrong_tree), + MerkleTreeMetadataError::MerkleTreeAndQueueNotAssociated, + "Association check should fail with wrong pubkey", + ); + + // Test 5: Empty queue prove_inclusion returns InvalidIndex + let (data, _pubkey) = QueueAccountBuilder::output_queue().build(); + let queue_ref = BatchedQueueRef::output_from_bytes(&data).unwrap(); + assert_error( + queue_ref.prove_inclusion_by_index(0, &[0u8; 32]), + light_batched_merkle_tree::errors::BatchedMerkleTreeError::InvalidIndex, + "Empty queue should return InvalidIndex", + ); +} + +#[test] +fn test_queue_ref_prove_inclusion_by_index() { + let (mut account_data, _pubkey) = QueueAccountBuilder::output_queue() + .with_batch_size(10) + .with_zkp_batch_size(2) + .build(); + + // Insert test values via the proper insertion API + let test_hash_1 = [0x11; 32]; + let test_hash_2 = [0x22; 32]; + { + let mut queue_mut = BatchedQueueAccount::output_from_bytes(&mut account_data).unwrap(); + queue_mut.insert_into_current_batch(&test_hash_1, &0).unwrap(); + queue_mut.insert_into_current_batch(&test_hash_2, &0).unwrap(); + } + + // Read via immutable ref + let queue_ref = BatchedQueueRef::output_from_bytes(&account_data).unwrap(); + + // Valid index with matching hash + assert_eq!( + queue_ref.prove_inclusion_by_index(0, &test_hash_1).unwrap(), + true, + "Valid index with matching hash should return true" + ); + assert_eq!( + queue_ref.prove_inclusion_by_index(1, &test_hash_2).unwrap(), + true, + "Second element with matching hash should return true" + ); + + // Valid index with wrong hash returns error + assert!( + queue_ref.prove_inclusion_by_index(0, &[0xFF; 32]).is_err(), + "Wrong hash should return error" + ); +} + +#[test] +fn test_queue_ref_different_batch_configurations() { + // Table-driven test: different batch configurations + struct TestConfig { + name: &'static str, + batch_size: u64, + zkp_batch_size: u64, + tree_capacity: u64, + } + + let configs = vec![ + TestConfig { + name: "Small batches", + batch_size: 2, + zkp_batch_size: 1, + tree_capacity: 8, + }, + TestConfig { + name: "Medium batches", + batch_size: 10, + zkp_batch_size: 5, + tree_capacity: 64, + }, + TestConfig { + name: "Large batches", + batch_size: 100, + zkp_batch_size: 10, + tree_capacity: 1024, + }, + ]; + + for config in configs { + let (data, _pubkey) = QueueAccountBuilder::output_queue() + .with_batch_size(config.batch_size) + .with_zkp_batch_size(config.zkp_batch_size) + .with_tree_capacity(config.tree_capacity) + .build(); + + let queue_ref = BatchedQueueRef::output_from_bytes(&data) + .unwrap_or_else(|_| panic!("{}: Failed to deserialize", config.name)); + + // Verify configuration is preserved + assert_eq!( + queue_ref.batch_metadata.batch_size, config.batch_size, + "{}: Batch size mismatch", + config.name + ); + assert_eq!( + queue_ref.batch_metadata.zkp_batch_size, config.zkp_batch_size, + "{}: ZKP batch size mismatch", + config.name + ); + assert_eq!( + queue_ref.tree_capacity, config.tree_capacity, + "{}: Tree capacity mismatch", + config.name + ); + } +} + +#[test] +fn test_queue_ref_randomized_equivalence() { + use rand::{rngs::StdRng, Rng, SeedableRng}; + + let mut rng = StdRng::seed_from_u64(0xCAFE_BABE); + let batch_size = 1000u64; + let associated_tree = Pubkey::new_unique(); + + let (mut account_data, _pubkey) = QueueAccountBuilder::output_queue() + .with_batch_size(batch_size) + .with_zkp_batch_size(1) + .with_associated_tree(associated_tree) + .build(); + + let mut inserted: Vec<(u64, [u8; 32])> = Vec::new(); + + for _ in 0..1000 { + // Insert a value into the current batch (stop when batch is full). + let value: [u8; 32] = rng.gen(); + let slot = 0u64; + { + let mut queue_mut = + BatchedQueueAccount::output_from_bytes(&mut account_data).unwrap(); + let result = queue_mut.insert_into_current_batch(&value, &slot); + if result.is_ok() { + inserted.push((inserted.len() as u64, value)); + } else { + // Batch is full, skip further inserts. + continue; + } + } + + // Clone data so we can deserialize both paths independently. + let mut account_data_clone = account_data.clone(); + + let queue_ref = BatchedQueueRef::output_from_bytes(&account_data).unwrap(); + let queue_mut = + BatchedQueueAccount::output_from_bytes(&mut account_data_clone).unwrap(); + + // Metadata via Deref. + assert_eq!(*queue_ref, *queue_mut.get_metadata()); + + // next_index. + assert_eq!( + queue_ref.batch_metadata.next_index, + queue_mut.get_metadata().batch_metadata.next_index, + ); + + // Prove inclusion for all inserted values. + for &(leaf_index, ref val) in &inserted { + assert_eq!( + queue_ref.prove_inclusion_by_index(leaf_index, val).unwrap(), + true, + "Inclusion failed at leaf_index {}", + leaf_index + ); + } + + // Association check. + queue_ref.check_is_associated(&associated_tree).unwrap(); + + // Pubkey accessor. + assert_eq!(*queue_ref.pubkey(), *queue_mut.pubkey()); + } +} diff --git a/program-libs/batched-merkle-tree/tests/test_helpers/account_builders.rs b/program-libs/batched-merkle-tree/tests/test_helpers/account_builders.rs new file mode 100644 index 0000000000..730c461d4c --- /dev/null +++ b/program-libs/batched-merkle-tree/tests/test_helpers/account_builders.rs @@ -0,0 +1,226 @@ +use light_batched_merkle_tree::{ + merkle_tree::BatchedMerkleTreeAccount, + merkle_tree_metadata::BatchedMerkleTreeMetadata, + queue::BatchedQueueAccount, +}; +use light_compressed_account::{pubkey::Pubkey, QueueType, TreeType}; +use light_merkle_tree_metadata::{merkle_tree::MerkleTreeMetadata, queue::QueueMetadata}; + +/// Builder for creating valid and invalid BatchedMerkleTreeAccount test data. +pub struct MerkleTreeAccountBuilder { + tree_type: TreeType, + batch_size: u64, + zkp_batch_size: u64, + root_history_capacity: u32, + height: u32, + num_iters: u64, + bloom_filter_capacity: u64, +} + +impl MerkleTreeAccountBuilder { + /// Create a state tree builder with default test parameters. + pub fn state_tree() -> Self { + Self { + tree_type: TreeType::StateV2, + batch_size: 5, + zkp_batch_size: 1, + root_history_capacity: 10, + height: 40, + num_iters: 1, + bloom_filter_capacity: 8000, + } + } + + pub fn with_tree_type(mut self, tree_type: TreeType) -> Self { + self.tree_type = tree_type; + self + } + + pub fn with_batch_size(mut self, batch_size: u64) -> Self { + self.batch_size = batch_size; + self + } + + pub fn with_zkp_batch_size(mut self, zkp_batch_size: u64) -> Self { + self.zkp_batch_size = zkp_batch_size; + self + } + + pub fn with_root_history_capacity(mut self, capacity: u32) -> Self { + self.root_history_capacity = capacity; + self + } + + pub fn with_height(mut self, height: u32) -> Self { + self.height = height; + self + } + + pub fn with_bloom_filter_capacity(mut self, capacity: u64) -> Self { + self.bloom_filter_capacity = capacity; + self + } + + pub fn with_num_iters(mut self, num_iters: u64) -> Self { + self.num_iters = num_iters; + self + } + + /// Pre-calculate the exact account size needed for these parameters. + fn calculate_size(&self) -> usize { + let mut temp_metadata = BatchedMerkleTreeMetadata::default(); + temp_metadata.root_history_capacity = self.root_history_capacity; + temp_metadata.height = self.height; + temp_metadata.tree_type = self.tree_type as u64; + temp_metadata.capacity = 2u64.pow(self.height); + temp_metadata + .queue_batches + .init(self.batch_size, self.zkp_batch_size) + .unwrap(); + temp_metadata.queue_batches.bloom_filter_capacity = self.bloom_filter_capacity; + temp_metadata.get_account_size().unwrap() + } + + /// Build a valid account with correctly initialized data. + pub fn build(self) -> (Vec, Pubkey) { + let pubkey = Pubkey::new_unique(); + let size = self.calculate_size(); + let mut data = vec![0u8; size]; + BatchedMerkleTreeAccount::init( + &mut data, + &pubkey, + MerkleTreeMetadata::default(), + self.root_history_capacity, + self.batch_size, + self.zkp_batch_size, + self.height, + self.num_iters, + self.bloom_filter_capacity, + self.tree_type, + ) + .unwrap(); + (data, pubkey) + } + + /// Build account with corrupted discriminator. + pub fn build_with_bad_discriminator(self) -> (Vec, Pubkey) { + let (mut data, pubkey) = self.build(); + data[0..8].copy_from_slice(b"BadDiscr"); + (data, pubkey) + } + + /// Build account with wrong tree type field (but correct discriminator). + pub fn build_with_wrong_tree_type(self, wrong_type: u64) -> (Vec, Pubkey) { + let (mut data, pubkey) = self.build(); + // tree_type is the first field of BatchedMerkleTreeMetadata, right after discriminator + let tree_type_offset = 8; // 8 bytes discriminator + data[tree_type_offset..tree_type_offset + 8].copy_from_slice(&wrong_type.to_le_bytes()); + (data, pubkey) + } + +} + +/// Builder for creating valid and invalid BatchedQueueAccount test data. +pub struct QueueAccountBuilder { + associated_merkle_tree: Pubkey, + batch_size: u64, + zkp_batch_size: u64, + tree_capacity: u64, +} + +impl QueueAccountBuilder { + /// Create an output queue builder with default test parameters. + pub fn output_queue() -> Self { + Self { + associated_merkle_tree: Pubkey::new_unique(), + batch_size: 4, + zkp_batch_size: 2, + tree_capacity: 16, + } + } + + pub fn with_associated_tree(mut self, tree_pubkey: Pubkey) -> Self { + self.associated_merkle_tree = tree_pubkey; + self + } + + pub fn with_batch_size(mut self, batch_size: u64) -> Self { + self.batch_size = batch_size; + self + } + + pub fn with_zkp_batch_size(mut self, zkp_batch_size: u64) -> Self { + self.zkp_batch_size = zkp_batch_size; + self + } + + pub fn with_tree_capacity(mut self, tree_capacity: u64) -> Self { + self.tree_capacity = tree_capacity; + self + } + + /// Pre-calculate exact account size using a temporary metadata struct. + fn calculate_size(&self) -> usize { + use light_batched_merkle_tree::queue_batch_metadata::QueueBatches; + let mut temp_batches = QueueBatches::default(); + temp_batches + .init(self.batch_size, self.zkp_batch_size) + .unwrap(); + // queue_account_size already includes BatchedQueueMetadata::LEN + // which contains discriminator via aligned_sized(anchor) + temp_batches + .queue_account_size(QueueType::OutputStateV2 as u64) + .unwrap() + } + + /// Build a valid queue account with correctly initialized data. + pub fn build(self) -> (Vec, Pubkey) { + let pubkey = Pubkey::new_unique(); + let queue_metadata = QueueMetadata { + associated_merkle_tree: self.associated_merkle_tree, + queue_type: QueueType::OutputStateV2 as u64, + ..Default::default() + }; + + let size = self.calculate_size(); + let mut data = vec![0u8; size]; + BatchedQueueAccount::init( + &mut data, + queue_metadata, + self.batch_size, + self.zkp_batch_size, + 0, // num_iters (output queues don't use bloom filters) + 0, // bloom_filter_capacity + pubkey, + self.tree_capacity, + ) + .unwrap(); + (data, pubkey) + } + + /// Build account with corrupted discriminator. + pub fn build_with_bad_discriminator(self) -> (Vec, Pubkey) { + let (mut data, pubkey) = self.build(); + data[0..8].copy_from_slice(b"BadQueue"); + (data, pubkey) + } + + /// Build account with wrong queue type field (but correct discriminator). + pub fn build_with_wrong_queue_type(self, wrong_type: u64) -> (Vec, Pubkey) { + let (mut data, pubkey) = self.build(); + // In BatchedQueueMetadata, metadata is QueueMetadata which has: + // AccessMetadata (3 pubkeys = 96 bytes) + RolloverMetadata (7*u64 = 56 bytes) + + // associated_merkle_tree (32 bytes) + next_queue (32 bytes) + queue_type (8 bytes) + // Total offset from start of metadata = 96 + 56 + 32 + 32 = 216 + // Plus 8 for discriminator = 224 + let queue_type_offset = 8 + 96 + 56 + 32 + 32; + data[queue_type_offset..queue_type_offset + 8].copy_from_slice(&wrong_type.to_le_bytes()); + (data, pubkey) + } + + /// Build account with insufficient size (truncated). + pub fn build_too_small(self) -> (Vec, Pubkey) { + let (data, pubkey) = self.build(); + (data[..data.len() / 2].to_vec(), pubkey) + } +} diff --git a/program-libs/batched-merkle-tree/tests/test_helpers/assertions.rs b/program-libs/batched-merkle-tree/tests/test_helpers/assertions.rs new file mode 100644 index 0000000000..ca6a1ffb4e --- /dev/null +++ b/program-libs/batched-merkle-tree/tests/test_helpers/assertions.rs @@ -0,0 +1,125 @@ +use light_batched_merkle_tree::errors::BatchedMerkleTreeError; +use light_merkle_tree_metadata::errors::MerkleTreeMetadataError; +use light_zero_copy::errors::ZeroCopyError; +use std::fmt::Debug; + +/// Assert that a result is an error and matches the expected error. +pub fn assert_error(result: Result, expected: E, context: &str) +where + T: Debug, + E: Debug + PartialEq, +{ + match result { + Ok(val) => panic!( + "{}: Expected error {:?}, but got Ok({:?})", + context, expected, val + ), + Err(actual) => assert_eq!( + actual, expected, + "{}: Error mismatch. Expected {:?}, got {:?}", + context, expected, actual + ), + } +} + +/// Assert that a result is a ZeroCopyError. +pub fn assert_zerocopy_error(result: Result, context: &str) +where + T: Debug, +{ + match result { + Ok(val) => panic!("{}: Expected ZeroCopyError, but got Ok({:?})", context, val), + Err(BatchedMerkleTreeError::ZeroCopy(_)) => { + // Success - it's a ZeroCopy error + } + Err(other) => panic!( + "{}: Expected ZeroCopyError, but got {:?}", + context, other + ), + } +} + +/// Assert that a result is a MerkleTreeMetadataError with the specific type. +pub fn assert_metadata_error( + result: Result, + expected: MerkleTreeMetadataError, + context: &str, +) where + T: Debug, +{ + match result { + Ok(val) => panic!( + "{}: Expected MerkleTreeMetadataError::{:?}, but got Ok({:?})", + context, expected, val + ), + Err(BatchedMerkleTreeError::MerkleTreeMetadata(actual)) => { + assert_eq!( + actual, expected, + "{}: MerkleTreeMetadataError mismatch. Expected {:?}, got {:?}", + context, expected, actual + ); + } + Err(other) => panic!( + "{}: Expected MerkleTreeMetadataError::{:?}, but got {:?}", + context, expected, other + ), + } +} + +/// Assert that a result is an AccountError. +pub fn assert_account_error(result: Result, context: &str) +where + T: Debug, +{ + match result { + Ok(val) => panic!("{}: Expected AccountError, but got Ok({:?})", context, val), + Err(BatchedMerkleTreeError::AccountError(_)) => { + // Success - it's an AccountError + } + Err(other) => panic!("{}: Expected AccountError, but got {:?}", context, other), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use light_account_checks::error::AccountError; + + #[test] + fn test_assert_error_catches_mismatch() { + let result: Result<(), BatchedMerkleTreeError> = + Err(BatchedMerkleTreeError::InvalidIndex); + assert_error( + result, + BatchedMerkleTreeError::InvalidIndex, + "Should match", + ); + } + + #[test] + fn test_assert_zerocopy_error() { + let result: Result<(), BatchedMerkleTreeError> = + Err(BatchedMerkleTreeError::ZeroCopy(ZeroCopyError::Size)); + assert_zerocopy_error(result, "Should be ZeroCopy error"); + } + + #[test] + fn test_assert_metadata_error() { + let result: Result<(), BatchedMerkleTreeError> = Err( + BatchedMerkleTreeError::MerkleTreeMetadata(MerkleTreeMetadataError::InvalidTreeType), + ); + assert_metadata_error( + result, + MerkleTreeMetadataError::InvalidTreeType, + "Should match InvalidTreeType", + ); + } + + #[test] + fn test_assert_account_error() { + let result: Result<(), BatchedMerkleTreeError> = Err( + BatchedMerkleTreeError::AccountError(AccountError::InvalidDiscriminator), + ); + assert_account_error(result, "Should be AccountError"); + } +} diff --git a/program-libs/batched-merkle-tree/tests/test_helpers/mod.rs b/program-libs/batched-merkle-tree/tests/test_helpers/mod.rs new file mode 100644 index 0000000000..d2dd1554be --- /dev/null +++ b/program-libs/batched-merkle-tree/tests/test_helpers/mod.rs @@ -0,0 +1,2 @@ +pub mod account_builders; +pub mod assertions; diff --git a/program-libs/bloom-filter/src/lib.rs b/program-libs/bloom-filter/src/lib.rs index 389c70086e..e3363843fb 100644 --- a/program-libs/bloom-filter/src/lib.rs +++ b/program-libs/bloom-filter/src/lib.rs @@ -129,6 +129,45 @@ impl<'a> BloomFilter<'a> { } } +/// Immutable bloom filter reference for read-only access. +/// +/// Uses `&'a [u8]` instead of `&'a mut [u8]` for the store, +/// enabling shared borrows of account data. +pub struct BloomFilterRef<'a> { + pub num_iters: usize, + pub capacity: u64, + pub store: &'a [u8], +} + +impl<'a> BloomFilterRef<'a> { + pub fn new( + num_iters: usize, + capacity: u64, + store: &'a [u8], + ) -> Result { + if store.len() * 8 != capacity as usize { + return Err(BloomFilterError::InvalidStoreCapacity); + } + Ok(Self { + num_iters, + capacity, + store, + }) + } + + pub fn contains(&self, value: &[u8; 32]) -> bool { + use bitvec::prelude::*; + let bits = BitSlice::::from_slice(self.store); + for i in 0..self.num_iters { + let probe_index = BloomFilter::probe_index_keccak(value, i, &self.capacity); + if !bits[probe_index] { + return false; + } + } + true + } +} + #[cfg(test)] mod test { use light_hasher::bigint::bigint_to_be_bytes_array; @@ -280,4 +319,23 @@ mod test { } } } + + #[test] + fn test_bloom_filter_ref() { + let capacity = 128_000u64 * 8; + let mut store = vec![0u8; 128_000]; + let value1 = [1u8; 32]; + let value2 = [2u8; 32]; + + // Insert via mutable BloomFilter + { + let mut bf = BloomFilter::new(3, capacity, &mut store).unwrap(); + bf.insert(&value1).unwrap(); + } + + // Read via immutable BloomFilterRef + let bf_ref = BloomFilterRef::new(3, capacity, &store).unwrap(); + assert!(bf_ref.contains(&value1)); + assert!(!bf_ref.contains(&value2)); + } } diff --git a/program-libs/zero-copy/src/cyclic_vec.rs b/program-libs/zero-copy/src/cyclic_vec.rs index 80815ffe01..aa2108bb8b 100644 --- a/program-libs/zero-copy/src/cyclic_vec.rs +++ b/program-libs/zero-copy/src/cyclic_vec.rs @@ -7,76 +7,51 @@ use core::{ #[cfg(feature = "std")] use std::vec::Vec; -use zerocopy::{little_endian::U32, Ref}; +use zerocopy::{ + byte_slice::{ByteSliceMut, SplitByteSlice, SplitByteSliceMut}, + little_endian::U32, + Ref, +}; use crate::{add_padding, errors::ZeroCopyError, ZeroCopyTraits}; -pub type ZeroCopyCyclicVecU32<'a, T> = ZeroCopyCyclicVec<'a, u32, T>; -pub type ZeroCopyCyclicVecU64<'a, T> = ZeroCopyCyclicVec<'a, u64, T>; -pub type ZeroCopyCyclicVecU16<'a, T> = ZeroCopyCyclicVec<'a, u16, T>; -pub type ZeroCopyCyclicVecU8<'a, T> = ZeroCopyCyclicVec<'a, u8, T>; -pub type ZeroCopyCyclicVecBorsh<'a, T> = ZeroCopyCyclicVec<'a, U32, T>; +/// Mutable aliases (existing API). +pub type ZeroCopyCyclicVecU32<'a, T> = ZeroCopyCyclicVec<&'a mut [u8], u32, T>; +pub type ZeroCopyCyclicVecU64<'a, T> = ZeroCopyCyclicVec<&'a mut [u8], u64, T>; +pub type ZeroCopyCyclicVecU16<'a, T> = ZeroCopyCyclicVec<&'a mut [u8], u16, T>; +pub type ZeroCopyCyclicVecU8<'a, T> = ZeroCopyCyclicVec<&'a mut [u8], u8, T>; +pub type ZeroCopyCyclicVecBorsh<'a, T> = ZeroCopyCyclicVec<&'a mut [u8], U32, T>; + +/// Immutable aliases. +pub type ZeroCopyCyclicVecRefU64<'a, T> = ZeroCopyCyclicVec<&'a [u8], u64, T>; -pub struct ZeroCopyCyclicVec<'a, L, T, const PAD: bool = true> +pub struct ZeroCopyCyclicVec where L: ZeroCopyTraits, T: ZeroCopyTraits, u64: From + TryInto, { /// [current_index, length, capacity] - metadata: Ref<&'a mut [u8], [L; 3]>, - slice: Ref<&'a mut [u8], [T]>, + metadata: Ref, + slice: Ref, } const CURRENT_INDEX_INDEX: usize = 0; const LENGTH_INDEX: usize = 1; const CAPACITY_INDEX: usize = 2; -impl<'a, L, T, const PAD: bool> ZeroCopyCyclicVec<'a, L, T, PAD> +// --------------------------------------------------------------------------- +// Read-only methods (available for both &[u8] and &mut [u8]). +// --------------------------------------------------------------------------- +impl ZeroCopyCyclicVec where + B: SplitByteSlice, L: ZeroCopyTraits, T: ZeroCopyTraits, u64: From + TryInto, { - pub fn new(capacity: L, bytes: &'a mut [u8]) -> Result { - Ok(Self::new_at(capacity, bytes)?.0) - } - - pub fn new_at(capacity: L, bytes: &'a mut [u8]) -> Result<(Self, &'a mut [u8]), ZeroCopyError> { - if u64::from(capacity) == 0 { - return Err(ZeroCopyError::InvalidCapacity); - } - let metadata_size = Self::metadata_size(); - if bytes.len() < metadata_size { - return Err(ZeroCopyError::InsufficientMemoryAllocated( - bytes.len(), - metadata_size, - )); - } - let (meta_data, bytes) = bytes.split_at_mut(metadata_size); - - let (mut metadata, _padding) = Ref::<&mut [u8], [L; 3]>::from_prefix(meta_data)?; - - if u64::from(metadata[LENGTH_INDEX]) != 0 - || u64::from(metadata[CURRENT_INDEX_INDEX]) != 0 - || u64::from(metadata[CAPACITY_INDEX]) != 0 - { - return Err(ZeroCopyError::MemoryNotZeroed); - } - metadata[CAPACITY_INDEX] = capacity; - let capacity_usize: usize = u64::from(metadata[CAPACITY_INDEX]) as usize; - - let (slice, remaining_bytes) = - Ref::<&mut [u8], [T]>::from_prefix_with_elems(bytes, capacity_usize)?; - Ok((Self { metadata, slice }, remaining_bytes)) - } - - pub fn from_bytes(bytes: &'a mut [u8]) -> Result { - Ok(Self::from_bytes_at(bytes)?.0) - } - #[inline] - pub fn from_bytes_at(bytes: &'a mut [u8]) -> Result<(Self, &'a mut [u8]), ZeroCopyError> { + pub fn from_bytes_at(bytes: B) -> Result<(Self, B), ZeroCopyError> { let metadata_size = Self::metadata_size(); if bytes.len() < metadata_size { return Err(ZeroCopyError::InsufficientMemoryAllocated( @@ -85,8 +60,10 @@ where )); } - let (meta_data, bytes) = bytes.split_at_mut(metadata_size); - let (metadata, _padding) = Ref::<&mut [u8], [L; 3]>::from_prefix(meta_data)?; + let (meta_data, bytes) = bytes.split_at(metadata_size).map_err(|_| { + ZeroCopyError::InsufficientMemoryAllocated(0, metadata_size) + })?; + let (metadata, _padding) = Ref::::from_prefix(meta_data)?; let usize_capacity: usize = u64::from(metadata[CAPACITY_INDEX]) as usize; let usize_len: usize = u64::from(metadata[LENGTH_INDEX]) as usize; let usize_current_index: usize = u64::from(metadata[CURRENT_INDEX_INDEX]) as usize; @@ -107,67 +84,33 @@ where )); } let (slice, remaining_bytes) = - Ref::<&mut [u8], [T]>::from_prefix_with_elems(bytes, usize_capacity)?; + Ref::::from_prefix_with_elems(bytes, usize_capacity)?; Ok((Self { metadata, slice }, remaining_bytes)) } - /// Convenience method to get the current index of the vector. #[inline] - fn get_current_index(&self) -> L { - self.metadata[CURRENT_INDEX_INDEX] + pub fn from_bytes(bytes: B) -> Result { + Ok(Self::from_bytes_at(bytes)?.0) } - /// Convenience method to get the current index of the vector. #[inline] - fn get_current_index_mut(&mut self) -> &mut L { - &mut self.metadata[CURRENT_INDEX_INDEX] + fn get_current_index(&self) -> L { + self.metadata[CURRENT_INDEX_INDEX] } - /// Convenience method to get the length of the vector. #[inline] fn get_len(&self) -> L { self.metadata[LENGTH_INDEX] } - /// Convenience method to get the length of the vector. - #[inline] - fn get_len_mut(&mut self) -> &mut L { - &mut self.metadata[LENGTH_INDEX] - } - - /// Convenience method to get the capacity of the vector. #[inline] fn get_capacity(&self) -> L { self.metadata[CAPACITY_INDEX] } #[inline] - pub fn push(&mut self, value: T) { - if self.len() < self.capacity() { - let len = self.len(); - self.slice[len] = value; - *self.get_len_mut() = (len as u64 + 1u64) - .try_into() - .map_err(|_| ZeroCopyError::InvalidConversion) - .unwrap(); - } else { - let current_index = self.current_index(); - self.slice[current_index] = value; - } - let new_index = (self.current_index() + 1) % self.capacity(); - *self.get_current_index_mut() = (new_index as u64) - .try_into() - .map_err(|_| ZeroCopyError::InvalidConversion) - .unwrap(); - } - - #[inline] - pub fn clear(&mut self) { - *self.get_current_index_mut() = 0 - .try_into() - .map_err(|_| ZeroCopyError::InvalidConversion) - .unwrap(); - *self.get_len_mut() = self.get_current_index(); + fn current_index(&self) -> usize { + u64::from(self.get_current_index()) as usize } #[inline] @@ -175,26 +118,11 @@ where self.get(self.first_index()) } - #[inline] - pub fn first_mut(&mut self) -> Option<&mut T> { - self.get_mut(self.first_index()) - } - #[inline] pub fn last(&self) -> Option<&T> { self.get(self.last_index()) } - #[inline] - pub fn last_mut(&mut self) -> Option<&mut T> { - self.get_mut(self.last_index()) - } - - #[inline] - fn current_index(&self) -> usize { - u64::from(self.get_current_index()) as usize - } - /// First index is the next index after the last index mod capacity. #[inline] pub fn first_index(&self) -> usize { @@ -214,32 +142,6 @@ where } } - #[inline] - pub fn iter(&self) -> ZeroCopyCyclicVecIterator<'_, L, T, PAD> { - ZeroCopyCyclicVecIterator { - vec: self, - current: self.first_index(), - is_finished: false, - _marker: PhantomData, - } - } - - #[inline] - pub fn iter_from( - &self, - start: usize, - ) -> Result, ZeroCopyError> { - if start >= self.len() { - return Err(ZeroCopyError::IterFromOutOfBounds); - } - Ok(ZeroCopyCyclicVecIterator { - vec: self, - current: start, - is_finished: false, - _marker: PhantomData, - }) - } - #[inline] pub fn metadata_size() -> usize { let mut size = size_of::<[L; 3]>(); @@ -282,6 +184,131 @@ where Some(&self.slice[index]) } + #[inline] + pub fn as_slice(&self) -> &[T] { + &self.slice[..self.len()] + } + + #[cfg(feature = "std")] + pub fn try_into_array(&self) -> Result<[T; N], ZeroCopyError> { + if self.len() != N { + return Err(ZeroCopyError::ArraySize(N, self.len())); + } + Ok(core::array::from_fn(|i| *self.get(i).unwrap())) + } + + #[cfg(feature = "std")] + #[inline] + pub fn to_vec(&self) -> Vec { + self.as_slice().to_vec() + } +} + +// --------------------------------------------------------------------------- +// Mutable construction (only &mut [u8]). +// --------------------------------------------------------------------------- +impl ZeroCopyCyclicVec +where + B: SplitByteSliceMut, + L: ZeroCopyTraits, + T: ZeroCopyTraits, + u64: From + TryInto, +{ + pub fn new(capacity: L, bytes: B) -> Result { + Ok(Self::new_at(capacity, bytes)?.0) + } + + pub fn new_at(capacity: L, bytes: B) -> Result<(Self, B), ZeroCopyError> { + if u64::from(capacity) == 0 { + return Err(ZeroCopyError::InvalidCapacity); + } + let metadata_size = Self::metadata_size(); + if bytes.len() < metadata_size { + return Err(ZeroCopyError::InsufficientMemoryAllocated( + bytes.len(), + metadata_size, + )); + } + let (meta_data, bytes) = bytes.split_at(metadata_size).map_err(|_| { + ZeroCopyError::InsufficientMemoryAllocated(0, metadata_size) + })?; + + let (mut metadata, _padding) = Ref::::from_prefix(meta_data)?; + + if u64::from(metadata[LENGTH_INDEX]) != 0 + || u64::from(metadata[CURRENT_INDEX_INDEX]) != 0 + || u64::from(metadata[CAPACITY_INDEX]) != 0 + { + return Err(ZeroCopyError::MemoryNotZeroed); + } + metadata[CAPACITY_INDEX] = capacity; + let capacity_usize: usize = u64::from(metadata[CAPACITY_INDEX]) as usize; + + let (slice, remaining_bytes) = + Ref::::from_prefix_with_elems(bytes, capacity_usize)?; + Ok((Self { metadata, slice }, remaining_bytes)) + } +} + +// --------------------------------------------------------------------------- +// Mutable access methods (only &mut [u8]). +// --------------------------------------------------------------------------- +impl ZeroCopyCyclicVec +where + B: ByteSliceMut + SplitByteSlice, + L: ZeroCopyTraits, + T: ZeroCopyTraits, + u64: From + TryInto, +{ + #[inline] + fn get_current_index_mut(&mut self) -> &mut L { + &mut self.metadata[CURRENT_INDEX_INDEX] + } + + #[inline] + fn get_len_mut(&mut self) -> &mut L { + &mut self.metadata[LENGTH_INDEX] + } + + #[inline] + pub fn push(&mut self, value: T) { + if self.len() < self.capacity() { + let len = self.len(); + self.slice[len] = value; + *self.get_len_mut() = (len as u64 + 1u64) + .try_into() + .map_err(|_| ZeroCopyError::InvalidConversion) + .unwrap(); + } else { + let current_index = self.current_index(); + self.slice[current_index] = value; + } + let new_index = (self.current_index() + 1) % self.capacity(); + *self.get_current_index_mut() = (new_index as u64) + .try_into() + .map_err(|_| ZeroCopyError::InvalidConversion) + .unwrap(); + } + + #[inline] + pub fn clear(&mut self) { + *self.get_current_index_mut() = 0 + .try_into() + .map_err(|_| ZeroCopyError::InvalidConversion) + .unwrap(); + *self.get_len_mut() = self.get_current_index(); + } + + #[inline] + pub fn first_mut(&mut self) -> Option<&mut T> { + self.get_mut(self.first_index()) + } + + #[inline] + pub fn last_mut(&mut self) -> Option<&mut T> { + self.get_mut(self.last_index()) + } + #[inline] pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { if index >= self.len() { @@ -290,46 +317,66 @@ where Some(&mut self.slice[index]) } - #[inline] - pub fn as_slice(&self) -> &[T] { - &self.slice[..self.len()] - } - #[inline] pub fn as_mut_slice(&mut self) -> &mut [T] { let len = self.len(); &mut self.slice[..len] } +} - #[cfg(feature = "std")] - pub fn try_into_array(&self) -> Result<[T; N], ZeroCopyError> { - if self.len() != N { - return Err(ZeroCopyError::ArraySize(N, self.len())); +// --------------------------------------------------------------------------- +// Iterator (read-only). +// --------------------------------------------------------------------------- +impl ZeroCopyCyclicVec +where + B: SplitByteSlice, + L: ZeroCopyTraits, + T: ZeroCopyTraits, + u64: From + TryInto, +{ + #[inline] + pub fn iter(&self) -> ZeroCopyCyclicVecIterator<'_, B, L, T, PAD> { + ZeroCopyCyclicVecIterator { + vec: self, + current: self.first_index(), + is_finished: false, + _marker: PhantomData, } - Ok(core::array::from_fn(|i| *self.get(i).unwrap())) } - #[cfg(feature = "std")] #[inline] - pub fn to_vec(&self) -> Vec { - self.as_slice().to_vec() + pub fn iter_from( + &self, + start: usize, + ) -> Result, ZeroCopyError> { + if start >= self.len() { + return Err(ZeroCopyError::IterFromOutOfBounds); + } + Ok(ZeroCopyCyclicVecIterator { + vec: self, + current: start, + is_finished: false, + _marker: PhantomData, + }) } } -pub struct ZeroCopyCyclicVecIterator<'a, L, T, const PAD: bool> +pub struct ZeroCopyCyclicVecIterator<'a, B, L, T, const PAD: bool> where + B: SplitByteSlice, L: ZeroCopyTraits, T: ZeroCopyTraits, u64: From + TryInto, { - vec: &'a ZeroCopyCyclicVec<'a, L, T, PAD>, + vec: &'a ZeroCopyCyclicVec, current: usize, is_finished: bool, _marker: PhantomData, } -impl<'a, L, T, const PAD: bool> Iterator for ZeroCopyCyclicVecIterator<'a, L, T, PAD> +impl<'a, B, L, T, const PAD: bool> Iterator for ZeroCopyCyclicVecIterator<'a, B, L, T, PAD> where + B: SplitByteSlice, L: ZeroCopyTraits, T: ZeroCopyTraits, u64: From + TryInto, @@ -353,21 +400,25 @@ where } } -impl IndexMut for ZeroCopyCyclicVec<'_, L, T, PAD> +// --------------------------------------------------------------------------- +// Index / IndexMut / trait impls. +// --------------------------------------------------------------------------- +impl IndexMut for ZeroCopyCyclicVec where + B: ByteSliceMut + SplitByteSlice, L: ZeroCopyTraits, T: ZeroCopyTraits, u64: From + TryInto, { #[inline] fn index_mut(&mut self, index: usize) -> &mut Self::Output { - // Access the underlying mutable slice using as_mut_slice() and index it &mut self.as_mut_slice()[index] } } -impl Index for ZeroCopyCyclicVec<'_, L, T, PAD> +impl Index for ZeroCopyCyclicVec where + B: SplitByteSlice, L: ZeroCopyTraits, T: ZeroCopyTraits, u64: From + TryInto, @@ -376,13 +427,13 @@ where #[inline] fn index(&self, index: usize) -> &Self::Output { - // Access the underlying slice using as_slice() and index it &self.as_slice()[index] } } -impl PartialEq for ZeroCopyCyclicVec<'_, L, T, PAD> +impl PartialEq for ZeroCopyCyclicVec where + B: SplitByteSlice, L: ZeroCopyTraits + PartialEq, T: ZeroCopyTraits + PartialEq, u64: From + TryInto, @@ -393,8 +444,9 @@ where } } -impl fmt::Debug for ZeroCopyCyclicVec<'_, L, T, PAD> +impl fmt::Debug for ZeroCopyCyclicVec where + B: SplitByteSlice, L: ZeroCopyTraits, T: ZeroCopyTraits + Debug, u64: From + TryInto, @@ -408,7 +460,7 @@ where #[test] fn test_private_getters() { let mut backing_store = [0u8; 64]; - let mut zcv = ZeroCopyCyclicVecU16::::new(5, &mut backing_store).unwrap(); + let mut zcv = ZeroCopyCyclicVecU16::::new(5, &mut backing_store[..]).unwrap(); assert_eq!(zcv.get_len(), 0); assert_eq!(zcv.get_capacity(), 5); for i in 0..5 { diff --git a/program-libs/zero-copy/tests/cyclic_vec_tests.rs b/program-libs/zero-copy/tests/cyclic_vec_tests.rs index 6fbee20880..a28785a998 100644 --- a/program-libs/zero-copy/tests/cyclic_vec_tests.rs +++ b/program-libs/zero-copy/tests/cyclic_vec_tests.rs @@ -581,42 +581,42 @@ fn test_init_pass() { #[test] fn test_metadata_size() { - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 3); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 6); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 12); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 24); - - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 3); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 6); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 12); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 24); - - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 4); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 6); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 12); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 24); - - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 8); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 8); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 12); - assert_eq!(ZeroCopyCyclicVec::::metadata_size(), 24); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u8, u8>::metadata_size(), 3); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u16, u8>::metadata_size(), 6); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u32, u8>::metadata_size(), 12); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u64, u8>::metadata_size(), 24); + + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u8, u16>::metadata_size(), 3); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u16, u16>::metadata_size(), 6); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u32, u16>::metadata_size(), 12); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u64, u16>::metadata_size(), 24); + + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u8, u32>::metadata_size(), 4); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u16, u32>::metadata_size(), 6); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u32, u32>::metadata_size(), 12); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u64, u32>::metadata_size(), 24); + + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u8, u64>::metadata_size(), 8); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u16, u64>::metadata_size(), 8); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u32, u64>::metadata_size(), 12); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u64, u64>::metadata_size(), 24); } #[test] fn test_data_size() { - assert_eq!(ZeroCopyCyclicVec::::data_size(64), 64); + assert_eq!(ZeroCopyCyclicVec::<&mut [u8], u8, u8>::data_size(64), 64); } #[test] fn test_required_size() { // current index + length + capacity + data assert_eq!( - ZeroCopyCyclicVec::::required_size_for_capacity(64), + ZeroCopyCyclicVec::<&mut [u8], u8, u8>::required_size_for_capacity(64), 1 + 1 + 1 + 64 ); // current index + length + capacity + data assert_eq!( - ZeroCopyCyclicVec::::required_size_for_capacity(64), + ZeroCopyCyclicVec::<&mut [u8], u64, u64>::required_size_for_capacity(64), 8 + 8 + 8 + 8 * 64 ); } diff --git a/programs/registry/src/lib.rs b/programs/registry/src/lib.rs index 201d78ff16..a1a351d205 100644 --- a/programs/registry/src/lib.rs +++ b/programs/registry/src/lib.rs @@ -37,7 +37,7 @@ use errors::RegistryError; use light_batched_merkle_tree::{ initialize_address_tree::InitAddressTreeAccountsInstructionData, initialize_state_tree::InitStateTreeAccountsInstructionData, - merkle_tree::BatchedMerkleTreeAccount, queue::BatchedQueueAccount, + merkle_tree_ref::BatchedMerkleTreeRef, queue_ref::BatchedQueueRef, }; use light_compressible::registry_instructions::CreateCompressibleConfig as CreateCompressibleConfigData; use protocol_config::state::ProtocolConfig; @@ -536,7 +536,7 @@ pub mod light_registry { data: Vec, ) -> Result<()> { let merkle_tree = - BatchedMerkleTreeAccount::state_from_account_info(&ctx.accounts.merkle_tree) + BatchedMerkleTreeRef::state_from_account_info(&ctx.accounts.merkle_tree) .map_err(ProgramError::from)?; check_forester( &merkle_tree.metadata, @@ -556,10 +556,10 @@ pub mod light_registry { data: Vec, ) -> Result<()> { let queue_account = - BatchedQueueAccount::output_from_account_info(&ctx.accounts.output_queue) + BatchedQueueRef::output_from_account_info(&ctx.accounts.output_queue) .map_err(ProgramError::from)?; let merkle_tree = - BatchedMerkleTreeAccount::state_from_account_info(&ctx.accounts.merkle_tree) + BatchedMerkleTreeRef::state_from_account_info(&ctx.accounts.merkle_tree) .map_err(ProgramError::from)?; // Eligibility is checked for the Merkle tree, // so that the same forester is eligible to @@ -611,7 +611,7 @@ pub mod light_registry { data: Vec, ) -> Result<()> { let account = - BatchedMerkleTreeAccount::address_from_account_info(&ctx.accounts.merkle_tree) + BatchedMerkleTreeRef::address_from_account_info(&ctx.accounts.merkle_tree) .map_err(ProgramError::from)?; check_forester( &account.metadata, @@ -628,7 +628,7 @@ pub mod light_registry { ctx: Context<'_, '_, '_, 'info, RolloverBatchedAddressMerkleTree<'info>>, bump: u8, ) -> Result<()> { - let account = BatchedMerkleTreeAccount::address_from_account_info( + let account = BatchedMerkleTreeRef::address_from_account_info( &ctx.accounts.old_address_merkle_tree, ) .map_err(ProgramError::from)?; @@ -647,7 +647,7 @@ pub mod light_registry { bump: u8, ) -> Result<()> { let account = - BatchedMerkleTreeAccount::state_from_account_info(&ctx.accounts.old_state_merkle_tree) + BatchedMerkleTreeRef::state_from_account_info(&ctx.accounts.old_state_merkle_tree) .map_err(ProgramError::from)?; check_forester( &account.metadata, diff --git a/programs/system/src/accounts/remaining_account_checks.rs b/programs/system/src/accounts/remaining_account_checks.rs index e5d1b1165c..58842a0c43 100644 --- a/programs/system/src/accounts/remaining_account_checks.rs +++ b/programs/system/src/accounts/remaining_account_checks.rs @@ -1,6 +1,9 @@ use light_account_checks::{checks::check_owner, discriminator::Discriminator}; use light_batched_merkle_tree::{ - merkle_tree::BatchedMerkleTreeAccount, queue::BatchedQueueAccount, + merkle_tree::BatchedMerkleTreeAccount, + merkle_tree_ref::BatchedMerkleTreeRef, + queue::BatchedQueueAccount, + queue_ref::BatchedQueueRef, }; use light_compressed_account::{ constants::{ @@ -32,9 +35,9 @@ pub enum AcpAccount<'info> { Authority(&'info AccountInfo), RegisteredProgramPda(&'info AccountInfo), SystemProgram(&'info AccountInfo), - OutputQueue(BatchedQueueAccount<'info>), - BatchedStateTree(BatchedMerkleTreeAccount<'info>), - BatchedAddressTree(BatchedMerkleTreeAccount<'info>), + OutputQueue(BatchedQueueRef<'info>), + BatchedStateTree(BatchedMerkleTreeRef<'info>), + BatchedAddressTree(BatchedMerkleTreeRef<'info>), StateTree((Pubkey, ConcurrentMerkleTreeZeroCopyMut<'info, Poseidon, 26>)), AddressTree( ( @@ -90,13 +93,13 @@ pub(crate) fn try_from_account_info<'a, 'info: 'a>( let tree_type = TreeType::from(u64::from_le_bytes(tree_type)); match tree_type { TreeType::AddressV2 => { - let tree = BatchedMerkleTreeAccount::address_from_account_info(account_info)?; + let tree = BatchedMerkleTreeRef::address_from_account_info(account_info)?; let program_owner = tree.metadata.access_metadata.program_owner; // for batched trees we set the fee when setting the rollover fee. Ok((AcpAccount::BatchedAddressTree(tree), program_owner)) } TreeType::StateV2 => { - let tree = BatchedMerkleTreeAccount::state_from_account_info(account_info)?; + let tree = BatchedMerkleTreeRef::state_from_account_info(account_info)?; let program_owner = tree.metadata.access_metadata.program_owner; Ok((AcpAccount::BatchedStateTree(tree), program_owner)) } @@ -112,7 +115,7 @@ pub(crate) fn try_from_account_info<'a, 'info: 'a>( } } BatchedQueueAccount::LIGHT_DISCRIMINATOR => { - let queue = BatchedQueueAccount::output_from_account_info(account_info)?; + let queue = BatchedQueueRef::output_from_account_info(account_info)?; let program_owner = queue.metadata.access_metadata.program_owner; Ok((AcpAccount::OutputQueue(queue), program_owner)) } diff --git a/programs/system/src/cpi_context/process_cpi_context.rs b/programs/system/src/cpi_context/process_cpi_context.rs index 42503c9ac7..fa9f36fe86 100644 --- a/programs/system/src/cpi_context/process_cpi_context.rs +++ b/programs/system/src/cpi_context/process_cpi_context.rs @@ -1,5 +1,5 @@ use light_account_checks::discriminator::Discriminator; -use light_batched_merkle_tree::queue::BatchedQueueAccount; +use light_batched_merkle_tree::{queue::BatchedQueueAccount, queue_ref::BatchedQueueRef}; use light_compressed_account::{ compressed_account::{CompressedAccountConfig, CompressedAccountDataConfig}, instruction_data::{ @@ -218,7 +218,7 @@ fn validate_cpi_context_associated_with_merkle_tree<'a, 'info, T: InstructionDat == BatchedQueueAccount::LIGHT_DISCRIMINATOR_SLICE { let queue_account = - BatchedQueueAccount::output_from_account_info(&remaining_accounts[index as usize])?; + BatchedQueueRef::output_from_account_info(&remaining_accounts[index as usize])?; queue_account.metadata.associated_merkle_tree.to_bytes() } else { *remaining_accounts[index as usize].key() diff --git a/programs/system/src/processor/process.rs b/programs/system/src/processor/process.rs index 039382b109..9cd134bafd 100644 --- a/programs/system/src/processor/process.rs +++ b/programs/system/src/processor/process.rs @@ -116,7 +116,7 @@ pub fn process< )?; // 2. Deserialize and check all Merkle tree and queue accounts. - let mut accounts = try_from_account_infos(remaining_accounts, &mut context)?; + let accounts = try_from_account_infos(remaining_accounts, &mut context)?; // 3. Deserialize cpi instruction data as zero copy to fill it. let (mut cpi_ix_data, bytes) = InsertIntoQueuesInstructionDataMut::new_at( &mut cpi_ix_bytes[12..], // 8 bytes instruction discriminator + 4 bytes vector length @@ -173,7 +173,7 @@ pub fn process< // 7. Verify read only address non-inclusion in bloom filters verify_read_only_address_queue_non_inclusion( - accounts.as_mut_slice(), + accounts.as_slice(), inputs.read_only_addresses().unwrap_or_default(), )?; @@ -228,7 +228,7 @@ pub fn process< // 14. Verify read-only account inclusion by index --------------------------------------------------- let num_read_only_accounts_by_index = - verify_read_only_account_inclusion_by_index(accounts.as_mut_slice(), read_only_accounts)?; + verify_read_only_account_inclusion_by_index(accounts.as_slice(), read_only_accounts)?; // Get num of elements proven by zkp, for inclusion and non-inclusion. let num_inclusion_proof_inputs = { diff --git a/programs/system/src/processor/read_only_account.rs b/programs/system/src/processor/read_only_account.rs index 7ff91cdecd..f10b10ca71 100644 --- a/programs/system/src/processor/read_only_account.rs +++ b/programs/system/src/processor/read_only_account.rs @@ -16,7 +16,7 @@ use crate::{ #[inline(always)] #[profile] pub fn verify_read_only_account_inclusion_by_index( - accounts: &mut [AcpAccount<'_>], + accounts: &[AcpAccount<'_>], read_only_accounts: &[ZPackedReadOnlyCompressedAccount], ) -> Result { let mut num_prove_read_only_accounts_prove_by_index = 0; diff --git a/programs/system/src/processor/read_only_address.rs b/programs/system/src/processor/read_only_address.rs index 81812d4a30..c8465bf4ce 100644 --- a/programs/system/src/processor/read_only_address.rs +++ b/programs/system/src/processor/read_only_address.rs @@ -7,7 +7,7 @@ use crate::{accounts::remaining_account_checks::AcpAccount, errors::SystemProgra #[inline(always)] #[profile] pub fn verify_read_only_address_queue_non_inclusion( - remaining_accounts: &mut [AcpAccount<'_>], + remaining_accounts: &[AcpAccount<'_>], read_only_addresses: &[ZPackedReadOnlyAddress], ) -> Result<()> { if read_only_addresses.is_empty() { @@ -15,7 +15,7 @@ pub fn verify_read_only_address_queue_non_inclusion( } for read_only_address in read_only_addresses.iter() { let merkle_tree = if let AcpAccount::BatchedAddressTree(tree) = - &mut remaining_accounts[read_only_address.address_merkle_tree_account_index as usize] + &remaining_accounts[read_only_address.address_merkle_tree_account_index as usize] { tree } else { diff --git a/programs/system/src/processor/verify_proof.rs b/programs/system/src/processor/verify_proof.rs index 07b19ae6bb..f33a35f2db 100644 --- a/programs/system/src/processor/verify_proof.rs +++ b/programs/system/src/processor/verify_proof.rs @@ -1,5 +1,6 @@ -use light_batched_merkle_tree::constants::{ - DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, DEFAULT_BATCH_STATE_TREE_HEIGHT, +use light_batched_merkle_tree::{ + constants::{DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, DEFAULT_BATCH_STATE_TREE_HEIGHT}, + errors::BatchedMerkleTreeError, }; use light_compressed_account::{ hash_chain::{create_hash_chain_from_slice, create_two_inputs_hash_chain}, @@ -130,12 +131,20 @@ fn read_root( (*roots).push(merkle_tree.roots[root_index as usize]); } AcpAccount::BatchedStateTree(merkle_tree) => { - (*roots).push(merkle_tree.root_history[root_index as usize]); + (*roots).push( + *merkle_tree + .get_root_by_index(root_index as usize) + .ok_or(BatchedMerkleTreeError::InvalidIndex)?, + ); height = merkle_tree.height as u8; } AcpAccount::BatchedAddressTree(merkle_tree) => { height = merkle_tree.height as u8; - (*roots).push(merkle_tree.root_history[root_index as usize]); + (*roots).push( + *merkle_tree + .get_root_by_index(root_index as usize) + .ok_or(BatchedMerkleTreeError::InvalidIndex)?, + ); } AcpAccount::StateTree((_, merkle_tree)) => { if IS_READ_ONLY { diff --git a/programs/system/src/utils.rs b/programs/system/src/utils.rs index c1fa2f458d..81a5d83fc3 100644 --- a/programs/system/src/utils.rs +++ b/programs/system/src/utils.rs @@ -26,24 +26,11 @@ pub fn get_sol_pool_pda() -> Pubkey { #[profile] pub fn get_queue_and_tree_accounts<'b, 'info>( - accounts: &'b mut [AcpAccount<'info>], + accounts: &'b [AcpAccount<'info>], queue_index: usize, tree_index: usize, -) -> std::result::Result<(&'b mut AcpAccount<'info>, &'b mut AcpAccount<'info>), SystemProgramError> -{ - let (smaller, bigger) = if queue_index < tree_index { - (queue_index, tree_index) - } else { - (tree_index, queue_index) - }; - let (left, right) = accounts.split_at_mut(bigger); - let smaller_ref = &mut left[smaller]; - let bigger_ref = &mut right[0]; - Ok(if queue_index < tree_index { - (smaller_ref, bigger_ref) - } else { - (bigger_ref, smaller_ref) - }) +) -> std::result::Result<(&'b AcpAccount<'info>, &'b AcpAccount<'info>), SystemProgramError> { + Ok((&accounts[queue_index], &accounts[tree_index])) } pub fn transfer_lamports_invoke( diff --git a/sdk-libs/token-sdk/src/instruction/create_mints.rs b/sdk-libs/token-sdk/src/instruction/create_mints.rs index d1a9629cdf..43b6556a51 100644 --- a/sdk-libs/token-sdk/src/instruction/create_mints.rs +++ b/sdk-libs/token-sdk/src/instruction/create_mints.rs @@ -9,7 +9,7 @@ //! - N=1: Single CPI (create + decompress) //! - N>1: 2N-1 CPIs (N-1 writes + 1 execute with decompress + N-1 decompress) -use light_batched_merkle_tree::queue::BatchedQueueAccount; +use light_batched_merkle_tree::queue_ref::BatchedQueueRef; use light_compressed_account::instruction_data::traits::LightInstructionData; use light_compressed_token_sdk::compressed_token::mint_action::{ get_mint_action_instruction_account_metas_cpi_write, MintActionMetaConfig, @@ -607,7 +607,7 @@ fn build_mint_instruction_data( /// Get base leaf index from output queue account. #[inline(never)] fn get_base_leaf_index(output_queue: &AccountInfo) -> Result { - let queue = BatchedQueueAccount::output_from_account_info(output_queue) + let queue = BatchedQueueRef::output_from_account_info(output_queue) .map_err(|_| ProgramError::InvalidAccountData)?; Ok(queue.batch_metadata.next_index as u32) }