From 723c091b4ff786fda0de1605f9ddfbd14e0a4ed0 Mon Sep 17 00:00:00 2001 From: vanhger Date: Wed, 17 Sep 2025 21:17:52 +0700 Subject: [PATCH 01/12] feat: add base structure for aes128 precompile --- Cargo.lock | 1 + crates/core/executor/Cargo.toml | 1 + crates/core/executor/src/air.rs | 3 + .../executor/src/events/precompiles/aes128.rs | 45 ++++ .../executor/src/events/precompiles/mod.rs | 7 + crates/core/executor/src/syscalls/code.rs | 4 + crates/core/executor/src/syscalls/mod.rs | 3 + .../syscalls/precompiles/aes128/encrypt.rs | 219 ++++++++++++++++++ .../src/syscalls/precompiles/aes128/mod.rs | 2 + .../src/syscalls/precompiles/aes128/utils.rs | 18 ++ .../executor/src/syscalls/precompiles/mod.rs | 1 + crates/core/machine/src/mips/mod.rs | 8 + .../machine/src/operations/aes/aes_mul2.rs | 91 ++++++++ .../machine/src/operations/aes/aes_mul3.rs | 57 +++++ .../machine/src/operations/aes/mix_column.rs | 181 +++++++++++++++ crates/core/machine/src/operations/aes/mod.rs | 5 + .../machine/src/operations/aes/round_key.rs | 198 ++++++++++++++++ .../machine/src/operations/aes/xor_byte_4.rs | 91 ++++++++ crates/core/machine/src/operations/mod.rs | 3 +- crates/core/machine/src/runtime/syscall.rs | 4 + .../syscall/precompiles/aes128_encrypt/air.rs | 20 ++ .../precompiles/aes128_encrypt/columns.rs | 30 +++ .../syscall/precompiles/aes128_encrypt/mod.rs | 45 ++++ .../precompiles/aes128_encrypt/trace.rs | 96 ++++++++ .../machine/src/syscall/precompiles/mod.rs | 1 + crates/test-artifacts/guests/Cargo.toml | 1 + .../test-artifacts/guests/aes128/Cargo.toml | 8 + .../test-artifacts/guests/aes128/src/main.rs | 13 ++ crates/test-artifacts/src/lib.rs | 1 + crates/zkvm/entrypoint/src/syscalls/aes128.rs | 26 +++ crates/zkvm/entrypoint/src/syscalls/mod.rs | 5 + crates/zkvm/lib/src/aes128.rs | 47 ++++ crates/zkvm/lib/src/lib.rs | 4 + examples/Cargo.lock | 19 ++ examples/Cargo.toml | 2 + examples/aes128/guest/Cargo.toml | 9 + examples/aes128/guest/src/main.rs | 40 ++++ examples/aes128/host/Cargo.toml | 14 ++ examples/aes128/host/build.rs | 3 + examples/aes128/host/src/main.rs | 60 +++++ 40 files changed, 1385 insertions(+), 1 deletion(-) create mode 100644 crates/core/executor/src/events/precompiles/aes128.rs create mode 100644 crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs create mode 100644 crates/core/executor/src/syscalls/precompiles/aes128/mod.rs create mode 100644 crates/core/executor/src/syscalls/precompiles/aes128/utils.rs create mode 100644 crates/core/machine/src/operations/aes/aes_mul2.rs create mode 100644 crates/core/machine/src/operations/aes/aes_mul3.rs create mode 100644 crates/core/machine/src/operations/aes/mix_column.rs create mode 100644 crates/core/machine/src/operations/aes/mod.rs create mode 100644 crates/core/machine/src/operations/aes/round_key.rs create mode 100644 crates/core/machine/src/operations/aes/xor_byte_4.rs create mode 100644 crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs create mode 100644 crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs create mode 100644 crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs create mode 100644 crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs create mode 100644 crates/test-artifacts/guests/aes128/Cargo.toml create mode 100644 crates/test-artifacts/guests/aes128/src/main.rs create mode 100644 crates/zkvm/entrypoint/src/syscalls/aes128.rs create mode 100644 crates/zkvm/lib/src/aes128.rs create mode 100644 examples/aes128/guest/Cargo.toml create mode 100644 examples/aes128/guest/src/main.rs create mode 100644 examples/aes128/host/Cargo.toml create mode 100644 examples/aes128/host/build.rs create mode 100644 examples/aes128/host/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 06677f0a9..063318356 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8356,6 +8356,7 @@ dependencies = [ name = "zkm-core-executor" version = "1.2.0" dependencies = [ + "aes", "anyhow", "bincode", "bytemuck", diff --git a/crates/core/executor/Cargo.toml b/crates/core/executor/Cargo.toml index d03c7637c..95820350c 100644 --- a/crates/core/executor/Cargo.toml +++ b/crates/core/executor/Cargo.toml @@ -45,6 +45,7 @@ bytemuck = "1.16.3" tiny-keccak = { version = "2.0.2", features = ["keccak"] } vec_map = { version = "0.8.2", features = ["serde"] } enum-map = { version = "2.7.3", features = ["serde"] } +aes = "0.8.4" sha2 = { workspace = true } anyhow = { workspace = true } tracing-subscriber = "0.3.19" diff --git a/crates/core/executor/src/air.rs b/crates/core/executor/src/air.rs index 6af0b3bda..3b89a6836 100644 --- a/crates/core/executor/src/air.rs +++ b/crates/core/executor/src/air.rs @@ -41,6 +41,8 @@ pub enum MipsAirId { Secp256r1DoubleAssign = 11, /// The Poseidon2 Permute chip Poseidon2Permute = 46, + /// The AES-128 Encrypt chip + Aes128Encrypt = 47, /// The Keccak sponge chip. KeccakSponge = 48, /// The bn254 add assign chip. @@ -153,6 +155,7 @@ impl MipsAirId { Self::Secp256r1AddAssign => "Secp256r1AddAssign", Self::Secp256r1DoubleAssign => "Secp256r1DoubleAssign", Self::Poseidon2Permute => "Poseidon2Permute", + Self::Aes128Encrypt => "Aes128Encrypt", Self::KeccakSponge => "KeccakSponge", Self::Bn254AddAssign => "Bn254AddAssign", Self::Bn254DoubleAssign => "Bn254DoubleAssign", diff --git a/crates/core/executor/src/events/precompiles/aes128.rs b/crates/core/executor/src/events/precompiles/aes128.rs new file mode 100644 index 000000000..aecc727f6 --- /dev/null +++ b/crates/core/executor/src/events/precompiles/aes128.rs @@ -0,0 +1,45 @@ +use serde::{Deserialize, Serialize}; + +use crate::events::{ + memory::{MemoryReadRecord, MemoryWriteRecord}, + MemoryLocalEvent, +}; + +pub(crate) const AES_128_BLOCK_U32S: usize = 4; + +/// AES128 Encrypt Event +/// +/// This event is emitted when a AES128 encrypt operation is performed. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct AES128EncryptEvent { + /// The shard number. + pub shard: u32, + /// The clock cycle. + pub clk: u32, + /// The address of the block + pub block_addr: u32, + /// The address of the key + pub key_addr: u32, + /// The address of sbox + pub sbox_addr: u32, + /// The memory records for sbox address + pub sbox_addr_memory: MemoryReadRecord, + /// The input block as a [u32; AES_128_BLOCK_U32S] words. + pub input: [u32; AES_128_BLOCK_U32S], + /// The key as a [u32; AES_128_BLOCK_U32S] words. + pub key: [u32; AES_128_BLOCK_U32S], + /// The output block as a [u32; AES_128_BLOCK_U32S] words. + pub output: [u32; AES_128_BLOCK_U32S], + /// The sbox reads + pub sbox_reads: Vec, + /// The memory records for the input + pub input_read_records: [MemoryReadRecord; AES_128_BLOCK_U32S], + /// The memory records for the key + pub key_read_records: [MemoryReadRecord; AES_128_BLOCK_U32S], + /// The memory records for reading sbox + pub sbox_read_records: Vec, + /// The memory records for the output + pub output_write_records: [MemoryWriteRecord; AES_128_BLOCK_U32S], + /// The local memory access records. + pub local_mem_access: Vec, +} diff --git a/crates/core/executor/src/events/precompiles/mod.rs b/crates/core/executor/src/events/precompiles/mod.rs index e7ae5e239..1747b0bcc 100644 --- a/crates/core/executor/src/events/precompiles/mod.rs +++ b/crates/core/executor/src/events/precompiles/mod.rs @@ -8,9 +8,11 @@ mod sha256_compress; mod sha256_extend; mod u256x2048_mul; mod uint256; +mod aes128; use super::{MemoryLocalEvent, SyscallEvent}; use crate::syscalls::SyscallCode; +pub use aes128::*; pub use ec::*; pub use edwards::*; pub use fptower::*; @@ -34,6 +36,8 @@ pub enum PrecompileEvent { ShaCompress(ShaCompressEvent), /// Keccak sponge precompile event. KeccakSponge(KeccakSpongeEvent), + /// AES-128 encrypt precompile event. + Aes128Encrypt(AES128EncryptEvent), /// Edwards curve add precompile event. EdAdd(EllipticCurveAddEvent), /// Edwards curve decompress precompile event. @@ -105,6 +109,9 @@ impl PrecompileLocalMemory for Vec<(SyscallEvent, PrecompileEvent)> { PrecompileEvent::KeccakSponge(e) => { iterators.push(e.local_mem_access.iter()); } + PrecompileEvent::Aes128Encrypt(e) => { + iterators.push(e.local_mem_access.iter()); + } PrecompileEvent::EdDecompress(e) => { iterators.push(e.local_mem_access.iter()); } diff --git a/crates/core/executor/src/syscalls/code.rs b/crates/core/executor/src/syscalls/code.rs index b2bbe30d8..059e0fd55 100644 --- a/crates/core/executor/src/syscalls/code.rs +++ b/crates/core/executor/src/syscalls/code.rs @@ -169,6 +169,9 @@ pub enum SyscallCode { POSEIDON2_PERMUTE = 0x00_01_00_30, SYS_LINUX = 5000, + /// Executes the `AES128_ENCRYPT` precompile. + AES128_ENCRYPT = 0x01_01_00_31, + UNIMPLEMENTED = 0xFF_FF_FF_FF, } @@ -190,6 +193,7 @@ impl SyscallCode { 0x01_01_00_07 => SyscallCode::ED_ADD, 0x00_01_00_08 => SyscallCode::ED_DECOMPRESS, 0x01_01_00_09 => SyscallCode::KECCAK_SPONGE, + 0x01_01_00_31 => SyscallCode::AES128_ENCRYPT, 0x01_01_00_0A => SyscallCode::SECP256K1_ADD, 0x00_01_00_0B => SyscallCode::SECP256K1_DOUBLE, 0x00_01_00_0C => SyscallCode::SECP256K1_DECOMPRESS, diff --git a/crates/core/executor/src/syscalls/mod.rs b/crates/core/executor/src/syscalls/mod.rs index b36f74ba7..cbc48849f 100644 --- a/crates/core/executor/src/syscalls/mod.rs +++ b/crates/core/executor/src/syscalls/mod.rs @@ -22,6 +22,7 @@ pub use code::*; pub use context::*; use hint::{HintLenSyscall, HintReadSyscall}; use precompiles::{ + aes128::encrypt::AES128EncryptSyscall, edwards::{add::EdwardsAddAssignSyscall, decompress::EdwardsDecompressSyscall}, fptower::{Fp2AddSubSyscall, Fp2MulSyscall, FpOpSyscall}, keccak::sponge::KeccakSpongeSyscall, @@ -101,6 +102,8 @@ pub fn default_syscall_map() -> HashMap> { syscall_map.insert(SyscallCode::HALT, Arc::new(HaltSyscall)); syscall_map.insert(SyscallCode::POSEIDON2_PERMUTE, Arc::new(Poseidon2PermuteSyscall)); + + syscall_map.insert(SyscallCode::AES128_ENCRYPT, Arc::new(AES128EncryptSyscall)); syscall_map.insert(SyscallCode::KECCAK_SPONGE, Arc::new(KeccakSpongeSyscall)); diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs new file mode 100644 index 000000000..5c30a9c7a --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -0,0 +1,219 @@ +use log::info; +use crate::events::{AES128EncryptEvent, MemoryReadRecord, PrecompileEvent, AES_128_BLOCK_U32S}; +use crate::Register::A2; +use crate::syscalls::{Syscall, SyscallCode, SyscallContext}; +use crate::syscalls::precompiles::aes128::utils::mul_md5; + +pub(crate) struct AES128EncryptSyscall; + +pub const AES128_RCON: [[u8; 4]; 10] = [ + [0x01, 0x00, 0x00, 0x00], + [0x02, 0x00, 0x00, 0x00], + [0x04, 0x00, 0x00, 0x00], + [0x08, 0x00, 0x00, 0x00], + [0x10, 0x00, 0x00, 0x00], + [0x20, 0x00, 0x00, 0x00], + [0x40, 0x00, 0x00, 0x00], + [0x80, 0x00, 0x00, 0x00], + [0x1B, 0x00, 0x00, 0x00], + [0x36, 0x00, 0x00, 0x00], +]; + +impl Syscall for AES128EncryptSyscall { + fn num_extra_cycles(&self) -> u32 { + 1 + } + fn execute( + &self, + rt: &mut SyscallContext, + syscall_code: SyscallCode, + arg1: u32, + arg2: u32, + ) -> Option { + let start_clk = rt.clk; + let block_ptr = arg1; + let key_ptr = arg2; + + let mut input_read_records = Vec::new(); + let mut key_read_records = Vec::new(); + let mut sbox_read_records = Vec::new(); + let mut output_write_records = Vec::new(); + + let mut input = Vec::new(); + let mut key_u32s = Vec::new(); + let mut state = Vec::new(); + let mut key = Vec::new(); + let mut output = Vec::new(); + let mut sbox = Vec::new(); + + // read sbox ptr + let (sbox_ptr_memory, sbox_ptr) = rt.mr(A2 as u32); + + // read block input + for i in 0..AES_128_BLOCK_U32S { + let (record, value) = rt.mr(block_ptr + i as u32 * 4); + input_read_records.push(record); + input.push(value); + state.extend(value.to_le_bytes()); + } + + // read key + for i in 0..AES_128_BLOCK_U32S { + let (record, value) = rt.mr(key_ptr + i as u32 * 4); + key_read_records.push(record); + key_u32s.push(value); + key.extend(value.to_le_bytes()); + } + + // Add Roundkey, Round 0 + for i in 0..state.len() { + state[i] = state[i] ^ key[i]; + } + + // Read first 24 sbox elements, Round 0 + for i in 0..24 { + let (record, value) = rt.mr(sbox_ptr + i as u32 * 4); + sbox_read_records.push(record); + sbox.push(value); + } + + // perform AES + let mut round_key = key; + for i in 1..11 { + // compute round key + Self::compute_round_key(rt, &mut round_key, &mut sbox_read_records, sbox_ptr, i - 1); + + // Subs_bytes + for i in 0..state.len() { + let (record, value) = rt.mr(sbox_ptr + state[i] as u32 * 4); + sbox_read_records.push(record); + state[i] = value as u8; + } + + // Shift row + let shift_row = [ + state[0], state[5], state[10], state[15], + state[4], state[9], state[14], state[3], + state[8], state[13], state[2], state[7], + state[12], state[1], state[6], state[11], + ].to_vec(); + + // Mix column + let mix_columns = if i != 10 { + let mut mixed = shift_row.clone(); + for col in 0..4 { + let col_start = col * 4; + let s0 = shift_row[col_start]; + let s1 = shift_row[col_start + 1]; + let s2 = shift_row[col_start + 2]; + let s3 = shift_row[col_start + 3]; + mixed[col_start] = mul_md5(s0, 2) ^ mul_md5(s1, 3) ^ mul_md5(s2, 1) ^ mul_md5(s3, 1); + mixed[col_start + 1] = mul_md5(s0, 1) ^ mul_md5(s1, 2) ^ mul_md5(s2, 3) ^ mul_md5(s3, 1); + mixed[col_start + 2] = mul_md5(s0, 1) ^ mul_md5(s1, 1) ^ mul_md5(s2, 2) ^ mul_md5(s3, 3); + mixed[col_start + 3] = mul_md5(s0, 3) ^ mul_md5(s1, 1) ^ mul_md5(s2, 1) ^ mul_md5(s3, 2); + } + mixed + } else { + shift_row + }; + + // Add round key + for i in 0..state.len() { + state[i] = mix_columns[i] ^ round_key[i]; + } + + // Read 24 sbox elements + if i != 10 { + for j in i * 24..i * 24 + 24 { + let (record, value) = rt.mr(sbox_ptr as u32 + j as u32 * 4); + sbox_read_records.push(record); + sbox.push(value); + } + } else { + for j in i * 24..256 { + let (record, value) = rt.mr(sbox_ptr as u32 + j as u32 * 4); + sbox_read_records.push(record); + sbox.push(value); + } + } + } + + + // write output + // Increment the clk by 1 before writing because we read from memory at start_clk. + rt.clk += 1; + assert_eq!(state.len(), 16); + log::info!("AES128 Encrypt output: {:?}", state); + for chunk in state.chunks(4) { + let value = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + output.push(value); + } + let write_records = rt.mw_slice(block_ptr, output.as_slice()); + output_write_records.extend_from_slice(&write_records); + + // Push the AES128 encrypt event. + let shard = rt.current_shard(); + let aes128_event = PrecompileEvent::Aes128Encrypt(AES128EncryptEvent { + shard, + clk: start_clk, + block_addr: block_ptr, + key_addr: key_ptr, + sbox_addr: sbox_ptr, + sbox_addr_memory: sbox_ptr_memory, + input: input.as_slice().try_into().unwrap(), + key: key_u32s.as_slice().try_into().unwrap(), + output: output.as_slice().try_into().unwrap(), + sbox_reads: sbox, + input_read_records: input_read_records.as_slice().try_into().unwrap(), + key_read_records: key_read_records.as_slice().try_into().unwrap(), + sbox_read_records, + output_write_records: output_write_records.as_slice().try_into().unwrap(), + local_mem_access: rt.postprocess(), + }); + let aes128_syscall_event = + rt.rt.syscall_event(start_clk, None, rt.next_pc, syscall_code.syscall_id(), arg1, arg2); + rt.add_precompile_event(syscall_code, aes128_syscall_event, aes128_event); + None + } +} + +impl AES128EncryptSyscall { + fn compute_round_key( + rt: &mut SyscallContext, + previous_key: &mut [u8], + sbox_records: &mut Vec, + sbox_ptr: u32, + round: usize + ) { + if previous_key.len() != 16 { + panic!("AES128: wrong previous key length"); + } + // First 4 bytes + let g_w3 = { + let mut result = [previous_key[13], previous_key[14], previous_key[15], previous_key[12]]; + for (i, rcon) in AES128_RCON[round].iter().enumerate() { + let (record, value) = rt.mr(sbox_ptr + result[i] as u32 * 4); + sbox_records.push(record); + result[i] = (value as u8) ^ rcon; + } + result + }; + let w0 = [previous_key[0], previous_key[1], previous_key[2], previous_key[3]]; + let w1 = [previous_key[4], previous_key[5], previous_key[6], previous_key[7]]; + let w2 = [previous_key[8], previous_key[9], previous_key[10], previous_key[11]]; + let w3 = [previous_key[12], previous_key[13], previous_key[14], previous_key[15]]; + let w4: [u8; 4] = w0.iter().zip(g_w3.iter()).map(|(&a, &b)| a ^ b) + .collect::>().try_into().unwrap(); + let w5: [u8; 4] = w4.iter().zip(w1.iter()).map(|(&a, &b)| a ^ b) + .collect::>().try_into().unwrap(); + let w6: [u8; 4] = w5.iter().zip(w2.iter()).map(|(&a, &b)| a ^ b) + .collect::>().try_into().unwrap(); + let w7: [u8; 4] = w6.iter().zip(w3.iter()).map(|(&a, &b)| a ^ b) + .collect::>().try_into().unwrap(); + + previous_key[0..4].copy_from_slice(&w4); + previous_key[4..8].copy_from_slice(&w5); + previous_key[8..12].copy_from_slice(&w6); + previous_key[12..16].copy_from_slice(&w7); + } +} diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs new file mode 100644 index 000000000..e4fbb9d57 --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs @@ -0,0 +1,2 @@ +pub mod encrypt; +mod utils; \ No newline at end of file diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs b/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs new file mode 100644 index 000000000..d77e6ad32 --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs @@ -0,0 +1,18 @@ +/// Multiply a byte by 2 in GF(2^8) with AES polynomial 0x9B +fn xtime(x: u8) -> u8 { + if x & 0x80 != 0 { + (x << 1) ^ 0x1b + } else { + x << 1 + } +} + +/// Multiply a byte by 1, 2, or 3 in GF(2^8) +pub fn mul_md5(x: u8, by: u8) -> u8 { + match by { + 1 => x, + 2 => xtime(x), + 3 => xtime(x) ^ x, // 3*x = (2*x) ⊕ x + _ => panic!("Only supports multipliers 1, 2, or 3"), + } +} \ No newline at end of file diff --git a/crates/core/executor/src/syscalls/precompiles/mod.rs b/crates/core/executor/src/syscalls/precompiles/mod.rs index cc50a28f4..06eea78ab 100644 --- a/crates/core/executor/src/syscalls/precompiles/mod.rs +++ b/crates/core/executor/src/syscalls/precompiles/mod.rs @@ -1,3 +1,4 @@ +pub mod aes128; pub mod edwards; pub mod fptower; pub mod keccak; diff --git a/crates/core/machine/src/mips/mod.rs b/crates/core/machine/src/mips/mod.rs index bd99f2f0b..6a594996f 100644 --- a/crates/core/machine/src/mips/mod.rs +++ b/crates/core/machine/src/mips/mod.rs @@ -4,6 +4,7 @@ use crate::{ syscall::precompiles::{ fptower::{Fp2AddSubAssignChip, Fp2MulAssignChip, FpOpChip}, poseidon2::Poseidon2PermuteChip, + aes128_encrypt::AES128EncryptChip, }, }; use core::fmt; @@ -141,6 +142,8 @@ pub enum MipsAir { Secp256r1Double(WeierstrassDoubleAssignChip>), /// A precompile for the Poseidon2 permutation Poseidon2Permute(Poseidon2PermuteChip), + /// A precompile for AES-128 encryption + Aes128Encrypt(AES128EncryptChip), /// A precompile for the Keccak Sponge KeccakSponge(KeccakSpongeChip), /// A precompile for addition on the Elliptic curve bn254. @@ -271,6 +274,10 @@ impl MipsAir { let poseidon2_permute = Chip::new(MipsAir::Poseidon2Permute(Poseidon2PermuteChip::new())); costs.insert(poseidon2_permute.name(), poseidon2_permute.cost()); chips.push(poseidon2_permute); + + let aes128_encrypt = Chip::new(MipsAir::Aes128Encrypt(AES128EncryptChip::new())); + costs.insert(aes128_encrypt.name(), 11 * aes128_encrypt.cost()); + chips.push(aes128_encrypt); let keccak_sponge = Chip::new(MipsAir::KeccakSponge(KeccakSpongeChip::new())); costs.insert(keccak_sponge.name(), 24 * keccak_sponge.cost()); @@ -623,6 +630,7 @@ impl MipsAir { Self::Bls12381Fp2Mul(_) => SyscallCode::BLS12381_FP2_MUL, Self::Bls12381Fp2AddSub(_) => SyscallCode::BLS12381_FP2_ADD, Self::Poseidon2Permute(_) => SyscallCode::POSEIDON2_PERMUTE, + Self::Aes128Encrypt(_) => SyscallCode::AES128_ENCRYPT, Self::KeccakSponge(_) => SyscallCode::KECCAK_SPONGE, Self::SysLinux(_) => SyscallCode::SYS_LINUX, Self::Add(_) => unreachable!("Invalid for core chip"), diff --git a/crates/core/machine/src/operations/aes/aes_mul2.rs b/crates/core/machine/src/operations/aes/aes_mul2.rs new file mode 100644 index 000000000..925b4c9bc --- /dev/null +++ b/crates/core/machine/src/operations/aes/aes_mul2.rs @@ -0,0 +1,91 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord}, + ByteOpcode, +}; +use zkm_derive::AlignedBorrow; +use zkm_stark::{air::ZKMAirBuilder, Word}; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct MulBy2InAES { + pub and_0x80: T, + pub left_shift_1: T, + pub xor_0x1b: T // also the result +} + +impl MulBy2InAES { + pub fn populate(&mut self, record: &mut impl ByteRecord, x: u8) -> u8 { + let and_0x80 = x & 0x80; + let left_shift_1 = x << 1; + let xor_0x1b = if and_0x80 != 0 { left_shift_1 ^ 0x1b } else { left_shift_1 }; + + self.and_0x80 = F::from_canonical_u8(and_0x80); + self.left_shift_1 = F::from_canonical_u8(left_shift_1); + self.xor_0x1b = F::from_canonical_u8(xor_0x1b); + + // Byte lookup events + let byte_event_and = ByteLookupEvent { + opcode: ByteOpcode::AND, + a1: and_0x80 as u16, + a2: 0, + b: x, + c: 0x80, + }; + record.add_byte_lookup_event(byte_event_and); + + let byte_event_ssl = ByteLookupEvent { + opcode: ByteOpcode::SLL, + a1: left_shift_1 as u16, + a2: 0, + b: x, + c: 1, + }; + record.add_byte_lookup_event(byte_event_ssl); + + if and_0x80 != 0 { + let byte_event_xor = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: xor_0x1b as u16, + a2: 0, + b: left_shift_1, + c: 0x1b, + }; + record.add_byte_lookup_event(byte_event_xor); + } + + xor_0x1b + } + + #[allow(unused_variables)] + pub fn eval( + builder: &mut AB, + x: AB::Var, + cols: MulBy2InAES, + is_real: AB::Var, + ) { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::AND as u32), + cols.and_0x80, + x, + AB::F::from_canonical_u32(0x80), + is_real, + ); + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::SLL as u32), + cols.left_shift_1, + x, + AB::F::from_canonical_u32(1), + is_real, + ); + builder + .when(cols.and_0x80).inner + .send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.xor_0x1b, + cols.left_shift_1, + AB::F::from_canonical_u32(0x1b), + is_real, + ); + } +} diff --git a/crates/core/machine/src/operations/aes/aes_mul3.rs b/crates/core/machine/src/operations/aes/aes_mul3.rs new file mode 100644 index 000000000..389efcb53 --- /dev/null +++ b/crates/core/machine/src/operations/aes/aes_mul3.rs @@ -0,0 +1,57 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord}, + ByteOpcode, +}; +use zkm_derive::AlignedBorrow; +use zkm_stark::{air::ZKMAirBuilder, Word}; +use crate::operations::aes_mul2::MulBy2InAES; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct MulBy3InAES { + pub mul_by_2: MulBy2InAES, + pub xor_x: T // also the result +} + +impl MulBy3InAES { + pub fn populate(&mut self, record: &mut impl ByteRecord, x: u8) -> u8 { + let x2 = self.mul_by_2.populate(record, x); + let xor_x = x2 ^ x; + self.xor_x = F::from_canonical_u8(xor_x); + + // Byte lookup event for the final XOR + let byte_event_xor = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: xor_x as u16, + a2: 0, + b: x2, + c: x, + }; + record.add_byte_lookup_event(byte_event_xor); + xor_x + } + + #[allow(unused_variables)] + pub fn eval( + builder: &mut AB, + x: AB::Var, + cols: MulBy3InAES, + is_real: AB::Var, + ) { + MulBy2InAES::::eval( + builder, + x, + cols.mul_by_2, + is_real, + ); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.xor_x, + cols.mul_by_2.xor_0x1b, + x, + is_real, + ); + } +} diff --git a/crates/core/machine/src/operations/aes/mix_column.rs b/crates/core/machine/src/operations/aes/mix_column.rs new file mode 100644 index 000000000..9d5ac46eb --- /dev/null +++ b/crates/core/machine/src/operations/aes/mix_column.rs @@ -0,0 +1,181 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord}, + ByteOpcode, +}; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; +use crate::operations::aes::xor_byte_4::XorByte4; +use crate::operations::aes_mul2::MulBy2InAES; +use crate::operations::aes_mul3::MulBy3InAES; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct MixColumn { + pub mul_by_2s: [MulBy2InAES; 16], + pub mul_by_3s: [MulBy3InAES; 16], + pub xor_byte4s: [XorByte4; 16], +} + +impl MixColumn { + pub fn populate( + &mut self, + record: &mut impl ByteRecord, + shifted_state: &[u8; 16], + ) -> [u8; 16] { + let mut mixed = [0u8; 16]; + for col in 0..4 { + let col_start = col * 4; + let s0 = shifted_state[col_start]; + let s1 = shifted_state[col_start + 1]; + let s2 = shifted_state[col_start + 2]; + let s3 = shifted_state[col_start + 3]; + + // mixed[col_start] = mul_md5(s0, 2) ^ mul_md5(s1, 3) ^ mul_md5(s2, 1) ^ mul_md5(s3, 1); + mixed[col_start] = { + let s0x2 = self.mul_by_2s[col_start].populate(record, s0); + let s1x3 = self.mul_by_3s[col_start].populate(record, s1); + self.xor_byte4s[col_start].populate(record, s0x2, s1x3, s2, s3) + }; + + // mixed[col_start + 1] = mul_md5(s0, 1) ^ mul_md5(s1, 2) ^ mul_md5(s2, 3) ^ mul_md5(s3, 1); + mixed[col_start + 1] = { + let s1x2 = self.mul_by_2s[col_start + 1].populate(record, s1); + let s2x3 = self.mul_by_3s[col_start + 1].populate(record, s2); + self.xor_byte4s[col_start + 1].populate(record, s0, s1x2, s2x3, s3) + }; + + // mixed[col_start + 2] = mul_md5(s0, 1) ^ mul_md5(s1, 1) ^ mul_md5(s2, 2) ^ mul_md5(s3, 3); + mixed[col_start + 2] = { + let s2x2 = self.mul_by_2s[col_start + 2].populate(record, s2); + let s3x3 = self.mul_by_3s[col_start + 2].populate(record, s3); + self.xor_byte4s[col_start + 2].populate(record, s0, s1, s2x2, s3x3) + }; + + // mixed[col_start + 3] = mul_md5(s0, 3) ^ mul_md5(s1, 1) ^ mul_md5(s2, 1) ^ mul_md5(s3, 2); + mixed[col_start + 3] = { + let s0x3 = self.mul_by_3s[col_start + 3].populate(record, s0); + let s3x2 = self.mul_by_2s[col_start + 3].populate(record, s3); + self.xor_byte4s[col_start + 3].populate(record, s0x3, s1, s2, s3x2) + } + } + mixed + } + + pub fn eval( + builder: &mut AB, + shifted_state: [AB::Var; 16], + cols: MixColumn, + is_real: AB::Var, + ) { + for col in 0..4 { + let col_start = col * 4; + let s0 = shifted_state[col_start]; + let s1 = shifted_state[col_start + 1]; + let s2 = shifted_state[col_start + 2]; + let s3 = shifted_state[col_start + 3]; + + // col_start + { + MulBy2InAES::::eval( + builder, + s0, + cols.mul_by_2s[col_start], + is_real, + ); + MulBy3InAES::::eval( + builder, + s1, + cols.mul_by_3s[col_start], + is_real, + ); + XorByte4::::eval( + builder, + cols.mul_by_2s[col_start].xor_0x1b, + cols.mul_by_3s[col_start].xor_x, + s2, + s3, + cols.xor_byte4s[col_start], + is_real, + ); + } + + // col_start + 1 + { + MulBy2InAES::::eval( + builder, + s1, + cols.mul_by_2s[col_start + 1], + is_real, + ); + MulBy3InAES::::eval( + builder, + s2, + cols.mul_by_3s[col_start + 1], + is_real, + ); + XorByte4::::eval( + builder, + s0, + cols.mul_by_2s[col_start + 1].xor_0x1b, + cols.mul_by_3s[col_start + 1].xor_x, + s3, + cols.xor_byte4s[col_start + 1], + is_real, + ) + } + + // col_start + 2 + { + MulBy2InAES::::eval( + builder, + s2, + cols.mul_by_2s[col_start + 2], + is_real, + ); + MulBy3InAES::::eval( + builder, + s3, + cols.mul_by_3s[col_start + 2], + is_real, + ); + XorByte4::::eval( + builder, + s0, + s1, + cols.mul_by_2s[col_start + 2].xor_0x1b, + cols.mul_by_3s[col_start + 2].xor_x, + cols.xor_byte4s[col_start + 2], + is_real, + ) + } + + // col_start + 3 + { + MulBy3InAES::::eval( + builder, + s0, + cols.mul_by_3s[col_start + 3], + is_real, + ); + MulBy2InAES::::eval( + builder, + s3, + cols.mul_by_2s[col_start + 3], + is_real, + ); + XorByte4::::eval( + builder, + cols.mul_by_3s[col_start + 3].xor_x, + s1, + s2, + cols.mul_by_2s[col_start + 3].xor_0x1b, + cols.xor_byte4s[col_start + 3], + is_real, + ) + } + } + + } + +} \ No newline at end of file diff --git a/crates/core/machine/src/operations/aes/mod.rs b/crates/core/machine/src/operations/aes/mod.rs new file mode 100644 index 000000000..36c0b8253 --- /dev/null +++ b/crates/core/machine/src/operations/aes/mod.rs @@ -0,0 +1,5 @@ +pub mod aes_mul2; +pub mod aes_mul3; +pub mod mix_column; +pub mod xor_byte_4; +pub mod round_key; \ No newline at end of file diff --git a/crates/core/machine/src/operations/aes/round_key.rs b/crates/core/machine/src/operations/aes/round_key.rs new file mode 100644 index 000000000..fd42ef69a --- /dev/null +++ b/crates/core/machine/src/operations/aes/round_key.rs @@ -0,0 +1,198 @@ +use p3_air::AirBuilder; +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::ByteOpcode; +use zkm_core_executor::events::{ByteLookupEvent, ByteRecord, MemoryReadRecord}; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; +use crate::memory::{MemoryCols, MemoryReadCols}; + +pub const ROUND_CONST: [u8; 10] = [ + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, +]; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct NextRoundKey { + pub round_const: T, + pub add_round_const: T, // XOR + // new round key + pub w4: [T; 4], + pub w5: [T; 4], + pub w6: [T; 4], + pub w7: [T; 4], +} + +impl NextRoundKey { + pub fn populate( + &mut self, + records: &mut impl ByteRecord, + prev_round_key: &[u8; 16], + byte_subs_records: &[MemoryReadRecord; 4], + round: u8, + ) -> [u8; 16] { + // check sbox values + let sbox_values: [u32; 4] = byte_subs_records.map(|m| m.value); + let all_in_u8 = sbox_values.iter().all(|&v| v <= 0xFF); + if !all_in_u8 { + panic!("Not all sbox_values fit in u8"); + } + + // previous round key + let w0 = &prev_round_key[0..4]; + let w1 = &prev_round_key[4..8]; + let w2 = &prev_round_key[8..12]; + let w3 = &prev_round_key[12..16]; + + let mut sub_rot_w3 = sbox_values.map(|u| u as u8); + let rcon = ROUND_CONST[round as usize]; + self.round_const = F::from_canonical_u8(rcon); + let first_byte = sub_rot_w3[0] ^ rcon; + self.add_round_const = F::from_canonical_u8(first_byte); + let first_byte_xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: first_byte as u16, + a2: 0, + b: sub_rot_w3[0], + c: rcon, + }; + records.add_byte_lookup_event(first_byte_xor_event); + // add constant + sub_rot_w3[0] = first_byte; + + // Compute new words + let mut new_key = [0u8; 16]; + // w4 = w0 ^ SubWord(RotWord(w3)) + for i in 0..4 { + new_key[i] = w0[i] ^ sub_rot_w3[i]; + self.w4[i] = F::from_canonical_u8(new_key[i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[i] as u16, + a2: 0, + b: w0[i], + c: sub_rot_w3[i], + }; + records.add_byte_lookup_event(xor_event); + } + + // w5 = w4 ^ w1 + for i in 0..4 { + new_key[4 + i] = new_key[i] ^ w1[i]; + self.w5[i] = F::from_canonical_u8(new_key[4 + i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[4 + i] as u16, + a2: 0, + b: new_key[i], + c: w1[i], + }; + records.add_byte_lookup_event(xor_event); + } + + // w6 = w5 ^ w2 + for i in 0..4 { + new_key[8 + i] = new_key[4 + i] ^ w2[i]; + self.w6[i] = F::from_canonical_u8(new_key[8 + i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[8 + i] as u16, + a2: 0, + b: new_key[4 + i], + c: w2[i], + }; + records.add_byte_lookup_event(xor_event); + } + + // w7 = w6 ^ w3 + for i in 0..4 { + new_key[12 + i] = new_key[8 + i] ^ w3[i]; + self.w7[i] = F::from_canonical_u8(new_key[12 + i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[12 + i] as u16, + a2: 0, + b: new_key[8 + i], + c: w3[i], + }; + records.add_byte_lookup_event(xor_event); + } + + new_key + } + + pub fn eval( + builder: &mut AB, + cols: NextRoundKey, + prev_round_key: [AB::Var; 16], + sbox_read: &[MemoryReadCols; 4], + round: usize, + is_real: AB::Var, + ) { + let w0 = &prev_round_key[0..4]; + let w1 = &prev_round_key[4..8]; + let w2 = &prev_round_key[8..12]; + let w3 = &prev_round_key[12..16]; + + // round const + let rcon = AB::F::from_canonical_u32(ROUND_CONST[round] as u32); + builder.when(is_real).assert_eq(cols.round_const, rcon); + + let sbox_values: [AB::Var; 4] = sbox_read.map(|m| m.value().0[0]); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.add_round_const, + sbox_values[0], + cols.round_const, + is_real, + ); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w4[0], + w0[0], + cols.add_round_const, + is_real, + ); + + for i in 1..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w4[i], + w0[i], + sbox_values[i], + is_real, + ) + } + + for i in 0..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w5[i], + cols.w4[i], + w1[i], + is_real, + ) + } + + for i in 0..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w6[i], + cols.w5[i], + w2[i], + is_real, + ) + } + + for i in 0..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w7[i], + cols.w6[i], + w3[i], + is_real, + ) + } + } +} diff --git a/crates/core/machine/src/operations/aes/xor_byte_4.rs b/crates/core/machine/src/operations/aes/xor_byte_4.rs new file mode 100644 index 000000000..c13e5b6a1 --- /dev/null +++ b/crates/core/machine/src/operations/aes/xor_byte_4.rs @@ -0,0 +1,91 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::ByteOpcode; +use zkm_core_executor::events::{ByteLookupEvent, ByteRecord}; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +// x ^ y ^ z ^ w +pub struct XorByte4 { + pub interm1: T, + pub interm2: T, + pub value: T +} + +impl XorByte4 { + pub fn populate( + &mut self, + record: &mut impl ByteRecord, + x: u8, + y: u8, + z: u8, + w: u8, + ) -> u8 { + let xor_inter1 = x ^ y; + self.interm1 = F::from_canonical_u8(xor_inter1); + let byte_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: xor_inter1 as u16, + a2: 0, + b: x, + c: y, + }; + record.add_byte_lookup_event(byte_event); + + let xor_inter2 = xor_inter1 ^ z; + self.interm2 = F::from_canonical_u8(xor_inter2); + let byte_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: xor_inter2 as u16, + a2: 0, + b: xor_inter1, + c: z, + }; + record.add_byte_lookup_event(byte_event); + + let result = xor_inter2 ^ w; + self.value = F::from_canonical_u8(result); + let byte_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: result as u16, + a2: 0, + b: xor_inter2, + c: w, + }; + record.add_byte_lookup_event(byte_event); + result + } + + pub fn eval( + builder: &mut AB, + x: AB::Var, + y: AB::Var, + z: AB::Var, + w: AB::Var, + cols: XorByte4, + is_real: AB::Var, + ) { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.interm1, + x, + y, + is_real, + ); + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.interm2, + cols.interm1, + z, + is_real, + ); + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.value, + cols.interm2, + w, + is_real, + ); + } +} \ No newline at end of file diff --git a/crates/core/machine/src/operations/mod.rs b/crates/core/machine/src/operations/mod.rs index 31a9537c7..667aed4de 100644 --- a/crates/core/machine/src/operations/mod.rs +++ b/crates/core/machine/src/operations/mod.rs @@ -8,6 +8,7 @@ mod add; mod add4; mod add5; mod adddouble; +mod aes; mod and; mod cmp; pub mod field; @@ -24,11 +25,11 @@ mod not; mod or; pub mod poseidon2; mod xor; - pub use add::*; pub use add4::*; pub use add5::*; pub use adddouble::*; +pub use aes::*; pub use and::*; pub use cmp::*; pub use fixed_rotate_right::*; diff --git a/crates/core/machine/src/runtime/syscall.rs b/crates/core/machine/src/runtime/syscall.rs index 902697912..8ac1b85ea 100644 --- a/crates/core/machine/src/runtime/syscall.rs +++ b/crates/core/machine/src/runtime/syscall.rs @@ -157,6 +157,9 @@ pub enum SyscallCode { /// Executes the `POSEIDON2_PERMUTE` precompile. POSEIDON2_PERMUTE = 0x00_00_01_30, + + /// Executes the `AES128_ENCRYPT` precompile. + AES128_ENCRYPT = 0x00_01_01_31, } impl SyscallCode { @@ -201,6 +204,7 @@ impl SyscallCode { 0x00_01_01_29 => SyscallCode::BN254_FP2_ADD, 0x00_01_01_2A => SyscallCode::BN254_FP2_SUB, 0x00_01_01_2B => SyscallCode::BN254_FP2_MUL, + 0x00_01_01_31 => SyscallCode::AES128_ENCRYPT, 0x00_00_01_1C => SyscallCode::BLS12381_DECOMPRESS, 0x00_00_01_30 => SyscallCode::POSEIDON2_PERMUTE, _ => panic!("invalid syscall number: {}", value), diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs new file mode 100644 index 000000000..54cfc11b4 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -0,0 +1,20 @@ +use p3_air::{Air, BaseAir}; +use tempfile::Builder; +use zkm_stark::ZKMAirBuilder; +use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; +use crate::syscall::precompiles::aes128_encrypt::columns::NUM_AES128_ENCRYPTION_COLS; + +impl BaseAir for AES128EncryptChip { + fn width(&self) -> usize { + NUM_AES128_ENCRYPTION_COLS + } +} + +impl Air for AES128EncryptChip +where + AB: ZKMAirBuilder, +{ + fn eval(&self, builder: &mut AB) { + todo!() + } +} \ No newline at end of file diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs new file mode 100644 index 000000000..0adbbb797 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs @@ -0,0 +1,30 @@ +use zkm_derive::AlignedBorrow; +use crate::memory::{MemoryReadCols, MemoryReadWriteCols}; +use crate::operations::mix_column::MixColumn; +use crate::operations::round_key::NextRoundKey; + +/// AES128EncryptCols is the column layout for the AES128 encryption. +/// The number of rows equal to the number of block. +#[derive(AlignedBorrow)] +#[repr(C)] +pub struct AES128EncryptionCols { + pub shard: T, + pub clk: T, + pub is_real: T, + pub key_address: T, + pub block_address: T, + pub sbox_address: T, + pub key: [MemoryReadCols; 4], + pub block: [MemoryReadWriteCols; 4], + pub sbox: [MemoryReadCols; 24], //24 * 11 = 264 > 256 Sbox elements. + pub round: [T; 11], // [0,..10] + pub state_matrix: [T; 16], + pub round_key_matrix: [T; 16], + pub next_round_key: NextRoundKey, + pub roundkey_subs_bytes: [MemoryReadCols; 4], // byte subs for round key + pub state_subs_bytes: [MemoryReadCols; 16], // byte subs for state + pub mix_column: MixColumn, + pub add_round_key: [T; 16], // result of this round +} + +pub const NUM_AES128_ENCRYPTION_COLS: usize = core::mem::size_of::>(); diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs new file mode 100644 index 000000000..6d09b3642 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs @@ -0,0 +1,45 @@ +mod air; +mod columns; +mod trace; + +pub const AES_SBOX: [u32; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + +#[derive(Default)] +pub struct AES128EncryptChip; + +impl AES128EncryptChip { + pub const fn new() -> Self { + Self {} + } +} + +#[cfg(test)] +pub mod tests { + use crate::utils::{self, run_test}; + use test_artifacts::AES128_ENCRYPT_ELF; + use zkm_core_executor::Program; + use zkm_stark::CpuProver; + #[test] + fn test_aes128_encrypt_program_prove() { + utils::setup_logger(); + let program = Program::from(AES128_ENCRYPT_ELF).unwrap(); + run_test::>(program).unwrap(); + } +} \ No newline at end of file diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs new file mode 100644 index 000000000..b4cc48c52 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -0,0 +1,96 @@ +use std::borrow::BorrowMut; + +use hashbrown::HashMap; +use itertools::Itertools; +use p3_air::BaseAir; +use p3_field::PrimeField32; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord, PrecompileEvent, ShaCompressEvent}, + syscalls::SyscallCode, + ExecutionRecord, Program, +}; +use zkm_core_executor::events::{AES128EncryptEvent, KeccakSpongeEvent}; +use zkm_stark::{air::MachineAir, Word}; + +use super::{columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}, AES128EncryptChip, AES_SBOX}; +use crate::utils::pad_rows_fixed; + +impl MachineAir for AES128EncryptChip { + type Record = ExecutionRecord; + type Program = Program; + + fn name(&self) -> String { + "AES128Encryption".to_string() + } + + fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { + let events = input.get_precompile_events(SyscallCode::AES128_ENCRYPT); + let chunk_size = std::cmp::max(events.len() / num_cpus::get(), 1); + + let blu_batches = events + .par_chunks(chunk_size) + .map(|events| { + let mut blu: HashMap = HashMap::new(); + events.iter().for_each(|(_, event)| { + let event = if let PrecompileEvent::Aes128Encrypt(event) = event { + event + } else { + unreachable!(); + }; + + self.event_to_rows::(event, &mut None, &mut blu); + }); + blu + }) + .collect::>(); + + output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec()); + } + + fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + let rows = Vec::new(); + + let mut wrapped_rows = Some(rows); + for (_, event) in input.get_precompile_events(SyscallCode::AES128_ENCRYPT) { + let event = if let PrecompileEvent::Aes128Encrypt(event) = event { + event + } else { + unreachable!(); + }; + self.event_to_rows(event, &mut wrapped_rows, &mut Vec::new()); + } + let mut rows = wrapped_rows.unwrap(); + let num_real_rows = rows.len(); + let mut padded_num_rows = num_real_rows.next_power_of_two(); + for i in num_real_rows..padded_num_rows { + let mut row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; + // let cols: &mut AES128EncryptionCols = row.as_mut_slice().borrow_mut(); + // todo!() + rows.push(row); + } + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_AES128_ENCRYPTION_COLS) + } + + + + fn included(&self, shard: &Self::Record) -> bool { + if let Some(shape) = shard.shape.as_ref() { + shape.included::(self) + } else { + !shard.get_precompile_events(SyscallCode::AES128_ENCRYPT).is_empty() + } + } +} + +impl AES128EncryptChip { + pub fn event_to_rows( + &self, + event: &AES128EncryptEvent, + rows: &mut Option>, + blu: &mut impl ByteRecord, + ) { + todo!() + } +} \ No newline at end of file diff --git a/crates/core/machine/src/syscall/precompiles/mod.rs b/crates/core/machine/src/syscall/precompiles/mod.rs index 6a5c37835..e011c11fe 100644 --- a/crates/core/machine/src/syscall/precompiles/mod.rs +++ b/crates/core/machine/src/syscall/precompiles/mod.rs @@ -1,3 +1,4 @@ +pub mod aes128_encrypt; pub mod edwards; pub mod fptower; pub mod keccak_sponge; diff --git a/crates/test-artifacts/guests/Cargo.toml b/crates/test-artifacts/guests/Cargo.toml index 79b8ba081..e2868abc6 100644 --- a/crates/test-artifacts/guests/Cargo.toml +++ b/crates/test-artifacts/guests/Cargo.toml @@ -44,6 +44,7 @@ members = [ "verify-proof", "u256x2048-mul", "unconstrained", + "aes128", ] resolver = "2" diff --git a/crates/test-artifacts/guests/aes128/Cargo.toml b/crates/test-artifacts/guests/aes128/Cargo.toml new file mode 100644 index 000000000..335358d5c --- /dev/null +++ b/crates/test-artifacts/guests/aes128/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "aes128" +version = "1.1.0" +edition = "2021" +publish = false + +[dependencies] +zkm-zkvm = { path = "../../../../crates/zkvm/entrypoint" } diff --git a/crates/test-artifacts/guests/aes128/src/main.rs b/crates/test-artifacts/guests/aes128/src/main.rs new file mode 100644 index 000000000..941eccb56 --- /dev/null +++ b/crates/test-artifacts/guests/aes128/src/main.rs @@ -0,0 +1,13 @@ +#![no_std] +#![no_main] +zkm_zkvm::entrypoint!(main); + +use zkm_zkvm::lib::aes128::aes128_encrypt; + +pub fn main() { + for _ in 0..25 { + let mut state = [1u8; 16]; + let key = [0u8; 16]; + aes128_encrypt(&mut state, &key); + } +} diff --git a/crates/test-artifacts/src/lib.rs b/crates/test-artifacts/src/lib.rs index 2f8a21c9d..309202176 100644 --- a/crates/test-artifacts/src/lib.rs +++ b/crates/test-artifacts/src/lib.rs @@ -8,6 +8,7 @@ pub const HELLO_WORLD_ELF: &[u8] = include_elf!("hello-world"); pub const POSEIDON2_PERMUTE_ELF: &[u8] = include_elf!("poseidon2-permute-test"); +pub const AES128_ENCRYPT_ELF: &[u8] = include_elf!("aes128"); pub const SHA2_ELF: &[u8] = include_elf!("sha2-test"); pub const SHA_EXTEND_ELF: &[u8] = include_elf!("sha-extend-test"); pub const SHA_COMPRESS_ELF: &[u8] = include_elf!("sha-compress-test"); diff --git a/crates/zkvm/entrypoint/src/syscalls/aes128.rs b/crates/zkvm/entrypoint/src/syscalls/aes128.rs new file mode 100644 index 000000000..35a34f81e --- /dev/null +++ b/crates/zkvm/entrypoint/src/syscalls/aes128.rs @@ -0,0 +1,26 @@ +#[cfg(target_os = "zkvm")] +use core::arch::asm; + +/// Executes the AES-128 encryption on the given block with the given key. +/// +/// ### Safety +/// +/// The caller must ensure that `state` and `key` are valid pointers to data that are aligned along +/// a four byte boundary. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4], sbox: *const [u32; 256]) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "syscall", + in("$2") crate::syscalls::AES128_ENCRYPT, + in("$4") state, + in("$5") key, + in("$6") sbox, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} diff --git a/crates/zkvm/entrypoint/src/syscalls/mod.rs b/crates/zkvm/entrypoint/src/syscalls/mod.rs index b1aa1d61a..46838f440 100644 --- a/crates/zkvm/entrypoint/src/syscalls/mod.rs +++ b/crates/zkvm/entrypoint/src/syscalls/mod.rs @@ -1,3 +1,4 @@ +mod aes128; mod bigint; mod bls12381; mod bn254; @@ -19,6 +20,7 @@ mod unconstrained; #[cfg(feature = "verify")] mod verify; +pub use aes128::*; pub use bigint::*; pub use bls12381::*; pub use bn254::*; @@ -162,3 +164,6 @@ pub const BN254_FP2_MUL: u32 = 0x01_01_00_2B; /// Executes the `POSEIDON2_PERMUTE` precompile. pub const POSEIDON2_PERMUTE: u32 = 0x00_01_00_30; + +/// Executes the `AES128_ENCRYPT` precompile. +pub const AES128_ENCRYPT: u32 = 0x01_01_00_31; diff --git a/crates/zkvm/lib/src/aes128.rs b/crates/zkvm/lib/src/aes128.rs new file mode 100644 index 000000000..3deeb21bd --- /dev/null +++ b/crates/zkvm/lib/src/aes128.rs @@ -0,0 +1,47 @@ +use crate::syscall_aes128_encrypt; + +pub fn aes128_encrypt(state: &mut [u8; 16], key: &[u8; 16]) { + // convert to u32 to align the memory + let mut state_u32 = [0u32; 4]; + let mut key_u32 = [0u32; 4]; + + let sbox: [u32; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, + ]; + + for i in 0..4 { + state_u32[i] = u32::from_le_bytes([ + state[i * 4], + state[i * 4 + 1], + state[i * 4 + 2], + state[i * 4 + 3], + ]); + key_u32[i] = u32::from_le_bytes([ + key[i * 4], + key[i * 4 + 1], + key[i * 4 + 2], + key[i * 4 + 3], + ]); + } + unsafe { + syscall_aes128_encrypt(&mut state_u32, &key_u32, &sbox) + } + for i in 0..4 { + state[4 * i..4 * i + 4].copy_from_slice(&state_u32[i].to_le_bytes()); + } +} \ No newline at end of file diff --git a/crates/zkvm/lib/src/lib.rs b/crates/zkvm/lib/src/lib.rs index 35dd71a0d..4940a4b0b 100644 --- a/crates/zkvm/lib/src/lib.rs +++ b/crates/zkvm/lib/src/lib.rs @@ -19,6 +19,7 @@ pub mod unconstrained; pub mod utils; #[cfg(feature = "verify")] pub mod verify; +pub mod aes128; extern "C" { /// Halts the program with the given exit code. @@ -78,6 +79,9 @@ extern "C" { /// Executes the Poseidon2 permutation pub fn syscall_poseidon2_permute(state: *mut [u32; 16]); + /// Executes the AES-128 encryption on the given state with the given key. + pub fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4], sbox: *const [u32; 256]); + /// Executes an uint256 multiplication on the given inputs. pub fn syscall_uint256_mulmod(x: *mut [u32; 8], y: *const [u32; 8]); diff --git a/examples/Cargo.lock b/examples/Cargo.lock index ba80a1e3e..7e53311da 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -49,6 +49,24 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "aes128" +version = "1.1.0" +dependencies = [ + "zkm-zkvm", +] + +[[package]] +name = "aes128-host" +version = "1.1.0" +dependencies = [ + "hex", + "log", + "tracing", + "zkm-build", + "zkm-sdk", +] + [[package]] name = "aggregation" version = "1.1.0" @@ -8721,6 +8739,7 @@ dependencies = [ name = "zkm-core-executor" version = "1.2.0" dependencies = [ + "aes", "anyhow", "bincode", "bytemuck", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 504354292..6cb0ee51a 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -1,5 +1,7 @@ [workspace] members = [ + "aes128/guest", + "aes128/host", "aggregation/guest", "aggregation/host", "bn254/guest", diff --git a/examples/aes128/guest/Cargo.toml b/examples/aes128/guest/Cargo.toml new file mode 100644 index 000000000..602f6a47e --- /dev/null +++ b/examples/aes128/guest/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "aes128" +version = "1.1.0" +edition = "2021" +publish = false + +[dependencies] +zkm-zkvm = { path = "../../../crates/zkvm/entrypoint" } + diff --git a/examples/aes128/guest/src/main.rs b/examples/aes128/guest/src/main.rs new file mode 100644 index 000000000..dce3131d0 --- /dev/null +++ b/examples/aes128/guest/src/main.rs @@ -0,0 +1,40 @@ +#![no_std] +#![no_main] + +extern crate alloc; + +use alloc::vec::Vec; +use core::cmp::min; +use zkm_zkvm::lib::aes128::aes128_encrypt; +zkm_zkvm::entrypoint!(main); + + +pub fn main() { + let plain_text: Vec = zkm_zkvm::io::read(); + let key: Vec = zkm_zkvm::io::read(); + let iv: Vec = zkm_zkvm::io::read(); + let expected_output: Vec = zkm_zkvm::io::read(); + zkm_zkvm::io::commit::>(&plain_text); + zkm_zkvm::io::commit::>(&key); + zkm_zkvm::io::commit::>(&iv); + + assert_eq!(key.len(), 16); + assert_eq!(iv.len(), 16); + let key_array: [u8; 16] = key.as_slice().try_into().unwrap(); + let iv_array: [u8; 16] = iv.as_slice().try_into().unwrap(); + let output = cipher_block_chaining(&plain_text, &key_array, &iv_array); + assert_eq!(expected_output, output.to_vec()); + zkm_zkvm::io::commit::>(&output.to_vec()); +} + +fn cipher_block_chaining(input: &[u8], key: &[u8; 16], iv: &[u8; 16]) -> [u8; 16] { + let mut block = iv.clone(); + for chunk in input.chunks(16) { + for i in 0..min(chunk.len(), 16) { + block[i] ^= chunk[i]; + } + aes128_encrypt(&mut block, key); + } + block +} + diff --git a/examples/aes128/host/Cargo.toml b/examples/aes128/host/Cargo.toml new file mode 100644 index 000000000..dc0fe3c83 --- /dev/null +++ b/examples/aes128/host/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "aes128-host" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[dependencies] +zkm-sdk = { workspace = true } +tracing = { workspace = true } +hex = "0.4.3" +log = "0.4.27" + +[build-dependencies] +zkm-build = { workspace = true } diff --git a/examples/aes128/host/build.rs b/examples/aes128/host/build.rs new file mode 100644 index 000000000..e7ea36ace --- /dev/null +++ b/examples/aes128/host/build.rs @@ -0,0 +1,3 @@ +fn main() { + zkm_build::build_program("../guest"); +} diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs new file mode 100644 index 000000000..53c76031c --- /dev/null +++ b/examples/aes128/host/src/main.rs @@ -0,0 +1,60 @@ +use std::env; +use zkm_sdk::{include_elf, utils, ProverClient, ZKMProofWithPublicValues, ZKMStdin}; + +/// The ELF we want to execute inside the zkVM. +const ELF: &[u8] = include_elf!("aes128"); +fn prove_aes128_rust() { + let mut stdin = ZKMStdin::new(); + + // load input + let plain_text = vec![ + 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, + 97, 244, 57, 25, 124, 234, 234, 234, 214, 134, 135, 246, 17, 29, 7 + ]; + let key = vec![0_u8; 16]; + let iv = vec![0_u8; 16]; + + let expected_output = vec![ + 97_u8, 203, 140, 117, 36, 211, 41, 97, + 177, 36, 93, 148, 107, 228, 201, 129 + ]; + + stdin.write(&plain_text); + stdin.write(&key); + stdin.write(&iv); + stdin.write(&expected_output); + + // Create a `ProverClient` method. + let client = ProverClient::new(); + + // Execute the program using the `ProverClient.execute` method, without generating a proof. + let (_, report) = client.execute(ELF, stdin.clone()).run().unwrap(); + println!("executed program with {} cycles", report.total_instruction_count()); + + // // Generate the proof for the given program and input. + // let (pk, vk) = client.setup(ELF); + // let mut proof = client.prove(&pk, stdin).run().unwrap(); + // println!("generated proof"); + // + // // Read and verify the output. + // // + // // Note that this output is read from values committed to in the program using + // // `zkm_zkvm::io::commit`. + // let plain_text = proof.public_values.read::>(); + // let key = proof.public_values.read::>(); + // let iv = proof.public_values.read::>(); + // let public_input = proof.public_values.read::>(); + // println!("plaintext: {:?}", plain_text); + // println!("key: {:?}", key); + // println!("iv: {:?}", iv); + // assert_eq!(expected_output, public_input); + // + // // Verify proof and public values + // client.verify(&proof, &vk).expect("verification failed"); + // println!("successfully generated and verified proof for the program!"); +} + +fn main() { + utils::setup_logger(); + prove_aes128_rust(); +} From ba9230abbc4ccefb9fc613ae275557f5a43f21f7 Mon Sep 17 00:00:00 2001 From: vanhger Date: Thu, 18 Sep 2025 11:24:48 +0700 Subject: [PATCH 02/12] chore: add trace generation --- .../executor/src/events/precompiles/aes128.rs | 5 +- crates/core/executor/src/record.rs | 1 + .../syscalls/precompiles/aes128/encrypt.rs | 27 ++- .../src/syscalls/precompiles/aes128/mod.rs | 2 +- crates/core/machine/src/mips/mod.rs | 2 + .../machine/src/operations/aes/round_key.rs | 2 +- .../syscall/precompiles/aes128_encrypt/air.rs | 24 ++- .../precompiles/aes128_encrypt/columns.rs | 2 + .../precompiles/aes128_encrypt/trace.rs | 196 ++++++++++++++++-- crates/stark/src/opts.rs | 3 + .../test-artifacts/guests/aes128/src/main.rs | 4 +- examples/aes128/host/src/main.rs | 21 +- 12 files changed, 252 insertions(+), 37 deletions(-) diff --git a/crates/core/executor/src/events/precompiles/aes128.rs b/crates/core/executor/src/events/precompiles/aes128.rs index aecc727f6..413702661 100644 --- a/crates/core/executor/src/events/precompiles/aes128.rs +++ b/crates/core/executor/src/events/precompiles/aes128.rs @@ -5,7 +5,8 @@ use crate::events::{ MemoryLocalEvent, }; -pub(crate) const AES_128_BLOCK_U32S: usize = 4; +pub const AES_128_BLOCK_U32S: usize = 4; +pub const AES_128_BLOCK_BYTES: usize = 16; /// AES128 Encrypt Event /// @@ -31,7 +32,7 @@ pub struct AES128EncryptEvent { /// The output block as a [u32; AES_128_BLOCK_U32S] words. pub output: [u32; AES_128_BLOCK_U32S], /// The sbox reads - pub sbox_reads: Vec, + pub sbox_reads: Vec, /// The memory records for the input pub input_read_records: [MemoryReadRecord; AES_128_BLOCK_U32S], /// The memory records for the key diff --git a/crates/core/executor/src/record.rs b/crates/core/executor/src/record.rs index a98c7c3d6..811a3d778 100644 --- a/crates/core/executor/src/record.rs +++ b/crates/core/executor/src/record.rs @@ -124,6 +124,7 @@ impl ExecutionRecord { SyscallCode::KECCAK_SPONGE => opts.keccak, SyscallCode::SHA_EXTEND => opts.sha_extend, SyscallCode::SHA_COMPRESS => opts.sha_compress, + SyscallCode::AES128_ENCRYPT => opts.aes128_encrypt, _ => opts.deferred, }; diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs index 5c30a9c7a..969eac75b 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -44,7 +44,7 @@ impl Syscall for AES128EncryptSyscall { let mut state = Vec::new(); let mut key = Vec::new(); let mut output = Vec::new(); - let mut sbox = Vec::new(); + let mut sbox: Vec = Vec::new(); // read sbox ptr let (sbox_ptr_memory, sbox_ptr) = rt.mr(A2 as u32); @@ -74,19 +74,29 @@ impl Syscall for AES128EncryptSyscall { for i in 0..24 { let (record, value) = rt.mr(sbox_ptr + i as u32 * 4); sbox_read_records.push(record); - sbox.push(value); + assert!(value <= u8::MAX as u32); + sbox.push(value as u8); } // perform AES let mut round_key = key; for i in 1..11 { // compute round key - Self::compute_round_key(rt, &mut round_key, &mut sbox_read_records, sbox_ptr, i - 1); + Self::compute_round_key( + rt, + &mut round_key, + &mut sbox_read_records, + &mut sbox, + sbox_ptr, + i - 1 + ); // Subs_bytes for i in 0..state.len() { let (record, value) = rt.mr(sbox_ptr + state[i] as u32 * 4); sbox_read_records.push(record); + assert!(value <= u8::MAX as u32); + sbox.push(value as u8); state[i] = value as u8; } @@ -98,7 +108,7 @@ impl Syscall for AES128EncryptSyscall { state[12], state[1], state[6], state[11], ].to_vec(); - // Mix column + // Mix columns let mix_columns = if i != 10 { let mut mixed = shift_row.clone(); for col in 0..4 { @@ -127,13 +137,15 @@ impl Syscall for AES128EncryptSyscall { for j in i * 24..i * 24 + 24 { let (record, value) = rt.mr(sbox_ptr as u32 + j as u32 * 4); sbox_read_records.push(record); - sbox.push(value); + assert!(value <= u8::MAX as u32); + sbox.push(value as u8); } } else { for j in i * 24..256 { let (record, value) = rt.mr(sbox_ptr as u32 + j as u32 * 4); sbox_read_records.push(record); - sbox.push(value); + assert!(value <= u8::MAX as u32); + sbox.push(value as u8); } } } @@ -182,6 +194,7 @@ impl AES128EncryptSyscall { rt: &mut SyscallContext, previous_key: &mut [u8], sbox_records: &mut Vec, + sbox: &mut Vec, sbox_ptr: u32, round: usize ) { @@ -194,6 +207,8 @@ impl AES128EncryptSyscall { for (i, rcon) in AES128_RCON[round].iter().enumerate() { let (record, value) = rt.mr(sbox_ptr + result[i] as u32 * 4); sbox_records.push(record); + assert!(value <= u8::MAX as u32); + sbox.push(value as u8); result[i] = (value as u8) ^ rcon; } result diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs index e4fbb9d57..d0d2382ed 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs @@ -1,2 +1,2 @@ pub mod encrypt; -mod utils; \ No newline at end of file +pub mod utils; \ No newline at end of file diff --git a/crates/core/machine/src/mips/mod.rs b/crates/core/machine/src/mips/mod.rs index 6a594996f..b96eff36e 100644 --- a/crates/core/machine/src/mips/mod.rs +++ b/crates/core/machine/src/mips/mod.rs @@ -276,6 +276,7 @@ impl MipsAir { chips.push(poseidon2_permute); let aes128_encrypt = Chip::new(MipsAir::Aes128Encrypt(AES128EncryptChip::new())); + // log::info!("aes128 cost: {:?}", aes128_encrypt.cost()); costs.insert(aes128_encrypt.name(), 11 * aes128_encrypt.cost()); chips.push(aes128_encrypt); @@ -581,6 +582,7 @@ impl MipsAir { Self::Sha256Compress(_) => 80, Self::Sha256Extend(_) => 48, Self::KeccakSponge(_) => 24, + Self::Aes128Encrypt(_) => 11, _ => 1, } } diff --git a/crates/core/machine/src/operations/aes/round_key.rs b/crates/core/machine/src/operations/aes/round_key.rs index fd42ef69a..8fdc55f94 100644 --- a/crates/core/machine/src/operations/aes/round_key.rs +++ b/crates/core/machine/src/operations/aes/round_key.rs @@ -32,7 +32,7 @@ impl NextRoundKey { ) -> [u8; 16] { // check sbox values let sbox_values: [u32; 4] = byte_subs_records.map(|m| m.value); - let all_in_u8 = sbox_values.iter().all(|&v| v <= 0xFF); + let all_in_u8 = sbox_values.iter().all(|&v| v <= u8::MAX as u32); if !all_in_u8 { panic!("Not all sbox_values fit in u8"); } diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index 54cfc11b4..4654e71ed 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -1,8 +1,13 @@ +use std::borrow::Borrow; +use log::__private_api::loc; use p3_air::{Air, BaseAir}; +use p3_field::FieldAlgebra; +use p3_matrix::Matrix; use tempfile::Builder; -use zkm_stark::ZKMAirBuilder; +use zkm_stark::{MachineAir, ZKMAirBuilder}; +use crate::KeccakSpongeChip; use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; -use crate::syscall::precompiles::aes128_encrypt::columns::NUM_AES128_ENCRYPTION_COLS; +use crate::syscall::precompiles::aes128_encrypt::columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}; impl BaseAir for AES128EncryptChip { fn width(&self) -> usize { @@ -15,6 +20,19 @@ where AB: ZKMAirBuilder, { fn eval(&self, builder: &mut AB) { - todo!() + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local: &AES128EncryptionCols = (*local).borrow(); + let next: &AES128EncryptionCols = (*next).borrow(); + + let first_round = local.round[0]; + let last_round = local.round[10]; + builder.assert_eq(first_round * local.is_real, local.receive_syscall); + + } +} + +impl AES128EncryptChip { + } \ No newline at end of file diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs index 0adbbb797..1cc132e73 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs @@ -14,6 +14,8 @@ pub struct AES128EncryptionCols { pub key_address: T, pub block_address: T, pub sbox_address: T, + pub receive_syscall: T, + pub sbox_addr_read: MemoryReadCols, pub key: [MemoryReadCols; 4], pub block: [MemoryReadWriteCols; 4], pub sbox: [MemoryReadCols; 24], //24 * 11 = 264 > 256 Sbox elements. diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index b4cc48c52..32f9bfa3a 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -6,23 +6,18 @@ use p3_air::BaseAir; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; -use zkm_core_executor::{ - events::{ByteLookupEvent, ByteRecord, PrecompileEvent, ShaCompressEvent}, - syscalls::SyscallCode, - ExecutionRecord, Program, -}; -use zkm_core_executor::events::{AES128EncryptEvent, KeccakSpongeEvent}; -use zkm_stark::{air::MachineAir, Word}; - -use super::{columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}, AES128EncryptChip, AES_SBOX}; -use crate::utils::pad_rows_fixed; +use zkm_core_executor::{events::{ByteLookupEvent, ByteRecord, PrecompileEvent}, syscalls::SyscallCode, ByteOpcode, ExecutionRecord, Program}; +use zkm_core_executor::events::{AES128EncryptEvent, MemoryRecordEnum, AES_128_BLOCK_BYTES, AES_128_BLOCK_U32S}; +use zkm_stark::{air::MachineAir}; +use crate::syscall::precompiles::aes128_encrypt::columns::AES128EncryptionCols; +use super::{columns::NUM_AES128_ENCRYPTION_COLS, AES128EncryptChip, AES_SBOX}; impl MachineAir for AES128EncryptChip { type Record = ExecutionRecord; type Program = Program; fn name(&self) -> String { - "AES128Encryption".to_string() + "Aes128Encrypt".to_string() } fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { @@ -51,6 +46,7 @@ impl MachineAir for AES128EncryptChip { fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { let rows = Vec::new(); + log::info!("generate trace"); let mut wrapped_rows = Some(rows); for (_, event) in input.get_precompile_events(SyscallCode::AES128_ENCRYPT) { @@ -67,7 +63,7 @@ impl MachineAir for AES128EncryptChip { for i in num_real_rows..padded_num_rows { let mut row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; // let cols: &mut AES128EncryptionCols = row.as_mut_slice().borrow_mut(); - // todo!() + // rows.push(row); } RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_AES128_ENCRYPTION_COLS) @@ -91,6 +87,180 @@ impl AES128EncryptChip { rows: &mut Option>, blu: &mut impl ByteRecord, ) { - todo!() + let num_round = 11; + let mut state = [0_u8; AES_128_BLOCK_BYTES]; + let mut round_key = [0_u8; AES_128_BLOCK_BYTES]; + for i in 0..AES_128_BLOCK_U32S { + state[i * 4..i * 4 + 4].copy_from_slice(&event.input[i].to_le_bytes()); + round_key[i * 4..i * 4 + 4].copy_from_slice(&event.key[i].to_le_bytes()); + } + let mut sbox_read_index = 0_usize; + for round in 0..num_round { + let mut row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; + let cols: &mut AES128EncryptionCols = row.as_mut_slice().borrow_mut(); + cols.shard = F::from_canonical_u32(event.shard); + cols.clk = F::from_canonical_u32(event.clk); + cols.is_real = F::ONE; + cols.key_address = F::from_canonical_u32(event.key_addr); + cols.block_address = F::from_canonical_u32(event.block_addr); + cols.sbox_address = F::from_canonical_u32(event.sbox_addr); + cols.round = [F::ZERO; 11]; + cols.round[round] = F::ONE; + cols.receive_syscall = F::from_bool(round == 0); + + for i in 0..AES_128_BLOCK_BYTES { + cols.state_matrix[i] = F::from_canonical_u8(state[i]); + cols.round_key_matrix[i] = F::from_canonical_u8(round_key[i]); + } + + if round == 0 { + // read the input + for i in 0..AES_128_BLOCK_U32S { + cols.block[i].populate( + MemoryRecordEnum::Read(event.input_read_records[i]), + blu + ); + } + // read the key + for i in 0..AES_128_BLOCK_U32S { + cols.key[i].populate( + event.key_read_records[i], + blu + ); + } + // read the sbox address + cols.sbox_addr_read.populate( + event.sbox_addr_memory, + blu + ); + // compute the add_round_key + for i in 0..AES_128_BLOCK_BYTES { + let tmp = state[i] ^ round_key[i]; + cols.add_round_key[i] = F::from_canonical_u8(state[i]); + let byte_lookup_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: tmp as u16, + a2: 0, + b: state[i], + c: round_key[i], + }; + blu.add_byte_lookup_event(byte_lookup_event); + state[i] = tmp; + } + } else + { + // subs_bytes + for i in 0..AES_128_BLOCK_BYTES { + cols.state_subs_bytes[i].populate( + event.sbox_read_records[sbox_read_index + i], + blu + ); + state[i] = event.sbox_reads[sbox_read_index + i]; + } + sbox_read_index += AES_128_BLOCK_BYTES; + + // shift_rows + let shifted_row = [ + state[0], state[5], state[10], state[15], + state[4], state[9], state[14], state[3], + state[8], state[13], state[2], state[7], + state[12], state[1], state[6], state[11], + ]; + + // Mix columns + let mixed_columns = if round != 10 { + cols.mix_column.populate( + blu, + &shifted_row + ) + } else { + shifted_row + }; + + // Add round key + for i in 0..AES_128_BLOCK_BYTES { + state[i] = mixed_columns[i] ^ round_key[i]; + cols.add_round_key[i] = F::from_canonical_u8(state[i]); + let byte_lookup_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: state[i] as u16, + a2: 0, + b: mixed_columns[i], + c: round_key[i], + }; + blu.add_byte_lookup_event(byte_lookup_event); + } + } + + if round != 10 { + // read 24 sbox elements for each, except the last round + for i in sbox_read_index..sbox_read_index + 24 { + cols.sbox[i - sbox_read_index].populate( + event.sbox_read_records[i], + blu + ) + } + sbox_read_index += 24; + + + // compute next round key + let next_round_key = cols.next_round_key.populate( + blu, + &round_key, + event.sbox_read_records[sbox_read_index..sbox_read_index + 4] + .try_into() + .expect("Slice length must be exactly 4"), + round as u8, + ); + + // read the round key byte subs + for i in 0..4 { + cols.roundkey_subs_bytes[i].populate( + event.sbox_read_records[sbox_read_index + i], + blu + ); + } + sbox_read_index += 4; + + round_key = next_round_key; + } else + { + for i in sbox_read_index..(sbox_read_index + 16) { + cols.sbox[i - sbox_read_index].populate( + event.sbox_read_records[i], + blu + ) + } + sbox_read_index += 16; + assert_eq!(sbox_read_index, 456); + + for i in 0..4 { + // check output + let tmp = event.output_write_records[i].value.to_le_bytes(); + for j in 0..4 { + assert_eq!(state[i * 4], tmp[0]); + assert_eq!(state[i * 4 + 1], tmp[1]); + assert_eq!(state[i * 4 + 2], tmp[2]); + assert_eq!(state[i * 4 + 3], tmp[3]); + } + } + + // write output + for i in 0..AES_128_BLOCK_U32S { + cols.block[i].populate( + MemoryRecordEnum::Write(event.output_write_records[i]), + blu + ); + } + + } + + if rows.as_ref().is_some() { + rows.as_mut().unwrap().push(row); + } + } + if rows.as_ref().is_some() { + log::info!("rows height: {:?}", rows.as_ref().unwrap().len()); + } } } \ No newline at end of file diff --git a/crates/stark/src/opts.rs b/crates/stark/src/opts.rs index 78f284dcf..5524d49a4 100644 --- a/crates/stark/src/opts.rs +++ b/crates/stark/src/opts.rs @@ -204,6 +204,8 @@ pub struct SplitOpts { pub sha_extend: usize, /// The threshold for sha compress events. pub sha_compress: usize, + /// The threshold for aes128 encrypt events. + pub aes128_encrypt: usize, /// The threshold for memory events. pub memory: usize, } @@ -217,6 +219,7 @@ impl SplitOpts { keccak: 8 * deferred_split_threshold / 24, sha_extend: 32 * deferred_split_threshold / 48, sha_compress: 32 * deferred_split_threshold / 80, + aes128_encrypt: 8 * deferred_split_threshold / 11, memory: 64 * deferred_split_threshold, } } diff --git a/crates/test-artifacts/guests/aes128/src/main.rs b/crates/test-artifacts/guests/aes128/src/main.rs index 941eccb56..d90db5ad3 100644 --- a/crates/test-artifacts/guests/aes128/src/main.rs +++ b/crates/test-artifacts/guests/aes128/src/main.rs @@ -5,8 +5,8 @@ zkm_zkvm::entrypoint!(main); use zkm_zkvm::lib::aes128::aes128_encrypt; pub fn main() { - for _ in 0..25 { - let mut state = [1u8; 16]; + for _ in 0..1 { + let mut state = [0u8; 16]; let key = [0u8; 16]; aes128_encrypt(&mut state, &key); } diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs index 53c76031c..7539a6cac 100644 --- a/examples/aes128/host/src/main.rs +++ b/examples/aes128/host/src/main.rs @@ -8,16 +8,19 @@ fn prove_aes128_rust() { // load input let plain_text = vec![ - 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, - 97, 244, 57, 25, 124, 234, 234, 234, 214, 134, 135, 246, 17, 29, 7 + // 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, + // 97, 244, 57, 25, 124, 234, 234, 234, 214, 134, 135, 246, 17, 29, 7 + 0; 16 ]; let key = vec![0_u8; 16]; let iv = vec![0_u8; 16]; - let expected_output = vec![ - 97_u8, 203, 140, 117, 36, 211, 41, 97, - 177, 36, 93, 148, 107, 228, 201, 129 - ]; + // let expected_output = vec![ + // 97_u8, 203, 140, 117, 36, 211, 41, 97, + // 177, 36, 93, 148, 107, 228, 201, 129 + // ]; + + let expected_output = vec![102_u8, 233, 75, 212, 239, 138, 44, 59, 136, 76, 250, 89, 202, 52, 43, 46]; stdin.write(&plain_text); stdin.write(&key); @@ -32,9 +35,9 @@ fn prove_aes128_rust() { println!("executed program with {} cycles", report.total_instruction_count()); // // Generate the proof for the given program and input. - // let (pk, vk) = client.setup(ELF); - // let mut proof = client.prove(&pk, stdin).run().unwrap(); - // println!("generated proof"); + let (pk, vk) = client.setup(ELF); + let mut proof = client.prove(&pk, stdin).run().unwrap(); + println!("generated proof"); // // // Read and verify the output. // // From 6f1be69068089c6bab1136b08c70d66d8c790e65 Mon Sep 17 00:00:00 2001 From: vanhger Date: Thu, 18 Sep 2025 14:32:33 +0700 Subject: [PATCH 03/12] feat: add memory constraints and change mips air id. --- crates/core/executor/src/air.rs | 2 +- .../syscall/precompiles/aes128_encrypt/air.rs | 113 +++++++++++++++++- .../precompiles/aes128_encrypt/columns.rs | 1 + .../precompiles/aes128_encrypt/trace.rs | 4 +- 4 files changed, 113 insertions(+), 7 deletions(-) diff --git a/crates/core/executor/src/air.rs b/crates/core/executor/src/air.rs index 3b89a6836..aceb3392a 100644 --- a/crates/core/executor/src/air.rs +++ b/crates/core/executor/src/air.rs @@ -42,7 +42,7 @@ pub enum MipsAirId { /// The Poseidon2 Permute chip Poseidon2Permute = 46, /// The AES-128 Encrypt chip - Aes128Encrypt = 47, + Aes128Encrypt = 49, /// The Keccak sponge chip. KeccakSponge = 48, /// The bn254 add assign chip. diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index 4654e71ed..9fa3e1d7d 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -4,7 +4,9 @@ use p3_air::{Air, BaseAir}; use p3_field::FieldAlgebra; use p3_matrix::Matrix; use tempfile::Builder; +use zkm_core_executor::events::AES_128_BLOCK_BYTES; use zkm_stark::{MachineAir, ZKMAirBuilder}; +use crate::air::MemoryAirBuilder; use crate::KeccakSpongeChip; use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; use crate::syscall::precompiles::aes128_encrypt::columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}; @@ -25,14 +27,119 @@ where let local: &AES128EncryptionCols = (*local).borrow(); let next: &AES128EncryptionCols = (*next).borrow(); - let first_round = local.round[0]; - let last_round = local.round[10]; - builder.assert_eq(first_round * local.is_real, local.receive_syscall); + // Constrain memory } } impl AES128EncryptChip { + fn eval_flags( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let first_round = local.round[0]; + let last_round = local.round[10]; + for i in 0..11 { + builder.assert_bool(local.round[i]); + } + builder.assert_bool(local.round_1to9); + builder.assert_eq(first_round * local.is_real, local.receive_syscall); + builder.assert_eq(last_round + first_round + local.round_1to9, local.is_real); + } + fn eval_memory_access( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let mut round = AB::Expr::ZERO; + for i in 0..11 { + round = round + local.round[i] * AB::F::from_canonical_u32(i as u32); + } + + // if this is the first row, populate reading key + for i in 0..4 { + builder.eval_memory_access( + local.shard, + local.clk, + local.key_address + AB::F::from_canonical_u32((i * 4) as u32), + &local.key[i], + local.round[0], + ); + } + + // if this is the first row, populate reading sbox_addr + builder.eval_memory_access( + local.shard, + local.clk, + local.sbox_address, + &local.sbox_addr_read, + local.round[0], + ); + + // if this is the first row, populate reading input + for i in 0..4 { + builder.eval_memory_access( + local.shard, + local.clk, + local.block_address + AB::F::from_canonical_u32((i * 4) as u32), + &local.block[i], + local.round[0], + ); + } + + // if this is the last row, populate writing output + for i in 0..4 { + builder.eval_memory_access( + local.shard, + local.clk + AB::Expr::ONE, + local.block_address + AB::F::from_canonical_u32((i * 4) as u32), + &local.block[i], + local.round[10] + ); + } + + let round_1to10 = local.round[10] + local.round_1to9; + // subs_bytes for state matrix + + for i in 0..AES_128_BLOCK_BYTES { + let index = local.state_matrix[i]; + builder.eval_memory_access( + local.shard, + local.clk, + local.sbox_address + index * AB::F::from_canonical_u8(4), + &local.state_subs_bytes[i], + round_1to10.clone(), + ) + } + + // sbox elements + let round_0to9 = local.round_1to9 + local.round[0]; + let start = round * AB::F::from_canonical_u32(24); + for i in 0..24 { + builder.eval_memory_access( + local.shard, + local.clk, + local.sbox_address + + (start.clone() + AB::F::from_canonical_u32(i as u32)) * AB::F::from_canonical_u8(4), + &local.sbox[i], + round_0to9.clone() + ); + } + for i in 0..16 { + builder.eval_memory_access( + local.shard, + local.clk, + local.sbox_address + + (start.clone() + AB::F::from_canonical_u32(i as u32)) * AB::F::from_canonical_u8(4), + &local.sbox[i], + local.round[10].clone(), + ); + } + + // round key subs bytes + todo!() + } } \ No newline at end of file diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs index 1cc132e73..6063b009e 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs @@ -20,6 +20,7 @@ pub struct AES128EncryptionCols { pub block: [MemoryReadWriteCols; 4], pub sbox: [MemoryReadCols; 24], //24 * 11 = 264 > 256 Sbox elements. pub round: [T; 11], // [0,..10] + pub round_1to9: T, // 1 to 9 pub state_matrix: [T; 16], pub round_key_matrix: [T; 16], pub next_round_key: NextRoundKey, diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index 32f9bfa3a..6462a998d 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -107,6 +107,7 @@ impl AES128EncryptChip { cols.round = [F::ZERO; 11]; cols.round[round] = F::ONE; cols.receive_syscall = F::from_bool(round == 0); + cols.round_1to9 = F::from_bool(round >= 1 && round <= 9); for i in 0..AES_128_BLOCK_BYTES { cols.state_matrix[i] = F::from_canonical_u8(state[i]); @@ -259,8 +260,5 @@ impl AES128EncryptChip { rows.as_mut().unwrap().push(row); } } - if rows.as_ref().is_some() { - log::info!("rows height: {:?}", rows.as_ref().unwrap().len()); - } } } \ No newline at end of file From 58c723b525b319b0c093bf66962f3b74aab88880 Mon Sep 17 00:00:00 2001 From: vanhger Date: Fri, 19 Sep 2025 09:57:41 +0700 Subject: [PATCH 04/12] fix: aes shape bug. --- .../syscall/precompiles/aes128_encrypt/air.rs | 18 ++++++++++++++---- examples/aes128/host/src/main.rs | 6 +++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index 9fa3e1d7d..571888a82 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -5,7 +5,8 @@ use p3_field::FieldAlgebra; use p3_matrix::Matrix; use tempfile::Builder; use zkm_core_executor::events::AES_128_BLOCK_BYTES; -use zkm_stark::{MachineAir, ZKMAirBuilder}; +use zkm_core_executor::syscalls::SyscallCode; +use zkm_stark::{LookupScope, MachineAir, ZKMAirBuilder}; use crate::air::MemoryAirBuilder; use crate::KeccakSpongeChip; use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; @@ -27,9 +28,18 @@ where let local: &AES128EncryptionCols = (*local).borrow(); let next: &AES128EncryptionCols = (*next).borrow(); + self.eval_flags(builder, local); + self.eval_memory_access(builder, local); - - // Constrain memory + builder.receive_syscall( + local.shard, + local.clk, + AB::F::from_canonical_u32(SyscallCode::AES128_ENCRYPT.syscall_id()), + local.block_address, + local.key_address, + local.receive_syscall, + LookupScope::Local, + ); } } @@ -140,6 +150,6 @@ impl AES128EncryptChip { } // round key subs bytes - todo!() + // todo!() } } \ No newline at end of file diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs index 7539a6cac..ebe78a1bf 100644 --- a/examples/aes128/host/src/main.rs +++ b/examples/aes128/host/src/main.rs @@ -52,9 +52,9 @@ fn prove_aes128_rust() { // println!("iv: {:?}", iv); // assert_eq!(expected_output, public_input); // - // // Verify proof and public values - // client.verify(&proof, &vk).expect("verification failed"); - // println!("successfully generated and verified proof for the program!"); + // Verify proof and public values + client.verify(&proof, &vk).expect("verification failed"); + println!("successfully generated and verified proof for the program!"); } fn main() { From 0796bddea05a34ffc4969b826e64c28136aa9a4a Mon Sep 17 00:00:00 2001 From: vanhger Date: Sat, 20 Sep 2025 16:48:46 +0700 Subject: [PATCH 05/12] feat: add some constraint eval functions --- .../syscalls/precompiles/aes128/encrypt.rs | 33 +++-- .../machine/src/operations/aes/aes_mul2.rs | 31 +++- .../machine/src/operations/aes/round_key.rs | 13 +- .../syscall/precompiles/aes128_encrypt/air.rs | 133 +++++++++++++++--- .../precompiles/aes128_encrypt/columns.rs | 1 + .../precompiles/aes128_encrypt/trace.rs | 67 +++++---- examples/aes128/guest/src/main.rs | 2 +- 7 files changed, 200 insertions(+), 80 deletions(-) diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs index 969eac75b..1582b6fc9 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -65,10 +65,10 @@ impl Syscall for AES128EncryptSyscall { key.extend(value.to_le_bytes()); } - // Add Roundkey, Round 0 - for i in 0..state.len() { - state[i] = state[i] ^ key[i]; - } + // // Add Roundkey, Round 0 + // for i in 0..state.len() { + // state[i] = state[i] ^ key[i]; + // } // Read first 24 sbox elements, Round 0 for i in 0..24 { @@ -92,14 +92,14 @@ impl Syscall for AES128EncryptSyscall { ); // Subs_bytes - for i in 0..state.len() { - let (record, value) = rt.mr(sbox_ptr + state[i] as u32 * 4); + for j in 0..state.len() { + let (record, value) = rt.mr(sbox_ptr + state[j] as u32 * 4); sbox_read_records.push(record); assert!(value <= u8::MAX as u32); sbox.push(value as u8); - state[i] = value as u8; + state[j] = value as u8; } - + // Shift row let shift_row = [ state[0], state[5], state[10], state[15], @@ -127,22 +127,22 @@ impl Syscall for AES128EncryptSyscall { shift_row }; - // Add round key - for i in 0..state.len() { - state[i] = mix_columns[i] ^ round_key[i]; - } - + // // Add round key + // for i in 0..state.len() { + // state[i] = mix_columns[i] ^ round_key[i]; + // } + // // Read 24 sbox elements if i != 10 { for j in i * 24..i * 24 + 24 { - let (record, value) = rt.mr(sbox_ptr as u32 + j as u32 * 4); + let (record, value) = rt.mr(sbox_ptr + j as u32 * 4); sbox_read_records.push(record); assert!(value <= u8::MAX as u32); sbox.push(value as u8); } } else { for j in i * 24..256 { - let (record, value) = rt.mr(sbox_ptr as u32 + j as u32 * 4); + let (record, value) = rt.mr(sbox_ptr + j as u32 * 4); sbox_read_records.push(record); assert!(value <= u8::MAX as u32); sbox.push(value as u8); @@ -150,7 +150,6 @@ impl Syscall for AES128EncryptSyscall { } } - // write output // Increment the clk by 1 before writing because we read from memory at start_clk. rt.clk += 1; @@ -225,7 +224,7 @@ impl AES128EncryptSyscall { .collect::>().try_into().unwrap(); let w7: [u8; 4] = w6.iter().zip(w3.iter()).map(|(&a, &b)| a ^ b) .collect::>().try_into().unwrap(); - + previous_key[0..4].copy_from_slice(&w4); previous_key[4..8].copy_from_slice(&w5); previous_key[8..12].copy_from_slice(&w6); diff --git a/crates/core/machine/src/operations/aes/aes_mul2.rs b/crates/core/machine/src/operations/aes/aes_mul2.rs index 925b4c9bc..0f6350e8e 100644 --- a/crates/core/machine/src/operations/aes/aes_mul2.rs +++ b/crates/core/machine/src/operations/aes/aes_mul2.rs @@ -11,6 +11,7 @@ use zkm_stark::{air::ZKMAirBuilder, Word}; pub struct MulBy2InAES { pub and_0x80: T, pub left_shift_1: T, + pub is_xor: T, // 0 or 1 pub xor_0x1b: T // also the result } @@ -18,11 +19,18 @@ impl MulBy2InAES { pub fn populate(&mut self, record: &mut impl ByteRecord, x: u8) -> u8 { let and_0x80 = x & 0x80; let left_shift_1 = x << 1; - let xor_0x1b = if and_0x80 != 0 { left_shift_1 ^ 0x1b } else { left_shift_1 }; + let mut is_xor = 0_u8; + let xor_0x1b = if and_0x80 != 0 { + is_xor = 1; + left_shift_1 ^ 0x1b + } else { + left_shift_1 + }; self.and_0x80 = F::from_canonical_u8(and_0x80); self.left_shift_1 = F::from_canonical_u8(left_shift_1); self.xor_0x1b = F::from_canonical_u8(xor_0x1b); + self.is_xor = F::from_canonical_u8(is_xor); // Byte lookup events let byte_event_and = ByteLookupEvent { @@ -43,7 +51,7 @@ impl MulBy2InAES { }; record.add_byte_lookup_event(byte_event_ssl); - if and_0x80 != 0 { + if is_xor == 1 { let byte_event_xor = ByteLookupEvent { opcode: ByteOpcode::XOR, a1: xor_0x1b as u16, @@ -71,6 +79,7 @@ impl MulBy2InAES { AB::F::from_canonical_u32(0x80), is_real, ); + builder.send_byte( AB::F::from_canonical_u32(ByteOpcode::SLL as u32), cols.left_shift_1, @@ -78,14 +87,26 @@ impl MulBy2InAES { AB::F::from_canonical_u32(1), is_real, ); - builder - .when(cols.and_0x80).inner + + builder.assert_bool(cols.is_xor); + // if cols.is_xor == 1, then and_0x80 == 128, else and_0x80 = 0 + builder.assert_eq( + cols.and_0x80, + cols.is_xor * AB::Expr::from_canonical_u8(128u8) + ); + + builder.assert_eq( + (AB::Expr::ONE- is_real.into()) * cols.is_xor, + AB::Expr::ZERO + ); + + builder .send_byte( AB::F::from_canonical_u32(ByteOpcode::XOR as u32), cols.xor_0x1b, cols.left_shift_1, AB::F::from_canonical_u32(0x1b), - is_real, + cols.is_xor, ); } } diff --git a/crates/core/machine/src/operations/aes/round_key.rs b/crates/core/machine/src/operations/aes/round_key.rs index 8fdc55f94..e1932ca1a 100644 --- a/crates/core/machine/src/operations/aes/round_key.rs +++ b/crates/core/machine/src/operations/aes/round_key.rs @@ -6,14 +6,13 @@ use zkm_derive::AlignedBorrow; use zkm_stark::ZKMAirBuilder; use crate::memory::{MemoryCols, MemoryReadCols}; -pub const ROUND_CONST: [u8; 10] = [ - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, +pub const ROUND_CONST: [u8; 11] = [ + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x00 ]; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct NextRoundKey { - pub round_const: T, pub add_round_const: T, // XOR // new round key pub w4: [T; 4], @@ -45,7 +44,6 @@ impl NextRoundKey { let mut sub_rot_w3 = sbox_values.map(|u| u as u8); let rcon = ROUND_CONST[round as usize]; - self.round_const = F::from_canonical_u8(rcon); let first_byte = sub_rot_w3[0] ^ rcon; self.add_round_const = F::from_canonical_u8(first_byte); let first_byte_xor_event = ByteLookupEvent { @@ -125,7 +123,7 @@ impl NextRoundKey { cols: NextRoundKey, prev_round_key: [AB::Var; 16], sbox_read: &[MemoryReadCols; 4], - round: usize, + rcon: AB::Var, is_real: AB::Var, ) { let w0 = &prev_round_key[0..4]; @@ -134,16 +132,13 @@ impl NextRoundKey { let w3 = &prev_round_key[12..16]; // round const - let rcon = AB::F::from_canonical_u32(ROUND_CONST[round] as u32); - builder.when(is_real).assert_eq(cols.round_const, rcon); - let sbox_values: [AB::Var; 4] = sbox_read.map(|m| m.value().0[0]); builder.send_byte( AB::F::from_canonical_u32(ByteOpcode::XOR as u32), cols.add_round_const, sbox_values[0], - cols.round_const, + rcon, is_real, ); diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index 571888a82..ada130cd9 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -1,14 +1,18 @@ use std::borrow::Borrow; use log::__private_api::loc; -use p3_air::{Air, BaseAir}; +use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::FieldAlgebra; use p3_matrix::Matrix; use tempfile::Builder; +use zkm_core_executor::ByteOpcode; use zkm_core_executor::events::AES_128_BLOCK_BYTES; use zkm_core_executor::syscalls::SyscallCode; use zkm_stark::{LookupScope, MachineAir, ZKMAirBuilder}; use crate::air::MemoryAirBuilder; use crate::KeccakSpongeChip; +use crate::memory::MemoryCols; +use crate::operations::mix_column::MixColumn; +use crate::operations::round_key::{NextRoundKey, ROUND_CONST}; use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; use crate::syscall::precompiles::aes128_encrypt::columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}; @@ -28,9 +32,6 @@ where let local: &AES128EncryptionCols = (*local).borrow(); let next: &AES128EncryptionCols = (*next).borrow(); - self.eval_flags(builder, local); - self.eval_memory_access(builder, local); - builder.receive_syscall( local.shard, local.clk, @@ -40,6 +41,14 @@ where local.receive_syscall, LookupScope::Local, ); + + self.eval_flags(builder, local); + self.eval_memory_access(builder, local); + + self.eval_mix_column(builder, local); + self.eval_compute_round_key(builder, local); + // self.eval_add_round_key(builder, local); + } } @@ -59,6 +68,7 @@ impl AES128EncryptChip { builder.assert_eq(first_round * local.is_real, local.receive_syscall); builder.assert_eq(last_round + first_round + local.round_1to9, local.is_real); } + fn eval_memory_access( &self, builder: &mut AB, @@ -74,32 +84,32 @@ impl AES128EncryptChip { builder.eval_memory_access( local.shard, local.clk, - local.key_address + AB::F::from_canonical_u32((i * 4) as u32), + local.key_address + AB::F::from_canonical_u32((i as u32 * 4)), &local.key[i], local.round[0], ); } - // if this is the first row, populate reading sbox_addr - builder.eval_memory_access( - local.shard, - local.clk, - local.sbox_address, - &local.sbox_addr_read, - local.round[0], - ); - // if this is the first row, populate reading input for i in 0..4 { builder.eval_memory_access( local.shard, local.clk, - local.block_address + AB::F::from_canonical_u32((i * 4) as u32), + local.block_address + AB::F::from_canonical_u32(i as u32 * 4), &local.block[i], local.round[0], ); } + // if this is the first row, populate reading sbox_addr + builder.eval_memory_access( + local.shard, + local.clk, + AB::F::from_canonical_u8(6), + &local.sbox_addr_read, + local.round[0], + ); + // if this is the last row, populate writing output for i in 0..4 { builder.eval_memory_access( @@ -112,8 +122,8 @@ impl AES128EncryptChip { } let round_1to10 = local.round[10] + local.round_1to9; - // subs_bytes for state matrix + // subs_bytes for state matrix for i in 0..AES_128_BLOCK_BYTES { let index = local.state_matrix[i]; builder.eval_memory_access( @@ -132,7 +142,7 @@ impl AES128EncryptChip { builder.eval_memory_access( local.shard, local.clk, - local.sbox_address + local.sbox_address + (start.clone() + AB::F::from_canonical_u32(i as u32)) * AB::F::from_canonical_u8(4), &local.sbox[i], round_0to9.clone() @@ -150,6 +160,93 @@ impl AES128EncryptChip { } // round key subs bytes - // todo!() + let key_id = [13, 14, 15, 12]; + for (i, id) in key_id.iter().enumerate() { + builder.eval_memory_access( + local.shard, + local.clk, + local.sbox_address + local.round_key_matrix[*id as usize] * AB::F::from_canonical_u8(4), + &local.roundkey_subs_bytes[i], + round_0to9.clone(), + ); + } + } + + fn eval_mix_column( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let shifted_state = [ + local.state_subs_bytes[0].value()[0], + local.state_subs_bytes[5].value()[0], + local.state_subs_bytes[10].value()[0], + local.state_subs_bytes[15].value()[0], + local.state_subs_bytes[4].value()[0], + local.state_subs_bytes[9].value()[0], + local.state_subs_bytes[14].value()[0], + local.state_subs_bytes[3].value()[0], + local.state_subs_bytes[8].value()[0], + local.state_subs_bytes[13].value()[0], + local.state_subs_bytes[2].value()[0], + local.state_subs_bytes[7].value()[0], + local.state_subs_bytes[12].value()[0], + local.state_subs_bytes[1].value()[0], + local.state_subs_bytes[6].value()[0], + local.state_subs_bytes[11].value()[0], + ]; + MixColumn::::eval( + builder, + shifted_state, + local.mix_column, + local.round_1to9 + ); + } + + fn eval_compute_round_key( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + NextRoundKey::::eval( + builder, + local.next_round_key, + local.round_key_matrix, + &local.roundkey_subs_bytes, + local.round_const, + local.round[0], + ); + NextRoundKey::::eval( + builder, + local.next_round_key, + local.round_key_matrix, + &local.roundkey_subs_bytes, + local.round_const, + local.round_1to9, + ); + + for i in 0..11 { + builder.when(local.round[i]).assert_eq( + local.round_const, + AB::F::from_canonical_u8(ROUND_CONST[i]) + ); + } + + } + + fn eval_add_round_key( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + for i in 0..AES_128_BLOCK_BYTES { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + local.add_round_key[i], + local.mix_column.xor_byte4s[i].value, + local.round_key_matrix[i], + local.is_real, + ) + } } } \ No newline at end of file diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs index 6063b009e..912146efc 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs @@ -21,6 +21,7 @@ pub struct AES128EncryptionCols { pub sbox: [MemoryReadCols; 24], //24 * 11 = 264 > 256 Sbox elements. pub round: [T; 11], // [0,..10] pub round_1to9: T, // 1 to 9 + pub round_const: T, pub state_matrix: [T; 16], pub round_key_matrix: [T; 16], pub next_round_key: NextRoundKey, diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index 6462a998d..4e54d6e84 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -9,6 +9,7 @@ use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; use zkm_core_executor::{events::{ByteLookupEvent, ByteRecord, PrecompileEvent}, syscalls::SyscallCode, ByteOpcode, ExecutionRecord, Program}; use zkm_core_executor::events::{AES128EncryptEvent, MemoryRecordEnum, AES_128_BLOCK_BYTES, AES_128_BLOCK_U32S}; use zkm_stark::{air::MachineAir}; +use crate::operations::round_key::ROUND_CONST; use crate::syscall::precompiles::aes128_encrypt::columns::AES128EncryptionCols; use super::{columns::NUM_AES128_ENCRYPTION_COLS, AES128EncryptChip, AES_SBOX}; @@ -108,6 +109,7 @@ impl AES128EncryptChip { cols.round[round] = F::ONE; cols.receive_syscall = F::from_bool(round == 0); cols.round_1to9 = F::from_bool(round >= 1 && round <= 9); + cols.round_const = F::from_canonical_u8(ROUND_CONST[round]); for i in 0..AES_128_BLOCK_BYTES { cols.state_matrix[i] = F::from_canonical_u8(state[i]); @@ -134,20 +136,26 @@ impl AES128EncryptChip { event.sbox_addr_memory, blu ); - // compute the add_round_key + + // the mix column value should be the state for i in 0..AES_128_BLOCK_BYTES { - let tmp = state[i] ^ round_key[i]; - cols.add_round_key[i] = F::from_canonical_u8(state[i]); - let byte_lookup_event = ByteLookupEvent { - opcode: ByteOpcode::XOR, - a1: tmp as u16, - a2: 0, - b: state[i], - c: round_key[i], - }; - blu.add_byte_lookup_event(byte_lookup_event); - state[i] = tmp; + cols.mix_column.xor_byte4s[i].value = F::from_canonical_u8(state[i]); } + + // add_round_key + // for i in 0..AES_128_BLOCK_BYTES { + // let tmp = state[i] ^ round_key[i]; + // cols.add_round_key[i] = F::from_canonical_u8(tmp); + // let byte_lookup_event = ByteLookupEvent { + // opcode: ByteOpcode::XOR, + // a1: tmp as u16, + // a2: 0, + // b: state[i], + // c: round_key[i], + // }; + // blu.add_byte_lookup_event(byte_lookup_event); + // state[i] = tmp; + // } } else { // subs_bytes @@ -178,19 +186,19 @@ impl AES128EncryptChip { shifted_row }; - // Add round key - for i in 0..AES_128_BLOCK_BYTES { - state[i] = mixed_columns[i] ^ round_key[i]; - cols.add_round_key[i] = F::from_canonical_u8(state[i]); - let byte_lookup_event = ByteLookupEvent { - opcode: ByteOpcode::XOR, - a1: state[i] as u16, - a2: 0, - b: mixed_columns[i], - c: round_key[i], - }; - blu.add_byte_lookup_event(byte_lookup_event); - } + // // Add round key + // for i in 0..AES_128_BLOCK_BYTES { + // state[i] = mixed_columns[i] ^ round_key[i]; + // cols.add_round_key[i] = F::from_canonical_u8(state[i]); + // let byte_lookup_event = ByteLookupEvent { + // opcode: ByteOpcode::XOR, + // a1: state[i] as u16, + // a2: 0, + // b: mixed_columns[i], + // c: round_key[i], + // }; + // blu.add_byte_lookup_event(byte_lookup_event); + // } } if round != 10 { @@ -222,8 +230,8 @@ impl AES128EncryptChip { ); } sbox_read_index += 4; - - round_key = next_round_key; + // + // round_key = next_round_key; } else { for i in sbox_read_index..(sbox_read_index + 16) { @@ -233,12 +241,12 @@ impl AES128EncryptChip { ) } sbox_read_index += 16; - assert_eq!(sbox_read_index, 456); + // assert_eq!(sbox_read_index, 456); for i in 0..4 { // check output let tmp = event.output_write_records[i].value.to_le_bytes(); - for j in 0..4 { + for _ in 0..4 { assert_eq!(state[i * 4], tmp[0]); assert_eq!(state[i * 4 + 1], tmp[1]); assert_eq!(state[i * 4 + 2], tmp[2]); @@ -253,7 +261,6 @@ impl AES128EncryptChip { blu ); } - } if rows.as_ref().is_some() { diff --git a/examples/aes128/guest/src/main.rs b/examples/aes128/guest/src/main.rs index dce3131d0..540e1811b 100644 --- a/examples/aes128/guest/src/main.rs +++ b/examples/aes128/guest/src/main.rs @@ -23,7 +23,7 @@ pub fn main() { let key_array: [u8; 16] = key.as_slice().try_into().unwrap(); let iv_array: [u8; 16] = iv.as_slice().try_into().unwrap(); let output = cipher_block_chaining(&plain_text, &key_array, &iv_array); - assert_eq!(expected_output, output.to_vec()); + // assert_eq!(expected_output, output.to_vec()); zkm_zkvm::io::commit::>(&output.to_vec()); } From c2408df82ec69246bc132a9d6fdb4f2b24832ea5 Mon Sep 17 00:00:00 2001 From: vanhger Date: Sat, 20 Sep 2025 21:26:16 +0700 Subject: [PATCH 06/12] feat: add full main constraints --- .../syscalls/precompiles/aes128/encrypt.rs | 18 ++--- .../syscall/precompiles/aes128_encrypt/air.rs | 13 ++-- .../precompiles/aes128_encrypt/trace.rs | 76 +++++++++---------- 3 files changed, 54 insertions(+), 53 deletions(-) diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs index 1582b6fc9..8c86df302 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -66,9 +66,9 @@ impl Syscall for AES128EncryptSyscall { } // // Add Roundkey, Round 0 - // for i in 0..state.len() { - // state[i] = state[i] ^ key[i]; - // } + for i in 0..state.len() { + state[i] = state[i] ^ key[i]; + } // Read first 24 sbox elements, Round 0 for i in 0..24 { @@ -127,11 +127,11 @@ impl Syscall for AES128EncryptSyscall { shift_row }; - // // Add round key - // for i in 0..state.len() { - // state[i] = mix_columns[i] ^ round_key[i]; - // } - // + // Add round key + for j in 0..state.len() { + state[j] = mix_columns[j] ^ round_key[j]; + } + // Read 24 sbox elements if i != 10 { for j in i * 24..i * 24 + 24 { @@ -224,7 +224,7 @@ impl AES128EncryptSyscall { .collect::>().try_into().unwrap(); let w7: [u8; 4] = w6.iter().zip(w3.iter()).map(|(&a, &b)| a ^ b) .collect::>().try_into().unwrap(); - + previous_key[0..4].copy_from_slice(&w4); previous_key[4..8].copy_from_slice(&w5); previous_key[8..12].copy_from_slice(&w6); diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index ada130cd9..03a8fd667 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -47,7 +47,7 @@ where self.eval_mix_column(builder, local); self.eval_compute_round_key(builder, local); - // self.eval_add_round_key(builder, local); + self.eval_add_round_key(builder, local); } } @@ -109,7 +109,7 @@ impl AES128EncryptChip { &local.sbox_addr_read, local.round[0], ); - + // if this is the last row, populate writing output for i in 0..4 { builder.eval_memory_access( @@ -122,7 +122,7 @@ impl AES128EncryptChip { } let round_1to10 = local.round[10] + local.round_1to9; - + // subs_bytes for state matrix for i in 0..AES_128_BLOCK_BYTES { let index = local.state_matrix[i]; @@ -134,7 +134,7 @@ impl AES128EncryptChip { round_1to10.clone(), ) } - + // sbox elements let round_0to9 = local.round_1to9 + local.round[0]; let start = round * AB::F::from_canonical_u32(24); @@ -158,7 +158,7 @@ impl AES128EncryptChip { local.round[10].clone(), ); } - + // round key subs bytes let key_id = [13, 14, 15, 12]; for (i, id) in key_id.iter().enumerate() { @@ -239,13 +239,14 @@ impl AES128EncryptChip { builder: &mut AB, local: &AES128EncryptionCols, ) { + for i in 0..AES_128_BLOCK_BYTES { builder.send_byte( AB::F::from_canonical_u32(ByteOpcode::XOR as u32), local.add_round_key[i], local.mix_column.xor_byte4s[i].value, local.round_key_matrix[i], - local.is_real, + local.is_real ) } } diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index 4e54d6e84..ded9b8f37 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -143,19 +143,19 @@ impl AES128EncryptChip { } // add_round_key - // for i in 0..AES_128_BLOCK_BYTES { - // let tmp = state[i] ^ round_key[i]; - // cols.add_round_key[i] = F::from_canonical_u8(tmp); - // let byte_lookup_event = ByteLookupEvent { - // opcode: ByteOpcode::XOR, - // a1: tmp as u16, - // a2: 0, - // b: state[i], - // c: round_key[i], - // }; - // blu.add_byte_lookup_event(byte_lookup_event); - // state[i] = tmp; - // } + for i in 0..AES_128_BLOCK_BYTES { + let tmp = state[i] ^ round_key[i]; + cols.add_round_key[i] = F::from_canonical_u8(tmp); + let byte_lookup_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: tmp as u16, + a2: 0, + b: state[i], + c: round_key[i], + }; + blu.add_byte_lookup_event(byte_lookup_event); + state[i] = tmp; + } } else { // subs_bytes @@ -183,22 +183,25 @@ impl AES128EncryptChip { &shifted_row ) } else { + for i in 0..AES_128_BLOCK_BYTES { + cols.mix_column.xor_byte4s[i].value = F::from_canonical_u8(shifted_row[i]); + } shifted_row }; - // // Add round key - // for i in 0..AES_128_BLOCK_BYTES { - // state[i] = mixed_columns[i] ^ round_key[i]; - // cols.add_round_key[i] = F::from_canonical_u8(state[i]); - // let byte_lookup_event = ByteLookupEvent { - // opcode: ByteOpcode::XOR, - // a1: state[i] as u16, - // a2: 0, - // b: mixed_columns[i], - // c: round_key[i], - // }; - // blu.add_byte_lookup_event(byte_lookup_event); - // } + // Add round key + for i in 0..AES_128_BLOCK_BYTES { + state[i] = mixed_columns[i] ^ round_key[i]; + cols.add_round_key[i] = F::from_canonical_u8(state[i]); + let byte_lookup_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: state[i] as u16, + a2: 0, + b: mixed_columns[i], + c: round_key[i], + }; + blu.add_byte_lookup_event(byte_lookup_event); + } } if round != 10 { @@ -211,6 +214,13 @@ impl AES128EncryptChip { } sbox_read_index += 24; + // read the round key byte subs + for i in 0..4 { + cols.roundkey_subs_bytes[i].populate( + event.sbox_read_records[sbox_read_index + i], + blu + ); + } // compute next round key let next_round_key = cols.next_round_key.populate( @@ -221,17 +231,8 @@ impl AES128EncryptChip { .expect("Slice length must be exactly 4"), round as u8, ); - - // read the round key byte subs - for i in 0..4 { - cols.roundkey_subs_bytes[i].populate( - event.sbox_read_records[sbox_read_index + i], - blu - ); - } sbox_read_index += 4; - // - // round_key = next_round_key; + round_key = next_round_key; } else { for i in sbox_read_index..(sbox_read_index + 16) { @@ -241,7 +242,7 @@ impl AES128EncryptChip { ) } sbox_read_index += 16; - // assert_eq!(sbox_read_index, 456); + assert_eq!(sbox_read_index, 456); for i in 0..4 { // check output @@ -253,7 +254,6 @@ impl AES128EncryptChip { assert_eq!(state[i * 4 + 3], tmp[3]); } } - // write output for i in 0..AES_128_BLOCK_U32S { cols.block[i].populate( From e593d0239d05a15b9b925ba74f46b640fc596ded Mon Sep 17 00:00:00 2001 From: vanhger Date: Mon, 22 Sep 2025 11:21:41 +0700 Subject: [PATCH 07/12] feat: add constraints --- .../machine/src/operations/aes/round_key.rs | 2 +- .../syscall/precompiles/aes128_encrypt/air.rs | 194 ++++++++++++++++-- .../syscall/precompiles/aes128_encrypt/mod.rs | 2 +- .../precompiles/aes128_encrypt/trace.rs | 11 +- examples/aes128/guest/src/main.rs | 2 +- examples/aes128/host/src/main.rs | 37 ++-- 6 files changed, 202 insertions(+), 46 deletions(-) diff --git a/crates/core/machine/src/operations/aes/round_key.rs b/crates/core/machine/src/operations/aes/round_key.rs index e1932ca1a..ba91de44c 100644 --- a/crates/core/machine/src/operations/aes/round_key.rs +++ b/crates/core/machine/src/operations/aes/round_key.rs @@ -131,7 +131,7 @@ impl NextRoundKey { let w2 = &prev_round_key[8..12]; let w3 = &prev_round_key[12..16]; - // round const + // sbox substitution bytes. let sbox_values: [AB::Var; 4] = sbox_read.map(|m| m.value().0[0]); builder.send_byte( diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index 03a8fd667..da7e52c9c 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -7,13 +7,13 @@ use tempfile::Builder; use zkm_core_executor::ByteOpcode; use zkm_core_executor::events::AES_128_BLOCK_BYTES; use zkm_core_executor::syscalls::SyscallCode; -use zkm_stark::{LookupScope, MachineAir, ZKMAirBuilder}; -use crate::air::MemoryAirBuilder; +use zkm_stark::{ByteAirBuilder, LookupScope, MachineAir, ZKMAirBuilder}; +use crate::air::{MemoryAirBuilder, WordAirBuilder}; use crate::KeccakSpongeChip; use crate::memory::MemoryCols; use crate::operations::mix_column::MixColumn; use crate::operations::round_key::{NextRoundKey, ROUND_CONST}; -use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; +use crate::syscall::precompiles::aes128_encrypt::{AES128EncryptChip, AES_SBOX}; use crate::syscall::precompiles::aes128_encrypt::columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}; impl BaseAir for AES128EncryptChip { @@ -44,11 +44,12 @@ where self.eval_flags(builder, local); self.eval_memory_access(builder, local); - self.eval_mix_column(builder, local); self.eval_compute_round_key(builder, local); self.eval_add_round_key(builder, local); - + self.eval_input_output(builder, local); + self.eval_sbox_values(builder, local); + self.eval_transition(builder, local, next); } } @@ -64,7 +65,11 @@ impl AES128EncryptChip { builder.assert_bool(local.round[i]); } builder.assert_bool(local.round_1to9); - + let mut computed_1to9 = AB::Expr::ZERO; + for i in 1..10 { + computed_1to9 = computed_1to9 + local.round[i]; + } + builder.assert_eq(computed_1to9, local.round_1to9); builder.assert_eq(first_round * local.is_real, local.receive_syscall); builder.assert_eq(last_round + first_round + local.round_1to9, local.is_real); } @@ -84,7 +89,7 @@ impl AES128EncryptChip { builder.eval_memory_access( local.shard, local.clk, - local.key_address + AB::F::from_canonical_u32((i as u32 * 4)), + local.key_address + AB::F::from_canonical_u32(i as u32 * 4), &local.key[i], local.round[0], ); @@ -122,8 +127,9 @@ impl AES128EncryptChip { } let round_1to10 = local.round[10] + local.round_1to9; + let round_0to9 = local.round_1to9 + local.round[0]; - // subs_bytes for state matrix + // subs bytes for state matrix for i in 0..AES_128_BLOCK_BYTES { let index = local.state_matrix[i]; builder.eval_memory_access( @@ -135,8 +141,19 @@ impl AES128EncryptChip { ) } + // subs bytes for round key computation + let key_id = [13_usize, 14, 15, 12]; + for (i, id) in key_id.iter().enumerate() { + builder.eval_memory_access( + local.shard, + local.clk, + local.sbox_address + local.round_key_matrix[*id] * AB::F::from_canonical_u8(4), + &local.roundkey_subs_bytes[i], + round_0to9.clone(), + ); + } + // sbox elements - let round_0to9 = local.round_1to9 + local.round[0]; let start = round * AB::F::from_canonical_u32(24); for i in 0..24 { builder.eval_memory_access( @@ -159,17 +176,6 @@ impl AES128EncryptChip { ); } - // round key subs bytes - let key_id = [13, 14, 15, 12]; - for (i, id) in key_id.iter().enumerate() { - builder.eval_memory_access( - local.shard, - local.clk, - local.sbox_address + local.round_key_matrix[*id as usize] * AB::F::from_canonical_u8(4), - &local.roundkey_subs_bytes[i], - round_0to9.clone(), - ); - } } fn eval_mix_column( @@ -239,7 +245,6 @@ impl AES128EncryptChip { builder: &mut AB, local: &AES128EncryptionCols, ) { - for i in 0..AES_128_BLOCK_BYTES { builder.send_byte( AB::F::from_canonical_u32(ByteOpcode::XOR as u32), @@ -250,4 +255,151 @@ impl AES128EncryptChip { ) } } + + fn eval_input_output( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + // In round 0, all state matrix values and key values are in [0, 255] + for i in 0..AES_128_BLOCK_BYTES { + builder.send_byte( + AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8), + AB::Expr::ZERO, + AB::Expr::ZERO, + local.state_matrix[i], + local.round[0], + ); + + builder.send_byte( + AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8), + AB::Expr::ZERO, + AB::Expr::ZERO, + local.round_key_matrix[i], + local.round[0], + ); + } + + // In round 0, state and key matrix should be derived from cols.block and cols.key + for i in 0..4 { + for j in 0..4 { + let idx = i * 4 + j; + builder.when(local.round[0]).assert_eq( + local.state_matrix[idx], + local.block[i].access.value[j] + ); + builder.when(local.round[0]).assert_eq( + local.round_key_matrix[idx], + (local.key[i].access.value[j]) + ); + } + } + + // In round 1-9, block should remain the same + for i in 0..4 { + builder.when(local.round_1to9).assert_word_eq( + *local.block[i].prev_value(), + *local.block[i].value() + ); + } + + // In round 10, output block should be derived from state matrix + for i in 0..4 { + for j in 0..4 { + let idx = i * 4 + j; + builder.when(local.round[10]).assert_eq( + local.block[i].access.value[j], + local.add_round_key[idx] + ); + } + } + } + + fn eval_sbox_values( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + // sbox values are true + for i in 0..24 { + for round in 0..10 { + builder.when(local.round[round]).assert_eq( + local.sbox[i].access.value[0], + AB::Expr::from_canonical_u8(AES_SBOX[round * 24 + i]) + ); + } + } + + for i in 0..16 { + builder.when(local.round[10]).assert_eq( + local.sbox[i].access.value[0], + AB::Expr::from_canonical_u8(AES_SBOX[240 + i]) + ); + } + + for i in 0..24 { + for j in 1..4 { + builder.assert_eq( + local.sbox[i].access.value[j], + AB::Expr::ZERO + ); + } + } + } + + fn eval_transition( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + next: &AES128EncryptionCols, + ) { + // if it's not the last round, shard, clk remain the same + let round_0to9 = local.round_1to9 + local.round[0]; + builder.when(round_0to9.clone()).assert_eq(next.shard, local.shard); + builder.when(round_0to9.clone()).assert_eq(next.clk, local.clk); + + // key_address, block_address, sbox_address remain the same + builder.when(round_0to9.clone()).assert_eq(next.key_address, local.key_address); + builder.when(round_0to9.clone()).assert_eq(next.block_address, local.block_address); + builder.when(round_0to9.clone()).assert_eq(next.sbox_address, local.sbox_address); + + // round transition + for i in 0..10 { + builder.when(round_0to9.clone()).assert_eq( + local.round[i], + next.round[i + 1] + ); + } + + // state transition + for i in 0..AES_128_BLOCK_BYTES { + builder.when(round_0to9.clone()).assert_eq( + local.add_round_key[i], + next.state_matrix[i] + ); + } + + // round key transition + for i in 0..4 { + builder.when(round_0to9.clone()).assert_eq( + local.next_round_key.w4[i], + next.round_key_matrix[i] + ); + + builder.when(round_0to9.clone()).assert_eq( + local.next_round_key.w5[i], + next.round_key_matrix[i + 4] + ); + + builder.when(round_0to9.clone()).assert_eq( + local.next_round_key.w6[i], + next.round_key_matrix[i + 8] + ); + + builder.when(round_0to9.clone()).assert_eq( + local.next_round_key.w7[i], + next.round_key_matrix[i + 12] + ); + } + } } \ No newline at end of file diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs index 6d09b3642..30f9b8d5c 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs @@ -2,7 +2,7 @@ mod air; mod columns; mod trace; -pub const AES_SBOX: [u32; 256] = [ +pub const AES_SBOX: [u8; 256] = [ 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index ded9b8f37..445507323 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -11,7 +11,7 @@ use zkm_core_executor::events::{AES128EncryptEvent, MemoryRecordEnum, AES_128_BL use zkm_stark::{air::MachineAir}; use crate::operations::round_key::ROUND_CONST; use crate::syscall::precompiles::aes128_encrypt::columns::AES128EncryptionCols; -use super::{columns::NUM_AES128_ENCRYPTION_COLS, AES128EncryptChip, AES_SBOX}; +use super::{columns::NUM_AES128_ENCRYPTION_COLS, AES128EncryptChip}; impl MachineAir for AES128EncryptChip { type Record = ExecutionRecord; @@ -115,6 +115,13 @@ impl AES128EncryptChip { cols.state_matrix[i] = F::from_canonical_u8(state[i]); cols.round_key_matrix[i] = F::from_canonical_u8(round_key[i]); } + // all the state matrix values and key should be in [0, 255] in round 0 + if round == 0 { + for i in 0..AES_128_BLOCK_BYTES { + blu.add_u8_range_check(0, state[i]); + blu.add_u8_range_check(0, round_key[i]); + } + } if round == 0 { // read the input @@ -203,7 +210,7 @@ impl AES128EncryptChip { blu.add_byte_lookup_event(byte_lookup_event); } } - + if round != 10 { // read 24 sbox elements for each, except the last round for i in sbox_read_index..sbox_read_index + 24 { diff --git a/examples/aes128/guest/src/main.rs b/examples/aes128/guest/src/main.rs index 540e1811b..dce3131d0 100644 --- a/examples/aes128/guest/src/main.rs +++ b/examples/aes128/guest/src/main.rs @@ -23,7 +23,7 @@ pub fn main() { let key_array: [u8; 16] = key.as_slice().try_into().unwrap(); let iv_array: [u8; 16] = iv.as_slice().try_into().unwrap(); let output = cipher_block_chaining(&plain_text, &key_array, &iv_array); - // assert_eq!(expected_output, output.to_vec()); + assert_eq!(expected_output, output.to_vec()); zkm_zkvm::io::commit::>(&output.to_vec()); } diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs index ebe78a1bf..31dac0fe2 100644 --- a/examples/aes128/host/src/main.rs +++ b/examples/aes128/host/src/main.rs @@ -8,19 +8,16 @@ fn prove_aes128_rust() { // load input let plain_text = vec![ - // 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, - // 97, 244, 57, 25, 124, 234, 234, 234, 214, 134, 135, 246, 17, 29, 7 - 0; 16 + 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, + 97, 244, 57, 25, 124, 234, 234, 234, 214, 134, 135, 246, 17, 29, 7 ]; let key = vec![0_u8; 16]; let iv = vec![0_u8; 16]; - // let expected_output = vec![ - // 97_u8, 203, 140, 117, 36, 211, 41, 97, - // 177, 36, 93, 148, 107, 228, 201, 129 - // ]; - - let expected_output = vec![102_u8, 233, 75, 212, 239, 138, 44, 59, 136, 76, 250, 89, 202, 52, 43, 46]; + let expected_output = vec![ + 97_u8, 203, 140, 117, 36, 211, 41, 97, + 177, 36, 93, 148, 107, 228, 201, 129 + ]; stdin.write(&plain_text); stdin.write(&key); @@ -38,20 +35,20 @@ fn prove_aes128_rust() { let (pk, vk) = client.setup(ELF); let mut proof = client.prove(&pk, stdin).run().unwrap(); println!("generated proof"); - // - // // Read and verify the output. - // // - // // Note that this output is read from values committed to in the program using - // // `zkm_zkvm::io::commit`. - // let plain_text = proof.public_values.read::>(); - // let key = proof.public_values.read::>(); - // let iv = proof.public_values.read::>(); - // let public_input = proof.public_values.read::>(); + + // Read and verify the output. + // + // Note that this output is read from values committed to in the program using + // `zkm_zkvm::io::commit`. + let _plain_text = proof.public_values.read::>(); + let _key = proof.public_values.read::>(); + let _iv = proof.public_values.read::>(); + let public_input = proof.public_values.read::>(); // println!("plaintext: {:?}", plain_text); // println!("key: {:?}", key); // println!("iv: {:?}", iv); - // assert_eq!(expected_output, public_input); - // + assert_eq!(expected_output, public_input); + // Verify proof and public values client.verify(&proof, &vk).expect("verification failed"); println!("successfully generated and verified proof for the program!"); From d20a49a2de6cf332bacb511beb4f4a7e81699f7d Mon Sep 17 00:00:00 2001 From: vanhger Date: Mon, 22 Sep 2025 17:31:57 +0700 Subject: [PATCH 08/12] feat: add subs byte op --- crates/core/machine/src/operations/aes/mod.rs | 3 +- .../machine/src/operations/aes/subs_byte.rs | 120 ++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 crates/core/machine/src/operations/aes/subs_byte.rs diff --git a/crates/core/machine/src/operations/aes/mod.rs b/crates/core/machine/src/operations/aes/mod.rs index 36c0b8253..4b76b2794 100644 --- a/crates/core/machine/src/operations/aes/mod.rs +++ b/crates/core/machine/src/operations/aes/mod.rs @@ -2,4 +2,5 @@ pub mod aes_mul2; pub mod aes_mul3; pub mod mix_column; pub mod xor_byte_4; -pub mod round_key; \ No newline at end of file +pub mod round_key; +pub mod subs_byte; \ No newline at end of file diff --git a/crates/core/machine/src/operations/aes/subs_byte.rs b/crates/core/machine/src/operations/aes/subs_byte.rs new file mode 100644 index 000000000..7d33b9151 --- /dev/null +++ b/crates/core/machine/src/operations/aes/subs_byte.rs @@ -0,0 +1,120 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::ByteOpcode; +use zkm_core_executor::events::{ByteLookupEvent, ByteRecord, MemoryReadRecord}; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; +use crate::memory::{MemoryCols, MemoryReadCols}; +pub const AES_SBOX: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct SubsByte { + pub positions: [[T; 32]; 4], + // if this byte is in range [0..=127] + pub is_left: T, + // the substituted byte + pub value: T, +} + +impl SubsByte { + pub fn populate( + &mut self, + byte: u8, + ) -> u8 { + let mut pos = byte; + if byte > 127 { + self.is_left = F::ZERO; + pos -= 128; + } else { + self.is_left = F::ONE; + } + + for i in 0..128 { + let row = i / 4; + let col = i % 4; + if i as u8 == pos { + self.positions[row][col] = F::ONE; + } else { + self.positions[row][col] = F::ZERO; + } + } + + let substituted = AES_SBOX[byte as usize]; + self.value = F::from_canonical_u8(substituted); + substituted + } + + pub fn eval( + builder: &mut AB, + cols: SubsByte, + byte: AB::Var, + is_real: AB::Var, + ) { + + builder.assert_bool(cols.is_left); + builder.assert_bool(is_real); + // if is_real = 0 then is_left must be 0 + builder.assert_eq( + (AB::Expr::ONE - is_real.into()) * cols.is_left, + AB::Expr::ZERO, + ); + + // exactly one position is selected + let sum_positions = cols.positions.iter().map(|&pos| { + pos.iter().map(|&p| p.into()).sum::() + }).sum::(); + builder.assert_eq( + sum_positions, + is_real.into(), + ); + + for i in 0..128 { + // positions are boolean + let row = i / 4; + let col = i % 4; + builder.assert_bool(cols.positions[row][col]); + builder.assert_eq( + (AB::Expr::ONE - is_real.into()) * cols.positions[row][col], + AB::Expr::ZERO, + ); + // if is_left = 1 then byte = i else byte = i+128 + builder.assert_eq( + cols.is_left * (AB::Expr::from_canonical_usize(i) - byte.clone()) * cols.positions[row][col].clone(), + AB::Expr::ZERO, + ); + builder.assert_eq( + (AB::Expr::ONE - cols.is_left.clone()) * (AB::Expr::from_canonical_usize(i + 128) - byte.clone()) * cols.positions[row][col].clone(), + AB::Expr::ZERO, + ); + + // value = SBOX[byte] + builder.assert_eq( + cols.is_left * (AB::Expr::from_canonical_u8(AES_SBOX[i]) - cols.value.clone()) * cols.positions[row][col].clone(), + AB::Expr::ZERO, + ); + builder.assert_eq( + (AB::Expr::ONE - cols.is_left.clone()) + * (AB::Expr::from_canonical_u8(AES_SBOX[i + 128]) - cols.value.clone()) + * cols.positions[row][col].clone(), + AB::Expr::ZERO, + ); + } + } +} \ No newline at end of file From d5db7c1ad356513455cdd376b90ff50b98296e88 Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 23 Sep 2025 11:03:57 +0700 Subject: [PATCH 09/12] feat: redesign. --- .../executor/src/artifacts/mips_costs.json | 1 + .../executor/src/events/precompiles/aes128.rs | 8 - .../syscalls/precompiles/aes128/encrypt.rs | 114 ++++------- crates/core/machine/src/mips/mod.rs | 1 - .../machine/src/operations/aes/aes_mul2.rs | 2 +- .../machine/src/operations/aes/aes_mul3.rs | 2 +- .../machine/src/operations/aes/mix_column.rs | 7 +- .../machine/src/operations/aes/round_key.rs | 47 +++-- .../machine/src/operations/aes/subs_byte.rs | 25 ++- .../syscall/precompiles/aes128_encrypt/air.rs | 192 +++++------------- .../precompiles/aes128_encrypt/columns.rs | 11 +- .../syscall/precompiles/aes128_encrypt/mod.rs | 19 -- .../precompiles/aes128_encrypt/trace.rs | 58 +----- crates/stark/src/opts.rs | 2 +- crates/zkvm/entrypoint/src/syscalls/aes128.rs | 3 +- crates/zkvm/lib/src/aes128.rs | 21 +- crates/zkvm/lib/src/lib.rs | 2 +- examples/aes128/host/src/main.rs | 2 +- 18 files changed, 156 insertions(+), 361 deletions(-) diff --git a/crates/core/executor/src/artifacts/mips_costs.json b/crates/core/executor/src/artifacts/mips_costs.json index 2beecc80f..ea53f0afe 100644 --- a/crates/core/executor/src/artifacts/mips_costs.json +++ b/crates/core/executor/src/artifacts/mips_costs.json @@ -9,6 +9,7 @@ "Secp256r1Decompress": 2686, "Secp256k1Decompress": 2686, "KeccakSponge": 102216, + "Aes128Encrypt": 38764, "Bn254AddAssign": 4013, "Bitwise": 42, "ShiftLeft": 68, diff --git a/crates/core/executor/src/events/precompiles/aes128.rs b/crates/core/executor/src/events/precompiles/aes128.rs index 413702661..ab2ef5826 100644 --- a/crates/core/executor/src/events/precompiles/aes128.rs +++ b/crates/core/executor/src/events/precompiles/aes128.rs @@ -21,24 +21,16 @@ pub struct AES128EncryptEvent { pub block_addr: u32, /// The address of the key pub key_addr: u32, - /// The address of sbox - pub sbox_addr: u32, - /// The memory records for sbox address - pub sbox_addr_memory: MemoryReadRecord, /// The input block as a [u32; AES_128_BLOCK_U32S] words. pub input: [u32; AES_128_BLOCK_U32S], /// The key as a [u32; AES_128_BLOCK_U32S] words. pub key: [u32; AES_128_BLOCK_U32S], /// The output block as a [u32; AES_128_BLOCK_U32S] words. pub output: [u32; AES_128_BLOCK_U32S], - /// The sbox reads - pub sbox_reads: Vec, /// The memory records for the input pub input_read_records: [MemoryReadRecord; AES_128_BLOCK_U32S], /// The memory records for the key pub key_read_records: [MemoryReadRecord; AES_128_BLOCK_U32S], - /// The memory records for reading sbox - pub sbox_read_records: Vec, /// The memory records for the output pub output_write_records: [MemoryWriteRecord; AES_128_BLOCK_U32S], /// The local memory access records. diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs index 8c86df302..2661e8116 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -1,6 +1,4 @@ -use log::info; -use crate::events::{AES128EncryptEvent, MemoryReadRecord, PrecompileEvent, AES_128_BLOCK_U32S}; -use crate::Register::A2; +use crate::events::{AES128EncryptEvent, PrecompileEvent, AES_128_BLOCK_U32S}; use crate::syscalls::{Syscall, SyscallCode, SyscallContext}; use crate::syscalls::precompiles::aes128::utils::mul_md5; @@ -19,6 +17,25 @@ pub const AES128_RCON: [[u8; 4]; 10] = [ [0x36, 0x00, 0x00, 0x00], ]; +pub const AES_SBOX: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + impl Syscall for AES128EncryptSyscall { fn num_extra_cycles(&self) -> u32 { 1 @@ -36,7 +53,6 @@ impl Syscall for AES128EncryptSyscall { let mut input_read_records = Vec::new(); let mut key_read_records = Vec::new(); - let mut sbox_read_records = Vec::new(); let mut output_write_records = Vec::new(); let mut input = Vec::new(); @@ -44,10 +60,7 @@ impl Syscall for AES128EncryptSyscall { let mut state = Vec::new(); let mut key = Vec::new(); let mut output = Vec::new(); - let mut sbox: Vec = Vec::new(); - // read sbox ptr - let (sbox_ptr_memory, sbox_ptr) = rt.mr(A2 as u32); // read block input for i in 0..AES_128_BLOCK_U32S { @@ -70,33 +83,19 @@ impl Syscall for AES128EncryptSyscall { state[i] = state[i] ^ key[i]; } - // Read first 24 sbox elements, Round 0 - for i in 0..24 { - let (record, value) = rt.mr(sbox_ptr + i as u32 * 4); - sbox_read_records.push(record); - assert!(value <= u8::MAX as u32); - sbox.push(value as u8); - } - // perform AES let mut round_key = key; for i in 1..11 { // compute round key Self::compute_round_key( - rt, &mut round_key, - &mut sbox_read_records, - &mut sbox, - sbox_ptr, i - 1 ); // Subs_bytes for j in 0..state.len() { - let (record, value) = rt.mr(sbox_ptr + state[j] as u32 * 4); - sbox_read_records.push(record); - assert!(value <= u8::MAX as u32); - sbox.push(value as u8); + assert!(state[j] <= u8::MAX); + let value = AES_SBOX[state[j] as usize]; state[j] = value as u8; } @@ -131,30 +130,12 @@ impl Syscall for AES128EncryptSyscall { for j in 0..state.len() { state[j] = mix_columns[j] ^ round_key[j]; } - - // Read 24 sbox elements - if i != 10 { - for j in i * 24..i * 24 + 24 { - let (record, value) = rt.mr(sbox_ptr + j as u32 * 4); - sbox_read_records.push(record); - assert!(value <= u8::MAX as u32); - sbox.push(value as u8); - } - } else { - for j in i * 24..256 { - let (record, value) = rt.mr(sbox_ptr + j as u32 * 4); - sbox_read_records.push(record); - assert!(value <= u8::MAX as u32); - sbox.push(value as u8); - } - } } // write output // Increment the clk by 1 before writing because we read from memory at start_clk. rt.clk += 1; assert_eq!(state.len(), 16); - log::info!("AES128 Encrypt output: {:?}", state); for chunk in state.chunks(4) { let value = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); output.push(value); @@ -169,15 +150,11 @@ impl Syscall for AES128EncryptSyscall { clk: start_clk, block_addr: block_ptr, key_addr: key_ptr, - sbox_addr: sbox_ptr, - sbox_addr_memory: sbox_ptr_memory, input: input.as_slice().try_into().unwrap(), key: key_u32s.as_slice().try_into().unwrap(), output: output.as_slice().try_into().unwrap(), - sbox_reads: sbox, input_read_records: input_read_records.as_slice().try_into().unwrap(), key_read_records: key_read_records.as_slice().try_into().unwrap(), - sbox_read_records, output_write_records: output_write_records.as_slice().try_into().unwrap(), local_mem_access: rt.postprocess(), }); @@ -190,11 +167,7 @@ impl Syscall for AES128EncryptSyscall { impl AES128EncryptSyscall { fn compute_round_key( - rt: &mut SyscallContext, previous_key: &mut [u8], - sbox_records: &mut Vec, - sbox: &mut Vec, - sbox_ptr: u32, round: usize ) { if previous_key.len() != 16 { @@ -204,30 +177,29 @@ impl AES128EncryptSyscall { let g_w3 = { let mut result = [previous_key[13], previous_key[14], previous_key[15], previous_key[12]]; for (i, rcon) in AES128_RCON[round].iter().enumerate() { - let (record, value) = rt.mr(sbox_ptr + result[i] as u32 * 4); - sbox_records.push(record); - assert!(value <= u8::MAX as u32); - sbox.push(value as u8); - result[i] = (value as u8) ^ rcon; + assert!(result[i] <= u8::MAX); + let value = AES_SBOX[result[i] as usize]; + result[i] = value ^ rcon; } result }; - let w0 = [previous_key[0], previous_key[1], previous_key[2], previous_key[3]]; - let w1 = [previous_key[4], previous_key[5], previous_key[6], previous_key[7]]; - let w2 = [previous_key[8], previous_key[9], previous_key[10], previous_key[11]]; - let w3 = [previous_key[12], previous_key[13], previous_key[14], previous_key[15]]; - let w4: [u8; 4] = w0.iter().zip(g_w3.iter()).map(|(&a, &b)| a ^ b) - .collect::>().try_into().unwrap(); - let w5: [u8; 4] = w4.iter().zip(w1.iter()).map(|(&a, &b)| a ^ b) - .collect::>().try_into().unwrap(); - let w6: [u8; 4] = w5.iter().zip(w2.iter()).map(|(&a, &b)| a ^ b) - .collect::>().try_into().unwrap(); - let w7: [u8; 4] = w6.iter().zip(w3.iter()).map(|(&a, &b)| a ^ b) - .collect::>().try_into().unwrap(); - - previous_key[0..4].copy_from_slice(&w4); - previous_key[4..8].copy_from_slice(&w5); - previous_key[8..12].copy_from_slice(&w6); - previous_key[12..16].copy_from_slice(&w7); + + let prev = previous_key.to_vec().clone(); + for i in 0..4 { + let w = if i == 0 { + prev[0..4] + .iter() + .zip(g_w3.iter()) + .map(|(&a, &b)| a ^ b) + .collect::>() + } else { + prev[i * 4..(i + 1) * 4] + .iter() + .zip(previous_key[(i - 1) * 4..i * 4].iter()) + .map(|(&a, &b)| a ^ b) + .collect::>() + }; + previous_key[i * 4..(i + 1) * 4].copy_from_slice(&w); + } } } diff --git a/crates/core/machine/src/mips/mod.rs b/crates/core/machine/src/mips/mod.rs index b96eff36e..c30e59dcf 100644 --- a/crates/core/machine/src/mips/mod.rs +++ b/crates/core/machine/src/mips/mod.rs @@ -276,7 +276,6 @@ impl MipsAir { chips.push(poseidon2_permute); let aes128_encrypt = Chip::new(MipsAir::Aes128Encrypt(AES128EncryptChip::new())); - // log::info!("aes128 cost: {:?}", aes128_encrypt.cost()); costs.insert(aes128_encrypt.name(), 11 * aes128_encrypt.cost()); chips.push(aes128_encrypt); diff --git a/crates/core/machine/src/operations/aes/aes_mul2.rs b/crates/core/machine/src/operations/aes/aes_mul2.rs index 0f6350e8e..7c7c9b970 100644 --- a/crates/core/machine/src/operations/aes/aes_mul2.rs +++ b/crates/core/machine/src/operations/aes/aes_mul2.rs @@ -4,7 +4,7 @@ use zkm_core_executor::{ ByteOpcode, }; use zkm_derive::AlignedBorrow; -use zkm_stark::{air::ZKMAirBuilder, Word}; +use zkm_stark::air::ZKMAirBuilder; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] diff --git a/crates/core/machine/src/operations/aes/aes_mul3.rs b/crates/core/machine/src/operations/aes/aes_mul3.rs index 389efcb53..ab6f991e1 100644 --- a/crates/core/machine/src/operations/aes/aes_mul3.rs +++ b/crates/core/machine/src/operations/aes/aes_mul3.rs @@ -4,7 +4,7 @@ use zkm_core_executor::{ ByteOpcode, }; use zkm_derive::AlignedBorrow; -use zkm_stark::{air::ZKMAirBuilder, Word}; +use zkm_stark::air::ZKMAirBuilder; use crate::operations::aes_mul2::MulBy2InAES; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] diff --git a/crates/core/machine/src/operations/aes/mix_column.rs b/crates/core/machine/src/operations/aes/mix_column.rs index 9d5ac46eb..c32a2ef5f 100644 --- a/crates/core/machine/src/operations/aes/mix_column.rs +++ b/crates/core/machine/src/operations/aes/mix_column.rs @@ -1,8 +1,5 @@ -use p3_field::{Field, FieldAlgebra}; -use zkm_core_executor::{ - events::{ByteLookupEvent, ByteRecord}, - ByteOpcode, -}; +use p3_field::Field; +use zkm_core_executor::events::ByteRecord; use zkm_derive::AlignedBorrow; use zkm_stark::ZKMAirBuilder; use crate::operations::aes::xor_byte_4::XorByte4; diff --git a/crates/core/machine/src/operations/aes/round_key.rs b/crates/core/machine/src/operations/aes/round_key.rs index ba91de44c..8ea0cf97a 100644 --- a/crates/core/machine/src/operations/aes/round_key.rs +++ b/crates/core/machine/src/operations/aes/round_key.rs @@ -1,10 +1,9 @@ -use p3_air::AirBuilder; use p3_field::{Field, FieldAlgebra}; use zkm_core_executor::ByteOpcode; -use zkm_core_executor::events::{ByteLookupEvent, ByteRecord, MemoryReadRecord}; +use zkm_core_executor::events::{ByteLookupEvent, ByteRecord}; use zkm_derive::AlignedBorrow; use zkm_stark::ZKMAirBuilder; -use crate::memory::{MemoryCols, MemoryReadCols}; +use crate::operations::subs_byte::SubsByte; pub const ROUND_CONST: [u8; 11] = [ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x00 @@ -14,6 +13,7 @@ pub const ROUND_CONST: [u8; 11] = [ #[repr(C)] pub struct NextRoundKey { pub add_round_const: T, // XOR + pub w3_subs_byte: [SubsByte; 4], // for round key // new round key pub w4: [T; 4], pub w5: [T; 4], @@ -26,14 +26,13 @@ impl NextRoundKey { &mut self, records: &mut impl ByteRecord, prev_round_key: &[u8; 16], - byte_subs_records: &[MemoryReadRecord; 4], round: u8, ) -> [u8; 16] { // check sbox values - let sbox_values: [u32; 4] = byte_subs_records.map(|m| m.value); - let all_in_u8 = sbox_values.iter().all(|&v| v <= u8::MAX as u32); - if !all_in_u8 { - panic!("Not all sbox_values fit in u8"); + let mut sub_rot_w3 = [0u8; 4]; + let shifted_w3 = [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; + for i in 0..4 { + sub_rot_w3[i] = self.w3_subs_byte[i].populate(shifted_w3[i]); } // previous round key @@ -42,7 +41,6 @@ impl NextRoundKey { let w2 = &prev_round_key[8..12]; let w3 = &prev_round_key[12..16]; - let mut sub_rot_w3 = sbox_values.map(|u| u as u8); let rcon = ROUND_CONST[round as usize]; let first_byte = sub_rot_w3[0] ^ rcon; self.add_round_const = F::from_canonical_u8(first_byte); @@ -122,24 +120,33 @@ impl NextRoundKey { builder: &mut AB, cols: NextRoundKey, prev_round_key: [AB::Var; 16], - sbox_read: &[MemoryReadCols; 4], rcon: AB::Var, - is_real: AB::Var, + is_real: AB::Expr, ) { let w0 = &prev_round_key[0..4]; let w1 = &prev_round_key[4..8]; let w2 = &prev_round_key[8..12]; let w3 = &prev_round_key[12..16]; + let shifted_w3 = [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; + for i in 0..4 { + SubsByte::::eval( + builder, + cols.w3_subs_byte[i], + shifted_w3[i], + is_real.clone(), + ); + } + // sbox substitution bytes. - let sbox_values: [AB::Var; 4] = sbox_read.map(|m| m.value().0[0]); + let subs_byte_values: [AB::Var; 4] = cols.w3_subs_byte.map(|m| m.value); builder.send_byte( AB::F::from_canonical_u32(ByteOpcode::XOR as u32), cols.add_round_const, - sbox_values[0], + subs_byte_values[0], rcon, - is_real, + is_real.clone(), ); builder.send_byte( @@ -147,7 +154,7 @@ impl NextRoundKey { cols.w4[0], w0[0], cols.add_round_const, - is_real, + is_real.clone(), ); for i in 1..4 { @@ -155,8 +162,8 @@ impl NextRoundKey { AB::F::from_canonical_u32(ByteOpcode::XOR as u32), cols.w4[i], w0[i], - sbox_values[i], - is_real, + subs_byte_values[i], + is_real.clone(), ) } @@ -166,7 +173,7 @@ impl NextRoundKey { cols.w5[i], cols.w4[i], w1[i], - is_real, + is_real.clone(), ) } @@ -176,7 +183,7 @@ impl NextRoundKey { cols.w6[i], cols.w5[i], w2[i], - is_real, + is_real.clone(), ) } @@ -186,7 +193,7 @@ impl NextRoundKey { cols.w7[i], cols.w6[i], w3[i], - is_real, + is_real.clone(), ) } } diff --git a/crates/core/machine/src/operations/aes/subs_byte.rs b/crates/core/machine/src/operations/aes/subs_byte.rs index 7d33b9151..eebd7850b 100644 --- a/crates/core/machine/src/operations/aes/subs_byte.rs +++ b/crates/core/machine/src/operations/aes/subs_byte.rs @@ -1,9 +1,6 @@ use p3_field::{Field, FieldAlgebra}; -use zkm_core_executor::ByteOpcode; -use zkm_core_executor::events::{ByteLookupEvent, ByteRecord, MemoryReadRecord}; use zkm_derive::AlignedBorrow; use zkm_stark::ZKMAirBuilder; -use crate::memory::{MemoryCols, MemoryReadCols}; pub const AES_SBOX: [u8; 256] = [ 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, @@ -47,8 +44,8 @@ impl SubsByte { } for i in 0..128 { - let row = i / 4; - let col = i % 4; + let row = i / 32; + let col = i % 32; if i as u8 == pos { self.positions[row][col] = F::ONE; } else { @@ -65,14 +62,14 @@ impl SubsByte { builder: &mut AB, cols: SubsByte, byte: AB::Var, - is_real: AB::Var, + is_real: AB::Expr, ) { builder.assert_bool(cols.is_left); - builder.assert_bool(is_real); + builder.assert_bool(is_real.clone()); // if is_real = 0 then is_left must be 0 builder.assert_eq( - (AB::Expr::ONE - is_real.into()) * cols.is_left, + (AB::Expr::ONE - is_real.clone()) * cols.is_left, AB::Expr::ZERO, ); @@ -82,16 +79,16 @@ impl SubsByte { }).sum::(); builder.assert_eq( sum_positions, - is_real.into(), + is_real.clone(), ); for i in 0..128 { // positions are boolean - let row = i / 4; - let col = i % 4; + let row = i / 32; + let col = i % 32; builder.assert_bool(cols.positions[row][col]); builder.assert_eq( - (AB::Expr::ONE - is_real.into()) * cols.positions[row][col], + (AB::Expr::ONE - is_real.clone()) * cols.positions[row][col], AB::Expr::ZERO, ); // if is_left = 1 then byte = i else byte = i+128 @@ -106,7 +103,9 @@ impl SubsByte { // value = SBOX[byte] builder.assert_eq( - cols.is_left * (AB::Expr::from_canonical_u8(AES_SBOX[i]) - cols.value.clone()) * cols.positions[row][col].clone(), + cols.is_left + * (AB::Expr::from_canonical_u8(AES_SBOX[i]) - cols.value.clone()) + * cols.positions[row][col].clone(), AB::Expr::ZERO, ); builder.assert_eq( diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index da7e52c9c..4a12a8992 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -1,19 +1,17 @@ use std::borrow::Borrow; -use log::__private_api::loc; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::FieldAlgebra; use p3_matrix::Matrix; -use tempfile::Builder; use zkm_core_executor::ByteOpcode; use zkm_core_executor::events::AES_128_BLOCK_BYTES; use zkm_core_executor::syscalls::SyscallCode; -use zkm_stark::{ByteAirBuilder, LookupScope, MachineAir, ZKMAirBuilder}; +use zkm_stark::{LookupScope, ZKMAirBuilder}; use crate::air::{MemoryAirBuilder, WordAirBuilder}; -use crate::KeccakSpongeChip; use crate::memory::MemoryCols; use crate::operations::mix_column::MixColumn; use crate::operations::round_key::{NextRoundKey, ROUND_CONST}; -use crate::syscall::precompiles::aes128_encrypt::{AES128EncryptChip, AES_SBOX}; +use crate::operations::subs_byte::SubsByte; +use crate::syscall::precompiles::aes128_encrypt::{AES128EncryptChip}; use crate::syscall::precompiles::aes128_encrypt::columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}; impl BaseAir for AES128EncryptChip { @@ -44,11 +42,11 @@ where self.eval_flags(builder, local); self.eval_memory_access(builder, local); + self.eval_subs_byte(builder, local); self.eval_mix_column(builder, local); - self.eval_compute_round_key(builder, local); self.eval_add_round_key(builder, local); + self.eval_compute_round_key(builder, local); self.eval_input_output(builder, local); - self.eval_sbox_values(builder, local); self.eval_transition(builder, local, next); } } @@ -106,15 +104,6 @@ impl AES128EncryptChip { ); } - // if this is the first row, populate reading sbox_addr - builder.eval_memory_access( - local.shard, - local.clk, - AB::F::from_canonical_u8(6), - &local.sbox_addr_read, - local.round[0], - ); - // if this is the last row, populate writing output for i in 0..4 { builder.eval_memory_access( @@ -125,57 +114,22 @@ impl AES128EncryptChip { local.round[10] ); } + } - let round_1to10 = local.round[10] + local.round_1to9; - let round_0to9 = local.round_1to9 + local.round[0]; - - // subs bytes for state matrix - for i in 0..AES_128_BLOCK_BYTES { - let index = local.state_matrix[i]; - builder.eval_memory_access( - local.shard, - local.clk, - local.sbox_address + index * AB::F::from_canonical_u8(4), - &local.state_subs_bytes[i], - round_1to10.clone(), - ) - } - - // subs bytes for round key computation - let key_id = [13_usize, 14, 15, 12]; - for (i, id) in key_id.iter().enumerate() { - builder.eval_memory_access( - local.shard, - local.clk, - local.sbox_address + local.round_key_matrix[*id] * AB::F::from_canonical_u8(4), - &local.roundkey_subs_bytes[i], - round_0to9.clone(), - ); - } - - // sbox elements - let start = round * AB::F::from_canonical_u32(24); - for i in 0..24 { - builder.eval_memory_access( - local.shard, - local.clk, - local.sbox_address - + (start.clone() + AB::F::from_canonical_u32(i as u32)) * AB::F::from_canonical_u8(4), - &local.sbox[i], - round_0to9.clone() - ); - } + fn eval_subs_byte( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let round_1to10 = local.round_1to9 + local.round[10]; for i in 0..16 { - builder.eval_memory_access( - local.shard, - local.clk, - local.sbox_address - + (start.clone() + AB::F::from_canonical_u32(i as u32)) * AB::F::from_canonical_u8(4), - &local.sbox[i], - local.round[10].clone(), + SubsByte::::eval( + builder, + local.state_subs_byte[i], + local.state_matrix[i], + round_1to10.clone(), ); } - } fn eval_mix_column( @@ -184,22 +138,22 @@ impl AES128EncryptChip { local: &AES128EncryptionCols, ) { let shifted_state = [ - local.state_subs_bytes[0].value()[0], - local.state_subs_bytes[5].value()[0], - local.state_subs_bytes[10].value()[0], - local.state_subs_bytes[15].value()[0], - local.state_subs_bytes[4].value()[0], - local.state_subs_bytes[9].value()[0], - local.state_subs_bytes[14].value()[0], - local.state_subs_bytes[3].value()[0], - local.state_subs_bytes[8].value()[0], - local.state_subs_bytes[13].value()[0], - local.state_subs_bytes[2].value()[0], - local.state_subs_bytes[7].value()[0], - local.state_subs_bytes[12].value()[0], - local.state_subs_bytes[1].value()[0], - local.state_subs_bytes[6].value()[0], - local.state_subs_bytes[11].value()[0], + local.state_subs_byte[0].value, + local.state_subs_byte[5].value, + local.state_subs_byte[10].value, + local.state_subs_byte[15].value, + local.state_subs_byte[4].value, + local.state_subs_byte[9].value, + local.state_subs_byte[14].value, + local.state_subs_byte[3].value, + local.state_subs_byte[8].value, + local.state_subs_byte[13].value, + local.state_subs_byte[2].value, + local.state_subs_byte[7].value, + local.state_subs_byte[12].value, + local.state_subs_byte[1].value, + local.state_subs_byte[6].value, + local.state_subs_byte[11].value, ]; MixColumn::::eval( builder, @@ -209,26 +163,34 @@ impl AES128EncryptChip { ); } + fn eval_add_round_key( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + for i in 0..AES_128_BLOCK_BYTES { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + local.add_round_key[i], + local.mix_column.xor_byte4s[i].value, + local.round_key_matrix[i], + local.is_real + ) + } + } + fn eval_compute_round_key( &self, builder: &mut AB, local: &AES128EncryptionCols, ) { + let round_0to9 = local.round_1to9 + local.round[0].clone(); NextRoundKey::::eval( builder, local.next_round_key, local.round_key_matrix, - &local.roundkey_subs_bytes, - local.round_const, - local.round[0], - ); - NextRoundKey::::eval( - builder, - local.next_round_key, - local.round_key_matrix, - &local.roundkey_subs_bytes, local.round_const, - local.round_1to9, + round_0to9, ); for i in 0..11 { @@ -237,23 +199,6 @@ impl AES128EncryptChip { AB::F::from_canonical_u8(ROUND_CONST[i]) ); } - - } - - fn eval_add_round_key( - &self, - builder: &mut AB, - local: &AES128EncryptionCols, - ) { - for i in 0..AES_128_BLOCK_BYTES { - builder.send_byte( - AB::F::from_canonical_u32(ByteOpcode::XOR as u32), - local.add_round_key[i], - local.mix_column.xor_byte4s[i].value, - local.round_key_matrix[i], - local.is_real - ) - } } fn eval_input_output( @@ -290,7 +235,7 @@ impl AES128EncryptChip { ); builder.when(local.round[0]).assert_eq( local.round_key_matrix[idx], - (local.key[i].access.value[j]) + local.key[i].access.value[j] ); } } @@ -315,38 +260,6 @@ impl AES128EncryptChip { } } - fn eval_sbox_values( - &self, - builder: &mut AB, - local: &AES128EncryptionCols, - ) { - // sbox values are true - for i in 0..24 { - for round in 0..10 { - builder.when(local.round[round]).assert_eq( - local.sbox[i].access.value[0], - AB::Expr::from_canonical_u8(AES_SBOX[round * 24 + i]) - ); - } - } - - for i in 0..16 { - builder.when(local.round[10]).assert_eq( - local.sbox[i].access.value[0], - AB::Expr::from_canonical_u8(AES_SBOX[240 + i]) - ); - } - - for i in 0..24 { - for j in 1..4 { - builder.assert_eq( - local.sbox[i].access.value[j], - AB::Expr::ZERO - ); - } - } - } - fn eval_transition( &self, builder: &mut AB, @@ -361,7 +274,6 @@ impl AES128EncryptChip { // key_address, block_address, sbox_address remain the same builder.when(round_0to9.clone()).assert_eq(next.key_address, local.key_address); builder.when(round_0to9.clone()).assert_eq(next.block_address, local.block_address); - builder.when(round_0to9.clone()).assert_eq(next.sbox_address, local.sbox_address); // round transition for i in 0..10 { diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs index 912146efc..83b686b4e 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs @@ -2,6 +2,7 @@ use zkm_derive::AlignedBorrow; use crate::memory::{MemoryReadCols, MemoryReadWriteCols}; use crate::operations::mix_column::MixColumn; use crate::operations::round_key::NextRoundKey; +use crate::operations::subs_byte::SubsByte; /// AES128EncryptCols is the column layout for the AES128 encryption. /// The number of rows equal to the number of block. @@ -13,22 +14,18 @@ pub struct AES128EncryptionCols { pub is_real: T, pub key_address: T, pub block_address: T, - pub sbox_address: T, pub receive_syscall: T, - pub sbox_addr_read: MemoryReadCols, pub key: [MemoryReadCols; 4], pub block: [MemoryReadWriteCols; 4], - pub sbox: [MemoryReadCols; 24], //24 * 11 = 264 > 256 Sbox elements. - pub round: [T; 11], // [0,..10] + pub round: [T; 11], // [0,..,10] pub round_1to9: T, // 1 to 9 pub round_const: T, pub state_matrix: [T; 16], pub round_key_matrix: [T; 16], + pub state_subs_byte: [SubsByte; 16], pub next_round_key: NextRoundKey, - pub roundkey_subs_bytes: [MemoryReadCols; 4], // byte subs for round key - pub state_subs_bytes: [MemoryReadCols; 16], // byte subs for state pub mix_column: MixColumn, pub add_round_key: [T; 16], // result of this round } -pub const NUM_AES128_ENCRYPTION_COLS: usize = core::mem::size_of::>(); +pub const NUM_AES128_ENCRYPTION_COLS: usize = size_of::>(); diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs index 30f9b8d5c..ca79949bb 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs @@ -2,25 +2,6 @@ mod air; mod columns; mod trace; -pub const AES_SBOX: [u8; 256] = [ - 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, - 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, - 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, - 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, - 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, - 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, - 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, - 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, - 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, - 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, - 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, - 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, - 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, - 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, - 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, - 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, -]; - #[derive(Default)] pub struct AES128EncryptChip; diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index 445507323..94fe196e6 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -2,7 +2,6 @@ use std::borrow::BorrowMut; use hashbrown::HashMap; use itertools::Itertools; -use p3_air::BaseAir; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; @@ -45,7 +44,7 @@ impl MachineAir for AES128EncryptChip { output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec()); } - fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + fn generate_trace(&self, input: &Self::Record, _output: &mut Self::Record) -> RowMajorMatrix { let rows = Vec::new(); log::info!("generate trace"); @@ -60,11 +59,9 @@ impl MachineAir for AES128EncryptChip { } let mut rows = wrapped_rows.unwrap(); let num_real_rows = rows.len(); - let mut padded_num_rows = num_real_rows.next_power_of_two(); - for i in num_real_rows..padded_num_rows { - let mut row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; - // let cols: &mut AES128EncryptionCols = row.as_mut_slice().borrow_mut(); - // + let padded_num_rows = num_real_rows.next_power_of_two(); + for _ in num_real_rows..padded_num_rows { + let row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; rows.push(row); } RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_AES128_ENCRYPTION_COLS) @@ -95,7 +92,6 @@ impl AES128EncryptChip { state[i * 4..i * 4 + 4].copy_from_slice(&event.input[i].to_le_bytes()); round_key[i * 4..i * 4 + 4].copy_from_slice(&event.key[i].to_le_bytes()); } - let mut sbox_read_index = 0_usize; for round in 0..num_round { let mut row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; let cols: &mut AES128EncryptionCols = row.as_mut_slice().borrow_mut(); @@ -104,7 +100,6 @@ impl AES128EncryptChip { cols.is_real = F::ONE; cols.key_address = F::from_canonical_u32(event.key_addr); cols.block_address = F::from_canonical_u32(event.block_addr); - cols.sbox_address = F::from_canonical_u32(event.sbox_addr); cols.round = [F::ZERO; 11]; cols.round[round] = F::ONE; cols.receive_syscall = F::from_bool(round == 0); @@ -138,11 +133,6 @@ impl AES128EncryptChip { blu ); } - // read the sbox address - cols.sbox_addr_read.populate( - event.sbox_addr_memory, - blu - ); // the mix column value should be the state for i in 0..AES_128_BLOCK_BYTES { @@ -167,13 +157,11 @@ impl AES128EncryptChip { { // subs_bytes for i in 0..AES_128_BLOCK_BYTES { - cols.state_subs_bytes[i].populate( - event.sbox_read_records[sbox_read_index + i], - blu + let subs_value = cols.state_subs_byte[i].populate( + state[i] ); - state[i] = event.sbox_reads[sbox_read_index + i]; + state[i] = subs_value; } - sbox_read_index += AES_128_BLOCK_BYTES; // shift_rows let shifted_row = [ @@ -212,45 +200,15 @@ impl AES128EncryptChip { } if round != 10 { - // read 24 sbox elements for each, except the last round - for i in sbox_read_index..sbox_read_index + 24 { - cols.sbox[i - sbox_read_index].populate( - event.sbox_read_records[i], - blu - ) - } - sbox_read_index += 24; - - // read the round key byte subs - for i in 0..4 { - cols.roundkey_subs_bytes[i].populate( - event.sbox_read_records[sbox_read_index + i], - blu - ); - } - // compute next round key let next_round_key = cols.next_round_key.populate( blu, &round_key, - event.sbox_read_records[sbox_read_index..sbox_read_index + 4] - .try_into() - .expect("Slice length must be exactly 4"), - round as u8, + round as u8 ); - sbox_read_index += 4; round_key = next_round_key; } else { - for i in sbox_read_index..(sbox_read_index + 16) { - cols.sbox[i - sbox_read_index].populate( - event.sbox_read_records[i], - blu - ) - } - sbox_read_index += 16; - assert_eq!(sbox_read_index, 456); - for i in 0..4 { // check output let tmp = event.output_write_records[i].value.to_le_bytes(); diff --git a/crates/stark/src/opts.rs b/crates/stark/src/opts.rs index 5524d49a4..3d380cd9d 100644 --- a/crates/stark/src/opts.rs +++ b/crates/stark/src/opts.rs @@ -219,7 +219,7 @@ impl SplitOpts { keccak: 8 * deferred_split_threshold / 24, sha_extend: 32 * deferred_split_threshold / 48, sha_compress: 32 * deferred_split_threshold / 80, - aes128_encrypt: 8 * deferred_split_threshold / 11, + aes128_encrypt: 32 * deferred_split_threshold / 11, memory: 64 * deferred_split_threshold, } } diff --git a/crates/zkvm/entrypoint/src/syscalls/aes128.rs b/crates/zkvm/entrypoint/src/syscalls/aes128.rs index 35a34f81e..97b9fda17 100644 --- a/crates/zkvm/entrypoint/src/syscalls/aes128.rs +++ b/crates/zkvm/entrypoint/src/syscalls/aes128.rs @@ -9,7 +9,7 @@ use core::arch::asm; /// a four byte boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4], sbox: *const [u32; 256]) { +pub extern "C" fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -17,7 +17,6 @@ pub extern "C" fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; in("$2") crate::syscalls::AES128_ENCRYPT, in("$4") state, in("$5") key, - in("$6") sbox, ); } diff --git a/crates/zkvm/lib/src/aes128.rs b/crates/zkvm/lib/src/aes128.rs index 3deeb21bd..11d3f9033 100644 --- a/crates/zkvm/lib/src/aes128.rs +++ b/crates/zkvm/lib/src/aes128.rs @@ -5,25 +5,6 @@ pub fn aes128_encrypt(state: &mut [u8; 16], key: &[u8; 16]) { let mut state_u32 = [0u32; 4]; let mut key_u32 = [0u32; 4]; - let sbox: [u32; 256] = [ - 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, - 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, - 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, - 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, - 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, - 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, - 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, - 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, - 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, - 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, - 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, - 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, - 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, - 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, - 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, - 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, - ]; - for i in 0..4 { state_u32[i] = u32::from_le_bytes([ state[i * 4], @@ -39,7 +20,7 @@ pub fn aes128_encrypt(state: &mut [u8; 16], key: &[u8; 16]) { ]); } unsafe { - syscall_aes128_encrypt(&mut state_u32, &key_u32, &sbox) + syscall_aes128_encrypt(&mut state_u32, &key_u32) } for i in 0..4 { state[4 * i..4 * i + 4].copy_from_slice(&state_u32[i].to_le_bytes()); diff --git a/crates/zkvm/lib/src/lib.rs b/crates/zkvm/lib/src/lib.rs index 4940a4b0b..cf0e7eb0b 100644 --- a/crates/zkvm/lib/src/lib.rs +++ b/crates/zkvm/lib/src/lib.rs @@ -80,7 +80,7 @@ extern "C" { pub fn syscall_poseidon2_permute(state: *mut [u32; 16]); /// Executes the AES-128 encryption on the given state with the given key. - pub fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4], sbox: *const [u32; 256]); + pub fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4]); /// Executes an uint256 multiplication on the given inputs. pub fn syscall_uint256_mulmod(x: *mut [u32; 8], y: *const [u32; 8]); diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs index 31dac0fe2..e752bb0f5 100644 --- a/examples/aes128/host/src/main.rs +++ b/examples/aes128/host/src/main.rs @@ -1,5 +1,5 @@ use std::env; -use zkm_sdk::{include_elf, utils, ProverClient, ZKMProofWithPublicValues, ZKMStdin}; +use zkm_sdk::{include_elf, utils, ProverClient, ZKMStdin}; /// The ELF we want to execute inside the zkVM. const ELF: &[u8] = include_elf!("aes128"); From 4e7f6da7fa4472151ea56caa6598538cf3fb892a Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 23 Sep 2025 11:46:10 +0700 Subject: [PATCH 10/12] style: fmt fix --- .../executor/src/events/precompiles/mod.rs | 2 +- crates/core/executor/src/syscalls/mod.rs | 2 +- .../syscalls/precompiles/aes128/encrypt.rs | 43 +++---- .../src/syscalls/precompiles/aes128/mod.rs | 2 +- .../src/syscalls/precompiles/aes128/utils.rs | 2 +- crates/core/machine/src/mips/mod.rs | 4 +- .../machine/src/operations/aes/aes_mul2.rs | 35 ++---- .../machine/src/operations/aes/aes_mul3.rs | 20 +--- .../machine/src/operations/aes/mix_column.rs | 72 +++--------- crates/core/machine/src/operations/aes/mod.rs | 4 +- .../machine/src/operations/aes/round_key.rs | 28 ++--- .../machine/src/operations/aes/subs_byte.rs | 34 +++--- .../machine/src/operations/aes/xor_byte_4.rs | 26 ++--- .../syscall/precompiles/aes128_encrypt/air.rs | 110 ++++++++---------- .../precompiles/aes128_encrypt/columns.rs | 4 +- .../syscall/precompiles/aes128_encrypt/mod.rs | 2 +- .../precompiles/aes128_encrypt/trace.rs | 80 ++++++------- crates/zkvm/lib/src/aes128.rs | 14 +-- crates/zkvm/lib/src/lib.rs | 3 +- examples/aes128/guest/src/main.rs | 2 - examples/aes128/host/src/main.rs | 10 +- examples/fibonacci_c_lib/host/build.rs | 2 +- examples/large-sum/host/Cargo.toml | 30 ++--- 23 files changed, 198 insertions(+), 333 deletions(-) diff --git a/crates/core/executor/src/events/precompiles/mod.rs b/crates/core/executor/src/events/precompiles/mod.rs index 1747b0bcc..9f1d1d76a 100644 --- a/crates/core/executor/src/events/precompiles/mod.rs +++ b/crates/core/executor/src/events/precompiles/mod.rs @@ -1,3 +1,4 @@ +mod aes128; mod ec; mod edwards; mod fptower; @@ -8,7 +9,6 @@ mod sha256_compress; mod sha256_extend; mod u256x2048_mul; mod uint256; -mod aes128; use super::{MemoryLocalEvent, SyscallEvent}; use crate::syscalls::SyscallCode; diff --git a/crates/core/executor/src/syscalls/mod.rs b/crates/core/executor/src/syscalls/mod.rs index cbc48849f..d21ecdf03 100644 --- a/crates/core/executor/src/syscalls/mod.rs +++ b/crates/core/executor/src/syscalls/mod.rs @@ -102,7 +102,7 @@ pub fn default_syscall_map() -> HashMap> { syscall_map.insert(SyscallCode::HALT, Arc::new(HaltSyscall)); syscall_map.insert(SyscallCode::POSEIDON2_PERMUTE, Arc::new(Poseidon2PermuteSyscall)); - + syscall_map.insert(SyscallCode::AES128_ENCRYPT, Arc::new(AES128EncryptSyscall)); syscall_map.insert(SyscallCode::KECCAK_SPONGE, Arc::new(KeccakSpongeSyscall)); diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs index 2661e8116..18d11b801 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -1,6 +1,6 @@ use crate::events::{AES128EncryptEvent, PrecompileEvent, AES_128_BLOCK_U32S}; -use crate::syscalls::{Syscall, SyscallCode, SyscallContext}; use crate::syscalls::precompiles::aes128::utils::mul_md5; +use crate::syscalls::{Syscall, SyscallCode, SyscallContext}; pub(crate) struct AES128EncryptSyscall; @@ -61,7 +61,6 @@ impl Syscall for AES128EncryptSyscall { let mut key = Vec::new(); let mut output = Vec::new(); - // read block input for i in 0..AES_128_BLOCK_U32S { let (record, value) = rt.mr(block_ptr + i as u32 * 4); @@ -87,10 +86,7 @@ impl Syscall for AES128EncryptSyscall { let mut round_key = key; for i in 1..11 { // compute round key - Self::compute_round_key( - &mut round_key, - i - 1 - ); + Self::compute_round_key(&mut round_key, i - 1); // Subs_bytes for j in 0..state.len() { @@ -101,11 +97,10 @@ impl Syscall for AES128EncryptSyscall { // Shift row let shift_row = [ - state[0], state[5], state[10], state[15], - state[4], state[9], state[14], state[3], - state[8], state[13], state[2], state[7], - state[12], state[1], state[6], state[11], - ].to_vec(); + state[0], state[5], state[10], state[15], state[4], state[9], state[14], state[3], + state[8], state[13], state[2], state[7], state[12], state[1], state[6], state[11], + ] + .to_vec(); // Mix columns let mix_columns = if i != 10 { @@ -116,10 +111,14 @@ impl Syscall for AES128EncryptSyscall { let s1 = shift_row[col_start + 1]; let s2 = shift_row[col_start + 2]; let s3 = shift_row[col_start + 3]; - mixed[col_start] = mul_md5(s0, 2) ^ mul_md5(s1, 3) ^ mul_md5(s2, 1) ^ mul_md5(s3, 1); - mixed[col_start + 1] = mul_md5(s0, 1) ^ mul_md5(s1, 2) ^ mul_md5(s2, 3) ^ mul_md5(s3, 1); - mixed[col_start + 2] = mul_md5(s0, 1) ^ mul_md5(s1, 1) ^ mul_md5(s2, 2) ^ mul_md5(s3, 3); - mixed[col_start + 3] = mul_md5(s0, 3) ^ mul_md5(s1, 1) ^ mul_md5(s2, 1) ^ mul_md5(s3, 2); + mixed[col_start] = + mul_md5(s0, 2) ^ mul_md5(s1, 3) ^ mul_md5(s2, 1) ^ mul_md5(s3, 1); + mixed[col_start + 1] = + mul_md5(s0, 1) ^ mul_md5(s1, 2) ^ mul_md5(s2, 3) ^ mul_md5(s3, 1); + mixed[col_start + 2] = + mul_md5(s0, 1) ^ mul_md5(s1, 1) ^ mul_md5(s2, 2) ^ mul_md5(s3, 3); + mixed[col_start + 3] = + mul_md5(s0, 3) ^ mul_md5(s1, 1) ^ mul_md5(s2, 1) ^ mul_md5(s3, 2); } mixed } else { @@ -166,16 +165,14 @@ impl Syscall for AES128EncryptSyscall { } impl AES128EncryptSyscall { - fn compute_round_key( - previous_key: &mut [u8], - round: usize - ) { + fn compute_round_key(previous_key: &mut [u8], round: usize) { if previous_key.len() != 16 { panic!("AES128: wrong previous key length"); } // First 4 bytes let g_w3 = { - let mut result = [previous_key[13], previous_key[14], previous_key[15], previous_key[12]]; + let mut result = + [previous_key[13], previous_key[14], previous_key[15], previous_key[12]]; for (i, rcon) in AES128_RCON[round].iter().enumerate() { assert!(result[i] <= u8::MAX); let value = AES_SBOX[result[i] as usize]; @@ -187,11 +184,7 @@ impl AES128EncryptSyscall { let prev = previous_key.to_vec().clone(); for i in 0..4 { let w = if i == 0 { - prev[0..4] - .iter() - .zip(g_w3.iter()) - .map(|(&a, &b)| a ^ b) - .collect::>() + prev[0..4].iter().zip(g_w3.iter()).map(|(&a, &b)| a ^ b).collect::>() } else { prev[i * 4..(i + 1) * 4] .iter() diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs index d0d2382ed..46178d5e0 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs @@ -1,2 +1,2 @@ pub mod encrypt; -pub mod utils; \ No newline at end of file +pub mod utils; diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs b/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs index d77e6ad32..86858f98e 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs @@ -15,4 +15,4 @@ pub fn mul_md5(x: u8, by: u8) -> u8 { 3 => xtime(x) ^ x, // 3*x = (2*x) ⊕ x _ => panic!("Only supports multipliers 1, 2, or 3"), } -} \ No newline at end of file +} diff --git a/crates/core/machine/src/mips/mod.rs b/crates/core/machine/src/mips/mod.rs index c30e59dcf..b913871ba 100644 --- a/crates/core/machine/src/mips/mod.rs +++ b/crates/core/machine/src/mips/mod.rs @@ -2,9 +2,9 @@ use crate::{ global::GlobalChip, memory::{MemoryChipType, MemoryLocalChip, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW}, syscall::precompiles::{ + aes128_encrypt::AES128EncryptChip, fptower::{Fp2AddSubAssignChip, Fp2MulAssignChip, FpOpChip}, poseidon2::Poseidon2PermuteChip, - aes128_encrypt::AES128EncryptChip, }, }; use core::fmt; @@ -274,7 +274,7 @@ impl MipsAir { let poseidon2_permute = Chip::new(MipsAir::Poseidon2Permute(Poseidon2PermuteChip::new())); costs.insert(poseidon2_permute.name(), poseidon2_permute.cost()); chips.push(poseidon2_permute); - + let aes128_encrypt = Chip::new(MipsAir::Aes128Encrypt(AES128EncryptChip::new())); costs.insert(aes128_encrypt.name(), 11 * aes128_encrypt.cost()); chips.push(aes128_encrypt); diff --git a/crates/core/machine/src/operations/aes/aes_mul2.rs b/crates/core/machine/src/operations/aes/aes_mul2.rs index 7c7c9b970..591c5a546 100644 --- a/crates/core/machine/src/operations/aes/aes_mul2.rs +++ b/crates/core/machine/src/operations/aes/aes_mul2.rs @@ -11,8 +11,8 @@ use zkm_stark::air::ZKMAirBuilder; pub struct MulBy2InAES { pub and_0x80: T, pub left_shift_1: T, - pub is_xor: T, // 0 or 1 - pub xor_0x1b: T // also the result + pub is_xor: T, // 0 or 1 + pub xor_0x1b: T, // also the result } impl MulBy2InAES { @@ -33,22 +33,12 @@ impl MulBy2InAES { self.is_xor = F::from_canonical_u8(is_xor); // Byte lookup events - let byte_event_and = ByteLookupEvent { - opcode: ByteOpcode::AND, - a1: and_0x80 as u16, - a2: 0, - b: x, - c: 0x80, - }; + let byte_event_and = + ByteLookupEvent { opcode: ByteOpcode::AND, a1: and_0x80 as u16, a2: 0, b: x, c: 0x80 }; record.add_byte_lookup_event(byte_event_and); - let byte_event_ssl = ByteLookupEvent { - opcode: ByteOpcode::SLL, - a1: left_shift_1 as u16, - a2: 0, - b: x, - c: 1, - }; + let byte_event_ssl = + ByteLookupEvent { opcode: ByteOpcode::SLL, a1: left_shift_1 as u16, a2: 0, b: x, c: 1 }; record.add_byte_lookup_event(byte_event_ssl); if is_xor == 1 { @@ -90,18 +80,11 @@ impl MulBy2InAES { builder.assert_bool(cols.is_xor); // if cols.is_xor == 1, then and_0x80 == 128, else and_0x80 = 0 - builder.assert_eq( - cols.and_0x80, - cols.is_xor * AB::Expr::from_canonical_u8(128u8) - ); + builder.assert_eq(cols.and_0x80, cols.is_xor * AB::Expr::from_canonical_u8(128u8)); - builder.assert_eq( - (AB::Expr::ONE- is_real.into()) * cols.is_xor, - AB::Expr::ZERO - ); + builder.assert_eq((AB::Expr::ONE - is_real.into()) * cols.is_xor, AB::Expr::ZERO); - builder - .send_byte( + builder.send_byte( AB::F::from_canonical_u32(ByteOpcode::XOR as u32), cols.xor_0x1b, cols.left_shift_1, diff --git a/crates/core/machine/src/operations/aes/aes_mul3.rs b/crates/core/machine/src/operations/aes/aes_mul3.rs index ab6f991e1..10caabbd8 100644 --- a/crates/core/machine/src/operations/aes/aes_mul3.rs +++ b/crates/core/machine/src/operations/aes/aes_mul3.rs @@ -1,3 +1,4 @@ +use crate::operations::aes_mul2::MulBy2InAES; use p3_field::{Field, FieldAlgebra}; use zkm_core_executor::{ events::{ByteLookupEvent, ByteRecord}, @@ -5,13 +6,12 @@ use zkm_core_executor::{ }; use zkm_derive::AlignedBorrow; use zkm_stark::air::ZKMAirBuilder; -use crate::operations::aes_mul2::MulBy2InAES; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct MulBy3InAES { pub mul_by_2: MulBy2InAES, - pub xor_x: T // also the result + pub xor_x: T, // also the result } impl MulBy3InAES { @@ -21,13 +21,8 @@ impl MulBy3InAES { self.xor_x = F::from_canonical_u8(xor_x); // Byte lookup event for the final XOR - let byte_event_xor = ByteLookupEvent { - opcode: ByteOpcode::XOR, - a1: xor_x as u16, - a2: 0, - b: x2, - c: x, - }; + let byte_event_xor = + ByteLookupEvent { opcode: ByteOpcode::XOR, a1: xor_x as u16, a2: 0, b: x2, c: x }; record.add_byte_lookup_event(byte_event_xor); xor_x } @@ -39,12 +34,7 @@ impl MulBy3InAES { cols: MulBy3InAES, is_real: AB::Var, ) { - MulBy2InAES::::eval( - builder, - x, - cols.mul_by_2, - is_real, - ); + MulBy2InAES::::eval(builder, x, cols.mul_by_2, is_real); builder.send_byte( AB::F::from_canonical_u32(ByteOpcode::XOR as u32), diff --git a/crates/core/machine/src/operations/aes/mix_column.rs b/crates/core/machine/src/operations/aes/mix_column.rs index c32a2ef5f..0fae83ec8 100644 --- a/crates/core/machine/src/operations/aes/mix_column.rs +++ b/crates/core/machine/src/operations/aes/mix_column.rs @@ -1,10 +1,10 @@ +use crate::operations::aes::xor_byte_4::XorByte4; +use crate::operations::aes_mul2::MulBy2InAES; +use crate::operations::aes_mul3::MulBy3InAES; use p3_field::Field; use zkm_core_executor::events::ByteRecord; use zkm_derive::AlignedBorrow; use zkm_stark::ZKMAirBuilder; -use crate::operations::aes::xor_byte_4::XorByte4; -use crate::operations::aes_mul2::MulBy2InAES; -use crate::operations::aes_mul3::MulBy3InAES; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] @@ -15,11 +15,7 @@ pub struct MixColumn { } impl MixColumn { - pub fn populate( - &mut self, - record: &mut impl ByteRecord, - shifted_state: &[u8; 16], - ) -> [u8; 16] { + pub fn populate(&mut self, record: &mut impl ByteRecord, shifted_state: &[u8; 16]) -> [u8; 16] { let mut mixed = [0u8; 16]; for col in 0..4 { let col_start = col * 4; @@ -74,18 +70,8 @@ impl MixColumn { // col_start { - MulBy2InAES::::eval( - builder, - s0, - cols.mul_by_2s[col_start], - is_real, - ); - MulBy3InAES::::eval( - builder, - s1, - cols.mul_by_3s[col_start], - is_real, - ); + MulBy2InAES::::eval(builder, s0, cols.mul_by_2s[col_start], is_real); + MulBy3InAES::::eval(builder, s1, cols.mul_by_3s[col_start], is_real); XorByte4::::eval( builder, cols.mul_by_2s[col_start].xor_0x1b, @@ -99,18 +85,8 @@ impl MixColumn { // col_start + 1 { - MulBy2InAES::::eval( - builder, - s1, - cols.mul_by_2s[col_start + 1], - is_real, - ); - MulBy3InAES::::eval( - builder, - s2, - cols.mul_by_3s[col_start + 1], - is_real, - ); + MulBy2InAES::::eval(builder, s1, cols.mul_by_2s[col_start + 1], is_real); + MulBy3InAES::::eval(builder, s2, cols.mul_by_3s[col_start + 1], is_real); XorByte4::::eval( builder, s0, @@ -124,18 +100,8 @@ impl MixColumn { // col_start + 2 { - MulBy2InAES::::eval( - builder, - s2, - cols.mul_by_2s[col_start + 2], - is_real, - ); - MulBy3InAES::::eval( - builder, - s3, - cols.mul_by_3s[col_start + 2], - is_real, - ); + MulBy2InAES::::eval(builder, s2, cols.mul_by_2s[col_start + 2], is_real); + MulBy3InAES::::eval(builder, s3, cols.mul_by_3s[col_start + 2], is_real); XorByte4::::eval( builder, s0, @@ -149,18 +115,8 @@ impl MixColumn { // col_start + 3 { - MulBy3InAES::::eval( - builder, - s0, - cols.mul_by_3s[col_start + 3], - is_real, - ); - MulBy2InAES::::eval( - builder, - s3, - cols.mul_by_2s[col_start + 3], - is_real, - ); + MulBy3InAES::::eval(builder, s0, cols.mul_by_3s[col_start + 3], is_real); + MulBy2InAES::::eval(builder, s3, cols.mul_by_2s[col_start + 3], is_real); XorByte4::::eval( builder, cols.mul_by_3s[col_start + 3].xor_x, @@ -172,7 +128,5 @@ impl MixColumn { ) } } - } - -} \ No newline at end of file +} diff --git a/crates/core/machine/src/operations/aes/mod.rs b/crates/core/machine/src/operations/aes/mod.rs index 4b76b2794..1ee16f771 100644 --- a/crates/core/machine/src/operations/aes/mod.rs +++ b/crates/core/machine/src/operations/aes/mod.rs @@ -1,6 +1,6 @@ pub mod aes_mul2; pub mod aes_mul3; pub mod mix_column; -pub mod xor_byte_4; pub mod round_key; -pub mod subs_byte; \ No newline at end of file +pub mod subs_byte; +pub mod xor_byte_4; diff --git a/crates/core/machine/src/operations/aes/round_key.rs b/crates/core/machine/src/operations/aes/round_key.rs index 8ea0cf97a..113b2650f 100644 --- a/crates/core/machine/src/operations/aes/round_key.rs +++ b/crates/core/machine/src/operations/aes/round_key.rs @@ -1,19 +1,18 @@ +use crate::operations::subs_byte::SubsByte; use p3_field::{Field, FieldAlgebra}; -use zkm_core_executor::ByteOpcode; use zkm_core_executor::events::{ByteLookupEvent, ByteRecord}; +use zkm_core_executor::ByteOpcode; use zkm_derive::AlignedBorrow; use zkm_stark::ZKMAirBuilder; -use crate::operations::subs_byte::SubsByte; -pub const ROUND_CONST: [u8; 11] = [ - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x00 -]; +pub const ROUND_CONST: [u8; 11] = + [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x00]; #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct NextRoundKey { - pub add_round_const: T, // XOR - pub w3_subs_byte: [SubsByte; 4], // for round key + pub add_round_const: T, // XOR + pub w3_subs_byte: [SubsByte; 4], // for round key // new round key pub w4: [T; 4], pub w5: [T; 4], @@ -21,7 +20,7 @@ pub struct NextRoundKey { pub w7: [T; 4], } -impl NextRoundKey { +impl NextRoundKey { pub fn populate( &mut self, records: &mut impl ByteRecord, @@ -30,7 +29,8 @@ impl NextRoundKey { ) -> [u8; 16] { // check sbox values let mut sub_rot_w3 = [0u8; 4]; - let shifted_w3 = [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; + let shifted_w3 = + [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; for i in 0..4 { sub_rot_w3[i] = self.w3_subs_byte[i].populate(shifted_w3[i]); } @@ -128,14 +128,10 @@ impl NextRoundKey { let w2 = &prev_round_key[8..12]; let w3 = &prev_round_key[12..16]; - let shifted_w3 = [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; + let shifted_w3 = + [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; for i in 0..4 { - SubsByte::::eval( - builder, - cols.w3_subs_byte[i], - shifted_w3[i], - is_real.clone(), - ); + SubsByte::::eval(builder, cols.w3_subs_byte[i], shifted_w3[i], is_real.clone()); } // sbox substitution bytes. diff --git a/crates/core/machine/src/operations/aes/subs_byte.rs b/crates/core/machine/src/operations/aes/subs_byte.rs index eebd7850b..eaf26ba8d 100644 --- a/crates/core/machine/src/operations/aes/subs_byte.rs +++ b/crates/core/machine/src/operations/aes/subs_byte.rs @@ -31,10 +31,7 @@ pub struct SubsByte { } impl SubsByte { - pub fn populate( - &mut self, - byte: u8, - ) -> u8 { + pub fn populate(&mut self, byte: u8) -> u8 { let mut pos = byte; if byte > 127 { self.is_left = F::ZERO; @@ -64,23 +61,18 @@ impl SubsByte { byte: AB::Var, is_real: AB::Expr, ) { - builder.assert_bool(cols.is_left); builder.assert_bool(is_real.clone()); // if is_real = 0 then is_left must be 0 - builder.assert_eq( - (AB::Expr::ONE - is_real.clone()) * cols.is_left, - AB::Expr::ZERO, - ); + builder.assert_eq((AB::Expr::ONE - is_real.clone()) * cols.is_left, AB::Expr::ZERO); // exactly one position is selected - let sum_positions = cols.positions.iter().map(|&pos| { - pos.iter().map(|&p| p.into()).sum::() - }).sum::(); - builder.assert_eq( - sum_positions, - is_real.clone(), - ); + let sum_positions = cols + .positions + .iter() + .map(|&pos| pos.iter().map(|&p| p.into()).sum::()) + .sum::(); + builder.assert_eq(sum_positions, is_real.clone()); for i in 0..128 { // positions are boolean @@ -93,11 +85,15 @@ impl SubsByte { ); // if is_left = 1 then byte = i else byte = i+128 builder.assert_eq( - cols.is_left * (AB::Expr::from_canonical_usize(i) - byte.clone()) * cols.positions[row][col].clone(), + cols.is_left + * (AB::Expr::from_canonical_usize(i) - byte.clone()) + * cols.positions[row][col].clone(), AB::Expr::ZERO, ); builder.assert_eq( - (AB::Expr::ONE - cols.is_left.clone()) * (AB::Expr::from_canonical_usize(i + 128) - byte.clone()) * cols.positions[row][col].clone(), + (AB::Expr::ONE - cols.is_left.clone()) + * (AB::Expr::from_canonical_usize(i + 128) - byte.clone()) + * cols.positions[row][col].clone(), AB::Expr::ZERO, ); @@ -116,4 +112,4 @@ impl SubsByte { ); } } -} \ No newline at end of file +} diff --git a/crates/core/machine/src/operations/aes/xor_byte_4.rs b/crates/core/machine/src/operations/aes/xor_byte_4.rs index c13e5b6a1..f1b18bc95 100644 --- a/crates/core/machine/src/operations/aes/xor_byte_4.rs +++ b/crates/core/machine/src/operations/aes/xor_byte_4.rs @@ -1,6 +1,6 @@ use p3_field::{Field, FieldAlgebra}; -use zkm_core_executor::ByteOpcode; use zkm_core_executor::events::{ByteLookupEvent, ByteRecord}; +use zkm_core_executor::ByteOpcode; use zkm_derive::AlignedBorrow; use zkm_stark::ZKMAirBuilder; @@ -10,27 +10,15 @@ use zkm_stark::ZKMAirBuilder; pub struct XorByte4 { pub interm1: T, pub interm2: T, - pub value: T + pub value: T, } impl XorByte4 { - pub fn populate( - &mut self, - record: &mut impl ByteRecord, - x: u8, - y: u8, - z: u8, - w: u8, - ) -> u8 { + pub fn populate(&mut self, record: &mut impl ByteRecord, x: u8, y: u8, z: u8, w: u8) -> u8 { let xor_inter1 = x ^ y; self.interm1 = F::from_canonical_u8(xor_inter1); - let byte_event = ByteLookupEvent { - opcode: ByteOpcode::XOR, - a1: xor_inter1 as u16, - a2: 0, - b: x, - c: y, - }; + let byte_event = + ByteLookupEvent { opcode: ByteOpcode::XOR, a1: xor_inter1 as u16, a2: 0, b: x, c: y }; record.add_byte_lookup_event(byte_event); let xor_inter2 = xor_inter1 ^ z; @@ -56,7 +44,7 @@ impl XorByte4 { record.add_byte_lookup_event(byte_event); result } - + pub fn eval( builder: &mut AB, x: AB::Var, @@ -88,4 +76,4 @@ impl XorByte4 { is_real, ); } -} \ No newline at end of file +} diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index 4a12a8992..16401c539 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -1,18 +1,20 @@ -use std::borrow::Borrow; +use crate::air::{MemoryAirBuilder, WordAirBuilder}; +use crate::memory::MemoryCols; +use crate::operations::mix_column::MixColumn; +use crate::operations::round_key::{NextRoundKey, ROUND_CONST}; +use crate::operations::subs_byte::SubsByte; +use crate::syscall::precompiles::aes128_encrypt::columns::{ + AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS, +}; +use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::FieldAlgebra; use p3_matrix::Matrix; -use zkm_core_executor::ByteOpcode; +use std::borrow::Borrow; use zkm_core_executor::events::AES_128_BLOCK_BYTES; use zkm_core_executor::syscalls::SyscallCode; +use zkm_core_executor::ByteOpcode; use zkm_stark::{LookupScope, ZKMAirBuilder}; -use crate::air::{MemoryAirBuilder, WordAirBuilder}; -use crate::memory::MemoryCols; -use crate::operations::mix_column::MixColumn; -use crate::operations::round_key::{NextRoundKey, ROUND_CONST}; -use crate::operations::subs_byte::SubsByte; -use crate::syscall::precompiles::aes128_encrypt::{AES128EncryptChip}; -use crate::syscall::precompiles::aes128_encrypt::columns::{AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS}; impl BaseAir for AES128EncryptChip { fn width(&self) -> usize { @@ -98,7 +100,7 @@ impl AES128EncryptChip { builder.eval_memory_access( local.shard, local.clk, - local.block_address + AB::F::from_canonical_u32(i as u32 * 4), + local.block_address + AB::F::from_canonical_u32(i as u32 * 4), &local.block[i], local.round[0], ); @@ -111,7 +113,7 @@ impl AES128EncryptChip { local.clk + AB::Expr::ONE, local.block_address + AB::F::from_canonical_u32((i * 4) as u32), &local.block[i], - local.round[10] + local.round[10], ); } } @@ -155,12 +157,7 @@ impl AES128EncryptChip { local.state_subs_byte[6].value, local.state_subs_byte[11].value, ]; - MixColumn::::eval( - builder, - shifted_state, - local.mix_column, - local.round_1to9 - ); + MixColumn::::eval(builder, shifted_state, local.mix_column, local.round_1to9); } fn eval_add_round_key( @@ -174,7 +171,7 @@ impl AES128EncryptChip { local.add_round_key[i], local.mix_column.xor_byte4s[i].value, local.round_key_matrix[i], - local.is_real + local.is_real, ) } } @@ -194,10 +191,9 @@ impl AES128EncryptChip { ); for i in 0..11 { - builder.when(local.round[i]).assert_eq( - local.round_const, - AB::F::from_canonical_u8(ROUND_CONST[i]) - ); + builder + .when(local.round[i]) + .assert_eq(local.round_const, AB::F::from_canonical_u8(ROUND_CONST[i])); } } @@ -229,33 +225,29 @@ impl AES128EncryptChip { for i in 0..4 { for j in 0..4 { let idx = i * 4 + j; - builder.when(local.round[0]).assert_eq( - local.state_matrix[idx], - local.block[i].access.value[j] - ); - builder.when(local.round[0]).assert_eq( - local.round_key_matrix[idx], - local.key[i].access.value[j] - ); + builder + .when(local.round[0]) + .assert_eq(local.state_matrix[idx], local.block[i].access.value[j]); + builder + .when(local.round[0]) + .assert_eq(local.round_key_matrix[idx], local.key[i].access.value[j]); } } // In round 1-9, block should remain the same for i in 0..4 { - builder.when(local.round_1to9).assert_word_eq( - *local.block[i].prev_value(), - *local.block[i].value() - ); + builder + .when(local.round_1to9) + .assert_word_eq(*local.block[i].prev_value(), *local.block[i].value()); } // In round 10, output block should be derived from state matrix for i in 0..4 { for j in 0..4 { let idx = i * 4 + j; - builder.when(local.round[10]).assert_eq( - local.block[i].access.value[j], - local.add_round_key[idx] - ); + builder + .when(local.round[10]) + .assert_eq(local.block[i].access.value[j], local.add_round_key[idx]); } } } @@ -277,41 +269,33 @@ impl AES128EncryptChip { // round transition for i in 0..10 { - builder.when(round_0to9.clone()).assert_eq( - local.round[i], - next.round[i + 1] - ); + builder.when(round_0to9.clone()).assert_eq(local.round[i], next.round[i + 1]); } // state transition for i in 0..AES_128_BLOCK_BYTES { - builder.when(round_0to9.clone()).assert_eq( - local.add_round_key[i], - next.state_matrix[i] - ); + builder + .when(round_0to9.clone()) + .assert_eq(local.add_round_key[i], next.state_matrix[i]); } // round key transition for i in 0..4 { - builder.when(round_0to9.clone()).assert_eq( - local.next_round_key.w4[i], - next.round_key_matrix[i] - ); + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w4[i], next.round_key_matrix[i]); - builder.when(round_0to9.clone()).assert_eq( - local.next_round_key.w5[i], - next.round_key_matrix[i + 4] - ); + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w5[i], next.round_key_matrix[i + 4]); - builder.when(round_0to9.clone()).assert_eq( - local.next_round_key.w6[i], - next.round_key_matrix[i + 8] - ); + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w6[i], next.round_key_matrix[i + 8]); - builder.when(round_0to9.clone()).assert_eq( - local.next_round_key.w7[i], - next.round_key_matrix[i + 12] - ); + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w7[i], next.round_key_matrix[i + 12]); } } -} \ No newline at end of file +} diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs index 83b686b4e..4acf6c7a8 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs @@ -1,8 +1,8 @@ -use zkm_derive::AlignedBorrow; use crate::memory::{MemoryReadCols, MemoryReadWriteCols}; use crate::operations::mix_column::MixColumn; use crate::operations::round_key::NextRoundKey; use crate::operations::subs_byte::SubsByte; +use zkm_derive::AlignedBorrow; /// AES128EncryptCols is the column layout for the AES128 encryption. /// The number of rows equal to the number of block. @@ -18,7 +18,7 @@ pub struct AES128EncryptionCols { pub key: [MemoryReadCols; 4], pub block: [MemoryReadWriteCols; 4], pub round: [T; 11], // [0,..,10] - pub round_1to9: T, // 1 to 9 + pub round_1to9: T, // 1 to 9 pub round_const: T, pub state_matrix: [T; 16], pub round_key_matrix: [T; 16], diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs index ca79949bb..f7d54bc12 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs @@ -23,4 +23,4 @@ pub mod tests { let program = Program::from(AES128_ENCRYPT_ELF).unwrap(); run_test::>(program).unwrap(); } -} \ No newline at end of file +} diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index 94fe196e6..88bf0c598 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -1,16 +1,22 @@ use std::borrow::BorrowMut; +use super::{columns::NUM_AES128_ENCRYPTION_COLS, AES128EncryptChip}; +use crate::operations::round_key::ROUND_CONST; +use crate::syscall::precompiles::aes128_encrypt::columns::AES128EncryptionCols; use hashbrown::HashMap; use itertools::Itertools; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; -use zkm_core_executor::{events::{ByteLookupEvent, ByteRecord, PrecompileEvent}, syscalls::SyscallCode, ByteOpcode, ExecutionRecord, Program}; -use zkm_core_executor::events::{AES128EncryptEvent, MemoryRecordEnum, AES_128_BLOCK_BYTES, AES_128_BLOCK_U32S}; -use zkm_stark::{air::MachineAir}; -use crate::operations::round_key::ROUND_CONST; -use crate::syscall::precompiles::aes128_encrypt::columns::AES128EncryptionCols; -use super::{columns::NUM_AES128_ENCRYPTION_COLS, AES128EncryptChip}; +use zkm_core_executor::events::{ + AES128EncryptEvent, MemoryRecordEnum, AES_128_BLOCK_BYTES, AES_128_BLOCK_U32S, +}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord, PrecompileEvent}, + syscalls::SyscallCode, + ByteOpcode, ExecutionRecord, Program, +}; +use zkm_stark::air::MachineAir; impl MachineAir for AES128EncryptChip { type Record = ExecutionRecord; @@ -44,7 +50,11 @@ impl MachineAir for AES128EncryptChip { output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec()); } - fn generate_trace(&self, input: &Self::Record, _output: &mut Self::Record) -> RowMajorMatrix { + fn generate_trace( + &self, + input: &Self::Record, + _output: &mut Self::Record, + ) -> RowMajorMatrix { let rows = Vec::new(); log::info!("generate trace"); @@ -64,11 +74,12 @@ impl MachineAir for AES128EncryptChip { let row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; rows.push(row); } - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_AES128_ENCRYPTION_COLS) + RowMajorMatrix::new( + rows.into_iter().flatten().collect::>(), + NUM_AES128_ENCRYPTION_COLS, + ) } - - fn included(&self, shard: &Self::Record) -> bool { if let Some(shape) = shard.shape.as_ref() { shape.included::(self) @@ -121,17 +132,12 @@ impl AES128EncryptChip { if round == 0 { // read the input for i in 0..AES_128_BLOCK_U32S { - cols.block[i].populate( - MemoryRecordEnum::Read(event.input_read_records[i]), - blu - ); + cols.block[i] + .populate(MemoryRecordEnum::Read(event.input_read_records[i]), blu); } // read the key for i in 0..AES_128_BLOCK_U32S { - cols.key[i].populate( - event.key_read_records[i], - blu - ); + cols.key[i].populate(event.key_read_records[i], blu); } // the mix column value should be the state @@ -153,30 +159,23 @@ impl AES128EncryptChip { blu.add_byte_lookup_event(byte_lookup_event); state[i] = tmp; } - } else - { + } else { // subs_bytes for i in 0..AES_128_BLOCK_BYTES { - let subs_value = cols.state_subs_byte[i].populate( - state[i] - ); + let subs_value = cols.state_subs_byte[i].populate(state[i]); state[i] = subs_value; } // shift_rows let shifted_row = [ - state[0], state[5], state[10], state[15], - state[4], state[9], state[14], state[3], - state[8], state[13], state[2], state[7], - state[12], state[1], state[6], state[11], + state[0], state[5], state[10], state[15], state[4], state[9], state[14], + state[3], state[8], state[13], state[2], state[7], state[12], state[1], + state[6], state[11], ]; // Mix columns let mixed_columns = if round != 10 { - cols.mix_column.populate( - blu, - &shifted_row - ) + cols.mix_column.populate(blu, &shifted_row) } else { for i in 0..AES_128_BLOCK_BYTES { cols.mix_column.xor_byte4s[i].value = F::from_canonical_u8(shifted_row[i]); @@ -198,17 +197,12 @@ impl AES128EncryptChip { blu.add_byte_lookup_event(byte_lookup_event); } } - + if round != 10 { // compute next round key - let next_round_key = cols.next_round_key.populate( - blu, - &round_key, - round as u8 - ); + let next_round_key = cols.next_round_key.populate(blu, &round_key, round as u8); round_key = next_round_key; - } else - { + } else { for i in 0..4 { // check output let tmp = event.output_write_records[i].value.to_le_bytes(); @@ -221,10 +215,8 @@ impl AES128EncryptChip { } // write output for i in 0..AES_128_BLOCK_U32S { - cols.block[i].populate( - MemoryRecordEnum::Write(event.output_write_records[i]), - blu - ); + cols.block[i] + .populate(MemoryRecordEnum::Write(event.output_write_records[i]), blu); } } @@ -233,4 +225,4 @@ impl AES128EncryptChip { } } } -} \ No newline at end of file +} diff --git a/crates/zkvm/lib/src/aes128.rs b/crates/zkvm/lib/src/aes128.rs index 11d3f9033..648209d27 100644 --- a/crates/zkvm/lib/src/aes128.rs +++ b/crates/zkvm/lib/src/aes128.rs @@ -12,17 +12,11 @@ pub fn aes128_encrypt(state: &mut [u8; 16], key: &[u8; 16]) { state[i * 4 + 2], state[i * 4 + 3], ]); - key_u32[i] = u32::from_le_bytes([ - key[i * 4], - key[i * 4 + 1], - key[i * 4 + 2], - key[i * 4 + 3], - ]); - } - unsafe { - syscall_aes128_encrypt(&mut state_u32, &key_u32) + key_u32[i] = + u32::from_le_bytes([key[i * 4], key[i * 4 + 1], key[i * 4 + 2], key[i * 4 + 3]]); } + unsafe { syscall_aes128_encrypt(&mut state_u32, &key_u32) } for i in 0..4 { state[4 * i..4 * i + 4].copy_from_slice(&state_u32[i].to_le_bytes()); } -} \ No newline at end of file +} diff --git a/crates/zkvm/lib/src/lib.rs b/crates/zkvm/lib/src/lib.rs index cf0e7eb0b..dda6e537e 100644 --- a/crates/zkvm/lib/src/lib.rs +++ b/crates/zkvm/lib/src/lib.rs @@ -3,11 +3,11 @@ //! Documentation for these syscalls can be found in the zkVM entrypoint //! `zkm_zkvm::syscalls` module. +pub mod aes128; pub mod bls12381; pub mod bn254; #[cfg(feature = "ecdsa")] pub mod ecdsa; - pub mod ed25519; pub mod io; pub mod keccak256; @@ -19,7 +19,6 @@ pub mod unconstrained; pub mod utils; #[cfg(feature = "verify")] pub mod verify; -pub mod aes128; extern "C" { /// Halts the program with the given exit code. diff --git a/examples/aes128/guest/src/main.rs b/examples/aes128/guest/src/main.rs index dce3131d0..708afb7a2 100644 --- a/examples/aes128/guest/src/main.rs +++ b/examples/aes128/guest/src/main.rs @@ -8,7 +8,6 @@ use core::cmp::min; use zkm_zkvm::lib::aes128::aes128_encrypt; zkm_zkvm::entrypoint!(main); - pub fn main() { let plain_text: Vec = zkm_zkvm::io::read(); let key: Vec = zkm_zkvm::io::read(); @@ -37,4 +36,3 @@ fn cipher_block_chaining(input: &[u8], key: &[u8; 16], iv: &[u8; 16]) -> [u8; 16 } block } - diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs index e752bb0f5..5cce41f2a 100644 --- a/examples/aes128/host/src/main.rs +++ b/examples/aes128/host/src/main.rs @@ -8,16 +8,14 @@ fn prove_aes128_rust() { // load input let plain_text = vec![ - 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, - 97, 244, 57, 25, 124, 234, 234, 234, 214, 134, 135, 246, 17, 29, 7 + 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, 97, 244, 57, 25, 124, + 234, 234, 234, 214, 134, 135, 246, 17, 29, 7, ]; let key = vec![0_u8; 16]; let iv = vec![0_u8; 16]; - let expected_output = vec![ - 97_u8, 203, 140, 117, 36, 211, 41, 97, - 177, 36, 93, 148, 107, 228, 201, 129 - ]; + let expected_output = + vec![97_u8, 203, 140, 117, 36, 211, 41, 97, 177, 36, 93, 148, 107, 228, 201, 129]; stdin.write(&plain_text); stdin.write(&key); diff --git a/examples/fibonacci_c_lib/host/build.rs b/examples/fibonacci_c_lib/host/build.rs index 6c6031ecb..eff897a02 100644 --- a/examples/fibonacci_c_lib/host/build.rs +++ b/examples/fibonacci_c_lib/host/build.rs @@ -3,4 +3,4 @@ use zkm_build::{build_program_with_args, BuildArgs}; fn main() { zkm_build::build_program("../guest"); -} \ No newline at end of file +} diff --git a/examples/large-sum/host/Cargo.toml b/examples/large-sum/host/Cargo.toml index 970134117..ca2f0cdab 100644 --- a/examples/large-sum/host/Cargo.toml +++ b/examples/large-sum/host/Cargo.toml @@ -13,21 +13,21 @@ zkm-sdk = { workspace = true } [build-dependencies] zkm-build = { workspace = true } -[[bin]] -name = "plonk_bn254" -path = "bin/plonk_bn254.rs" - -[[bin]] -name = "groth16_bn254" -path = "bin/groth16_bn254.rs" - -[[bin]] -name = "compressed" -path = "bin/compressed.rs" - -[[bin]] -name = "execute" -path = "bin/execute.rs" +#[[bin]] +#name = "plonk_bn254" +#path = "bin/plonk_bn254.rs" +# +#[[bin]] +#name = "groth16_bn254" +#path = "bin/groth16_bn254.rs" +# +#[[bin]] +#name = "compressed" +#path = "bin/compressed.rs" +# +#[[bin]] +#name = "execute" +#path = "bin/execute.rs" [[bin]] name = "large-sum-host" From 2063a812a38938f266d8bd3577f891bbfdf9f5e9 Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 23 Sep 2025 14:04:38 +0700 Subject: [PATCH 11/12] chore: change aes128 example --- examples/aes128/guest/src/main.rs | 18 ++++-------------- examples/aes128/host/src/main.rs | 13 +++---------- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/examples/aes128/guest/src/main.rs b/examples/aes128/guest/src/main.rs index 708afb7a2..97e499337 100644 --- a/examples/aes128/guest/src/main.rs +++ b/examples/aes128/guest/src/main.rs @@ -10,20 +10,10 @@ zkm_zkvm::entrypoint!(main); pub fn main() { let plain_text: Vec = zkm_zkvm::io::read(); - let key: Vec = zkm_zkvm::io::read(); - let iv: Vec = zkm_zkvm::io::read(); - let expected_output: Vec = zkm_zkvm::io::read(); - zkm_zkvm::io::commit::>(&plain_text); - zkm_zkvm::io::commit::>(&key); - zkm_zkvm::io::commit::>(&iv); - - assert_eq!(key.len(), 16); - assert_eq!(iv.len(), 16); - let key_array: [u8; 16] = key.as_slice().try_into().unwrap(); - let iv_array: [u8; 16] = iv.as_slice().try_into().unwrap(); - let output = cipher_block_chaining(&plain_text, &key_array, &iv_array); - assert_eq!(expected_output, output.to_vec()); - zkm_zkvm::io::commit::>(&output.to_vec()); + let key: [u8; 16] = zkm_zkvm::io::read(); + let iv: [u8; 16] = zkm_zkvm::io::read(); + let output = cipher_block_chaining(&plain_text, &key, &iv); + zkm_zkvm::io::commit::<[u8; 16]>(&output); } fn cipher_block_chaining(input: &[u8], key: &[u8; 16], iv: &[u8; 16]) -> [u8; 16] { diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs index 5cce41f2a..544d6a697 100644 --- a/examples/aes128/host/src/main.rs +++ b/examples/aes128/host/src/main.rs @@ -11,8 +11,8 @@ fn prove_aes128_rust() { 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, 97, 244, 57, 25, 124, 234, 234, 234, 214, 134, 135, 246, 17, 29, 7, ]; - let key = vec![0_u8; 16]; - let iv = vec![0_u8; 16]; + let key = [0_u8; 16]; + let iv = [0_u8; 16]; let expected_output = vec![97_u8, 203, 140, 117, 36, 211, 41, 97, 177, 36, 93, 148, 107, 228, 201, 129]; @@ -20,7 +20,6 @@ fn prove_aes128_rust() { stdin.write(&plain_text); stdin.write(&key); stdin.write(&iv); - stdin.write(&expected_output); // Create a `ProverClient` method. let client = ProverClient::new(); @@ -38,13 +37,7 @@ fn prove_aes128_rust() { // // Note that this output is read from values committed to in the program using // `zkm_zkvm::io::commit`. - let _plain_text = proof.public_values.read::>(); - let _key = proof.public_values.read::>(); - let _iv = proof.public_values.read::>(); - let public_input = proof.public_values.read::>(); - // println!("plaintext: {:?}", plain_text); - // println!("key: {:?}", key); - // println!("iv: {:?}", iv); + let public_input = proof.public_values.read::<[u8; 16]>(); assert_eq!(expected_output, public_input); // Verify proof and public values From 66834cc2ee707d3162beed9002d8ea4e0215c522 Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 23 Sep 2025 15:40:01 +0700 Subject: [PATCH 12/12] style: clippy fix --- .../syscalls/precompiles/aes128/encrypt.rs | 6 ++---- .../machine/src/operations/aes/subs_byte.rs | 20 +++++++++---------- .../syscall/precompiles/aes128_encrypt/air.rs | 2 +- .../precompiles/aes128_encrypt/trace.rs | 2 +- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs index 18d11b801..09d1215da 100644 --- a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -79,7 +79,7 @@ impl Syscall for AES128EncryptSyscall { // // Add Roundkey, Round 0 for i in 0..state.len() { - state[i] = state[i] ^ key[i]; + state[i] ^= key[i]; } // perform AES @@ -90,9 +90,8 @@ impl Syscall for AES128EncryptSyscall { // Subs_bytes for j in 0..state.len() { - assert!(state[j] <= u8::MAX); let value = AES_SBOX[state[j] as usize]; - state[j] = value as u8; + state[j] = value; } // Shift row @@ -174,7 +173,6 @@ impl AES128EncryptSyscall { let mut result = [previous_key[13], previous_key[14], previous_key[15], previous_key[12]]; for (i, rcon) in AES128_RCON[round].iter().enumerate() { - assert!(result[i] <= u8::MAX); let value = AES_SBOX[result[i] as usize]; result[i] = value ^ rcon; } diff --git a/crates/core/machine/src/operations/aes/subs_byte.rs b/crates/core/machine/src/operations/aes/subs_byte.rs index eaf26ba8d..27509603c 100644 --- a/crates/core/machine/src/operations/aes/subs_byte.rs +++ b/crates/core/machine/src/operations/aes/subs_byte.rs @@ -86,28 +86,28 @@ impl SubsByte { // if is_left = 1 then byte = i else byte = i+128 builder.assert_eq( cols.is_left - * (AB::Expr::from_canonical_usize(i) - byte.clone()) - * cols.positions[row][col].clone(), + * (AB::Expr::from_canonical_usize(i) - byte) + * cols.positions[row][col], AB::Expr::ZERO, ); builder.assert_eq( - (AB::Expr::ONE - cols.is_left.clone()) - * (AB::Expr::from_canonical_usize(i + 128) - byte.clone()) - * cols.positions[row][col].clone(), + (AB::Expr::ONE - cols.is_left) + * (AB::Expr::from_canonical_usize(i + 128) - byte) + * cols.positions[row][col], AB::Expr::ZERO, ); // value = SBOX[byte] builder.assert_eq( cols.is_left - * (AB::Expr::from_canonical_u8(AES_SBOX[i]) - cols.value.clone()) - * cols.positions[row][col].clone(), + * (AB::Expr::from_canonical_u8(AES_SBOX[i]) - cols.value) + * cols.positions[row][col], AB::Expr::ZERO, ); builder.assert_eq( - (AB::Expr::ONE - cols.is_left.clone()) - * (AB::Expr::from_canonical_u8(AES_SBOX[i + 128]) - cols.value.clone()) - * cols.positions[row][col].clone(), + (AB::Expr::ONE - cols.is_left) + * (AB::Expr::from_canonical_u8(AES_SBOX[i + 128]) - cols.value) + * cols.positions[row][col], AB::Expr::ZERO, ); } diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs index 16401c539..fa968a6eb 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -181,7 +181,7 @@ impl AES128EncryptChip { builder: &mut AB, local: &AES128EncryptionCols, ) { - let round_0to9 = local.round_1to9 + local.round[0].clone(); + let round_0to9 = local.round_1to9 + local.round[0]; NextRoundKey::::eval( builder, local.next_round_key, diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs index 88bf0c598..b90ff14d2 100644 --- a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -114,7 +114,7 @@ impl AES128EncryptChip { cols.round = [F::ZERO; 11]; cols.round[round] = F::ONE; cols.receive_syscall = F::from_bool(round == 0); - cols.round_1to9 = F::from_bool(round >= 1 && round <= 9); + cols.round_1to9 = F::from_bool((1..=9).contains(&round)); cols.round_const = F::from_canonical_u8(ROUND_CONST[round]); for i in 0..AES_128_BLOCK_BYTES {