diff --git a/Cargo.lock b/Cargo.lock index 06677f0a9..063318356 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8356,6 +8356,7 @@ dependencies = [ name = "zkm-core-executor" version = "1.2.0" dependencies = [ + "aes", "anyhow", "bincode", "bytemuck", diff --git a/crates/core/executor/Cargo.toml b/crates/core/executor/Cargo.toml index d03c7637c..95820350c 100644 --- a/crates/core/executor/Cargo.toml +++ b/crates/core/executor/Cargo.toml @@ -45,6 +45,7 @@ bytemuck = "1.16.3" tiny-keccak = { version = "2.0.2", features = ["keccak"] } vec_map = { version = "0.8.2", features = ["serde"] } enum-map = { version = "2.7.3", features = ["serde"] } +aes = "0.8.4" sha2 = { workspace = true } anyhow = { workspace = true } tracing-subscriber = "0.3.19" diff --git a/crates/core/executor/src/air.rs b/crates/core/executor/src/air.rs index 6af0b3bda..aceb3392a 100644 --- a/crates/core/executor/src/air.rs +++ b/crates/core/executor/src/air.rs @@ -41,6 +41,8 @@ pub enum MipsAirId { Secp256r1DoubleAssign = 11, /// The Poseidon2 Permute chip Poseidon2Permute = 46, + /// The AES-128 Encrypt chip + Aes128Encrypt = 49, /// The Keccak sponge chip. KeccakSponge = 48, /// The bn254 add assign chip. @@ -153,6 +155,7 @@ impl MipsAirId { Self::Secp256r1AddAssign => "Secp256r1AddAssign", Self::Secp256r1DoubleAssign => "Secp256r1DoubleAssign", Self::Poseidon2Permute => "Poseidon2Permute", + Self::Aes128Encrypt => "Aes128Encrypt", Self::KeccakSponge => "KeccakSponge", Self::Bn254AddAssign => "Bn254AddAssign", Self::Bn254DoubleAssign => "Bn254DoubleAssign", diff --git a/crates/core/executor/src/artifacts/mips_costs.json b/crates/core/executor/src/artifacts/mips_costs.json index 2beecc80f..ea53f0afe 100644 --- a/crates/core/executor/src/artifacts/mips_costs.json +++ b/crates/core/executor/src/artifacts/mips_costs.json @@ -9,6 +9,7 @@ "Secp256r1Decompress": 2686, "Secp256k1Decompress": 2686, "KeccakSponge": 102216, + "Aes128Encrypt": 38764, "Bn254AddAssign": 4013, "Bitwise": 42, "ShiftLeft": 68, diff --git a/crates/core/executor/src/events/precompiles/aes128.rs b/crates/core/executor/src/events/precompiles/aes128.rs new file mode 100644 index 000000000..ab2ef5826 --- /dev/null +++ b/crates/core/executor/src/events/precompiles/aes128.rs @@ -0,0 +1,38 @@ +use serde::{Deserialize, Serialize}; + +use crate::events::{ + memory::{MemoryReadRecord, MemoryWriteRecord}, + MemoryLocalEvent, +}; + +pub const AES_128_BLOCK_U32S: usize = 4; +pub const AES_128_BLOCK_BYTES: usize = 16; + +/// AES128 Encrypt Event +/// +/// This event is emitted when a AES128 encrypt operation is performed. +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct AES128EncryptEvent { + /// The shard number. + pub shard: u32, + /// The clock cycle. + pub clk: u32, + /// The address of the block + pub block_addr: u32, + /// The address of the key + pub key_addr: u32, + /// The input block as a [u32; AES_128_BLOCK_U32S] words. + pub input: [u32; AES_128_BLOCK_U32S], + /// The key as a [u32; AES_128_BLOCK_U32S] words. + pub key: [u32; AES_128_BLOCK_U32S], + /// The output block as a [u32; AES_128_BLOCK_U32S] words. + pub output: [u32; AES_128_BLOCK_U32S], + /// The memory records for the input + pub input_read_records: [MemoryReadRecord; AES_128_BLOCK_U32S], + /// The memory records for the key + pub key_read_records: [MemoryReadRecord; AES_128_BLOCK_U32S], + /// The memory records for the output + pub output_write_records: [MemoryWriteRecord; AES_128_BLOCK_U32S], + /// The local memory access records. + pub local_mem_access: Vec, +} diff --git a/crates/core/executor/src/events/precompiles/mod.rs b/crates/core/executor/src/events/precompiles/mod.rs index e7ae5e239..9f1d1d76a 100644 --- a/crates/core/executor/src/events/precompiles/mod.rs +++ b/crates/core/executor/src/events/precompiles/mod.rs @@ -1,3 +1,4 @@ +mod aes128; mod ec; mod edwards; mod fptower; @@ -11,6 +12,7 @@ mod uint256; use super::{MemoryLocalEvent, SyscallEvent}; use crate::syscalls::SyscallCode; +pub use aes128::*; pub use ec::*; pub use edwards::*; pub use fptower::*; @@ -34,6 +36,8 @@ pub enum PrecompileEvent { ShaCompress(ShaCompressEvent), /// Keccak sponge precompile event. KeccakSponge(KeccakSpongeEvent), + /// AES-128 encrypt precompile event. + Aes128Encrypt(AES128EncryptEvent), /// Edwards curve add precompile event. EdAdd(EllipticCurveAddEvent), /// Edwards curve decompress precompile event. @@ -105,6 +109,9 @@ impl PrecompileLocalMemory for Vec<(SyscallEvent, PrecompileEvent)> { PrecompileEvent::KeccakSponge(e) => { iterators.push(e.local_mem_access.iter()); } + PrecompileEvent::Aes128Encrypt(e) => { + iterators.push(e.local_mem_access.iter()); + } PrecompileEvent::EdDecompress(e) => { iterators.push(e.local_mem_access.iter()); } diff --git a/crates/core/executor/src/record.rs b/crates/core/executor/src/record.rs index a98c7c3d6..811a3d778 100644 --- a/crates/core/executor/src/record.rs +++ b/crates/core/executor/src/record.rs @@ -124,6 +124,7 @@ impl ExecutionRecord { SyscallCode::KECCAK_SPONGE => opts.keccak, SyscallCode::SHA_EXTEND => opts.sha_extend, SyscallCode::SHA_COMPRESS => opts.sha_compress, + SyscallCode::AES128_ENCRYPT => opts.aes128_encrypt, _ => opts.deferred, }; diff --git a/crates/core/executor/src/syscalls/code.rs b/crates/core/executor/src/syscalls/code.rs index b2bbe30d8..059e0fd55 100644 --- a/crates/core/executor/src/syscalls/code.rs +++ b/crates/core/executor/src/syscalls/code.rs @@ -169,6 +169,9 @@ pub enum SyscallCode { POSEIDON2_PERMUTE = 0x00_01_00_30, SYS_LINUX = 5000, + /// Executes the `AES128_ENCRYPT` precompile. + AES128_ENCRYPT = 0x01_01_00_31, + UNIMPLEMENTED = 0xFF_FF_FF_FF, } @@ -190,6 +193,7 @@ impl SyscallCode { 0x01_01_00_07 => SyscallCode::ED_ADD, 0x00_01_00_08 => SyscallCode::ED_DECOMPRESS, 0x01_01_00_09 => SyscallCode::KECCAK_SPONGE, + 0x01_01_00_31 => SyscallCode::AES128_ENCRYPT, 0x01_01_00_0A => SyscallCode::SECP256K1_ADD, 0x00_01_00_0B => SyscallCode::SECP256K1_DOUBLE, 0x00_01_00_0C => SyscallCode::SECP256K1_DECOMPRESS, diff --git a/crates/core/executor/src/syscalls/mod.rs b/crates/core/executor/src/syscalls/mod.rs index b36f74ba7..d21ecdf03 100644 --- a/crates/core/executor/src/syscalls/mod.rs +++ b/crates/core/executor/src/syscalls/mod.rs @@ -22,6 +22,7 @@ pub use code::*; pub use context::*; use hint::{HintLenSyscall, HintReadSyscall}; use precompiles::{ + aes128::encrypt::AES128EncryptSyscall, edwards::{add::EdwardsAddAssignSyscall, decompress::EdwardsDecompressSyscall}, fptower::{Fp2AddSubSyscall, Fp2MulSyscall, FpOpSyscall}, keccak::sponge::KeccakSpongeSyscall, @@ -102,6 +103,8 @@ pub fn default_syscall_map() -> HashMap> { syscall_map.insert(SyscallCode::POSEIDON2_PERMUTE, Arc::new(Poseidon2PermuteSyscall)); + syscall_map.insert(SyscallCode::AES128_ENCRYPT, Arc::new(AES128EncryptSyscall)); + syscall_map.insert(SyscallCode::KECCAK_SPONGE, Arc::new(KeccakSpongeSyscall)); syscall_map.insert( diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs new file mode 100644 index 000000000..09d1215da --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/aes128/encrypt.rs @@ -0,0 +1,196 @@ +use crate::events::{AES128EncryptEvent, PrecompileEvent, AES_128_BLOCK_U32S}; +use crate::syscalls::precompiles::aes128::utils::mul_md5; +use crate::syscalls::{Syscall, SyscallCode, SyscallContext}; + +pub(crate) struct AES128EncryptSyscall; + +pub const AES128_RCON: [[u8; 4]; 10] = [ + [0x01, 0x00, 0x00, 0x00], + [0x02, 0x00, 0x00, 0x00], + [0x04, 0x00, 0x00, 0x00], + [0x08, 0x00, 0x00, 0x00], + [0x10, 0x00, 0x00, 0x00], + [0x20, 0x00, 0x00, 0x00], + [0x40, 0x00, 0x00, 0x00], + [0x80, 0x00, 0x00, 0x00], + [0x1B, 0x00, 0x00, 0x00], + [0x36, 0x00, 0x00, 0x00], +]; + +pub const AES_SBOX: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + +impl Syscall for AES128EncryptSyscall { + fn num_extra_cycles(&self) -> u32 { + 1 + } + fn execute( + &self, + rt: &mut SyscallContext, + syscall_code: SyscallCode, + arg1: u32, + arg2: u32, + ) -> Option { + let start_clk = rt.clk; + let block_ptr = arg1; + let key_ptr = arg2; + + let mut input_read_records = Vec::new(); + let mut key_read_records = Vec::new(); + let mut output_write_records = Vec::new(); + + let mut input = Vec::new(); + let mut key_u32s = Vec::new(); + let mut state = Vec::new(); + let mut key = Vec::new(); + let mut output = Vec::new(); + + // read block input + for i in 0..AES_128_BLOCK_U32S { + let (record, value) = rt.mr(block_ptr + i as u32 * 4); + input_read_records.push(record); + input.push(value); + state.extend(value.to_le_bytes()); + } + + // read key + for i in 0..AES_128_BLOCK_U32S { + let (record, value) = rt.mr(key_ptr + i as u32 * 4); + key_read_records.push(record); + key_u32s.push(value); + key.extend(value.to_le_bytes()); + } + + // // Add Roundkey, Round 0 + for i in 0..state.len() { + state[i] ^= key[i]; + } + + // perform AES + let mut round_key = key; + for i in 1..11 { + // compute round key + Self::compute_round_key(&mut round_key, i - 1); + + // Subs_bytes + for j in 0..state.len() { + let value = AES_SBOX[state[j] as usize]; + state[j] = value; + } + + // Shift row + let shift_row = [ + state[0], state[5], state[10], state[15], state[4], state[9], state[14], state[3], + state[8], state[13], state[2], state[7], state[12], state[1], state[6], state[11], + ] + .to_vec(); + + // Mix columns + let mix_columns = if i != 10 { + let mut mixed = shift_row.clone(); + for col in 0..4 { + let col_start = col * 4; + let s0 = shift_row[col_start]; + let s1 = shift_row[col_start + 1]; + let s2 = shift_row[col_start + 2]; + let s3 = shift_row[col_start + 3]; + mixed[col_start] = + mul_md5(s0, 2) ^ mul_md5(s1, 3) ^ mul_md5(s2, 1) ^ mul_md5(s3, 1); + mixed[col_start + 1] = + mul_md5(s0, 1) ^ mul_md5(s1, 2) ^ mul_md5(s2, 3) ^ mul_md5(s3, 1); + mixed[col_start + 2] = + mul_md5(s0, 1) ^ mul_md5(s1, 1) ^ mul_md5(s2, 2) ^ mul_md5(s3, 3); + mixed[col_start + 3] = + mul_md5(s0, 3) ^ mul_md5(s1, 1) ^ mul_md5(s2, 1) ^ mul_md5(s3, 2); + } + mixed + } else { + shift_row + }; + + // Add round key + for j in 0..state.len() { + state[j] = mix_columns[j] ^ round_key[j]; + } + } + + // write output + // Increment the clk by 1 before writing because we read from memory at start_clk. + rt.clk += 1; + assert_eq!(state.len(), 16); + for chunk in state.chunks(4) { + let value = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + output.push(value); + } + let write_records = rt.mw_slice(block_ptr, output.as_slice()); + output_write_records.extend_from_slice(&write_records); + + // Push the AES128 encrypt event. + let shard = rt.current_shard(); + let aes128_event = PrecompileEvent::Aes128Encrypt(AES128EncryptEvent { + shard, + clk: start_clk, + block_addr: block_ptr, + key_addr: key_ptr, + input: input.as_slice().try_into().unwrap(), + key: key_u32s.as_slice().try_into().unwrap(), + output: output.as_slice().try_into().unwrap(), + input_read_records: input_read_records.as_slice().try_into().unwrap(), + key_read_records: key_read_records.as_slice().try_into().unwrap(), + output_write_records: output_write_records.as_slice().try_into().unwrap(), + local_mem_access: rt.postprocess(), + }); + let aes128_syscall_event = + rt.rt.syscall_event(start_clk, None, rt.next_pc, syscall_code.syscall_id(), arg1, arg2); + rt.add_precompile_event(syscall_code, aes128_syscall_event, aes128_event); + None + } +} + +impl AES128EncryptSyscall { + fn compute_round_key(previous_key: &mut [u8], round: usize) { + if previous_key.len() != 16 { + panic!("AES128: wrong previous key length"); + } + // First 4 bytes + let g_w3 = { + let mut result = + [previous_key[13], previous_key[14], previous_key[15], previous_key[12]]; + for (i, rcon) in AES128_RCON[round].iter().enumerate() { + let value = AES_SBOX[result[i] as usize]; + result[i] = value ^ rcon; + } + result + }; + + let prev = previous_key.to_vec().clone(); + for i in 0..4 { + let w = if i == 0 { + prev[0..4].iter().zip(g_w3.iter()).map(|(&a, &b)| a ^ b).collect::>() + } else { + prev[i * 4..(i + 1) * 4] + .iter() + .zip(previous_key[(i - 1) * 4..i * 4].iter()) + .map(|(&a, &b)| a ^ b) + .collect::>() + }; + previous_key[i * 4..(i + 1) * 4].copy_from_slice(&w); + } + } +} diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs new file mode 100644 index 000000000..46178d5e0 --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/aes128/mod.rs @@ -0,0 +1,2 @@ +pub mod encrypt; +pub mod utils; diff --git a/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs b/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs new file mode 100644 index 000000000..86858f98e --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/aes128/utils.rs @@ -0,0 +1,18 @@ +/// Multiply a byte by 2 in GF(2^8) with AES polynomial 0x9B +fn xtime(x: u8) -> u8 { + if x & 0x80 != 0 { + (x << 1) ^ 0x1b + } else { + x << 1 + } +} + +/// Multiply a byte by 1, 2, or 3 in GF(2^8) +pub fn mul_md5(x: u8, by: u8) -> u8 { + match by { + 1 => x, + 2 => xtime(x), + 3 => xtime(x) ^ x, // 3*x = (2*x) ⊕ x + _ => panic!("Only supports multipliers 1, 2, or 3"), + } +} diff --git a/crates/core/executor/src/syscalls/precompiles/mod.rs b/crates/core/executor/src/syscalls/precompiles/mod.rs index cc50a28f4..06eea78ab 100644 --- a/crates/core/executor/src/syscalls/precompiles/mod.rs +++ b/crates/core/executor/src/syscalls/precompiles/mod.rs @@ -1,3 +1,4 @@ +pub mod aes128; pub mod edwards; pub mod fptower; pub mod keccak; diff --git a/crates/core/machine/src/mips/mod.rs b/crates/core/machine/src/mips/mod.rs index bd99f2f0b..b913871ba 100644 --- a/crates/core/machine/src/mips/mod.rs +++ b/crates/core/machine/src/mips/mod.rs @@ -2,6 +2,7 @@ use crate::{ global::GlobalChip, memory::{MemoryChipType, MemoryLocalChip, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW}, syscall::precompiles::{ + aes128_encrypt::AES128EncryptChip, fptower::{Fp2AddSubAssignChip, Fp2MulAssignChip, FpOpChip}, poseidon2::Poseidon2PermuteChip, }, @@ -141,6 +142,8 @@ pub enum MipsAir { Secp256r1Double(WeierstrassDoubleAssignChip>), /// A precompile for the Poseidon2 permutation Poseidon2Permute(Poseidon2PermuteChip), + /// A precompile for AES-128 encryption + Aes128Encrypt(AES128EncryptChip), /// A precompile for the Keccak Sponge KeccakSponge(KeccakSpongeChip), /// A precompile for addition on the Elliptic curve bn254. @@ -272,6 +275,10 @@ impl MipsAir { costs.insert(poseidon2_permute.name(), poseidon2_permute.cost()); chips.push(poseidon2_permute); + let aes128_encrypt = Chip::new(MipsAir::Aes128Encrypt(AES128EncryptChip::new())); + costs.insert(aes128_encrypt.name(), 11 * aes128_encrypt.cost()); + chips.push(aes128_encrypt); + let keccak_sponge = Chip::new(MipsAir::KeccakSponge(KeccakSpongeChip::new())); costs.insert(keccak_sponge.name(), 24 * keccak_sponge.cost()); chips.push(keccak_sponge); @@ -574,6 +581,7 @@ impl MipsAir { Self::Sha256Compress(_) => 80, Self::Sha256Extend(_) => 48, Self::KeccakSponge(_) => 24, + Self::Aes128Encrypt(_) => 11, _ => 1, } } @@ -623,6 +631,7 @@ impl MipsAir { Self::Bls12381Fp2Mul(_) => SyscallCode::BLS12381_FP2_MUL, Self::Bls12381Fp2AddSub(_) => SyscallCode::BLS12381_FP2_ADD, Self::Poseidon2Permute(_) => SyscallCode::POSEIDON2_PERMUTE, + Self::Aes128Encrypt(_) => SyscallCode::AES128_ENCRYPT, Self::KeccakSponge(_) => SyscallCode::KECCAK_SPONGE, Self::SysLinux(_) => SyscallCode::SYS_LINUX, Self::Add(_) => unreachable!("Invalid for core chip"), diff --git a/crates/core/machine/src/operations/aes/aes_mul2.rs b/crates/core/machine/src/operations/aes/aes_mul2.rs new file mode 100644 index 000000000..591c5a546 --- /dev/null +++ b/crates/core/machine/src/operations/aes/aes_mul2.rs @@ -0,0 +1,95 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord}, + ByteOpcode, +}; +use zkm_derive::AlignedBorrow; +use zkm_stark::air::ZKMAirBuilder; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct MulBy2InAES { + pub and_0x80: T, + pub left_shift_1: T, + pub is_xor: T, // 0 or 1 + pub xor_0x1b: T, // also the result +} + +impl MulBy2InAES { + pub fn populate(&mut self, record: &mut impl ByteRecord, x: u8) -> u8 { + let and_0x80 = x & 0x80; + let left_shift_1 = x << 1; + let mut is_xor = 0_u8; + let xor_0x1b = if and_0x80 != 0 { + is_xor = 1; + left_shift_1 ^ 0x1b + } else { + left_shift_1 + }; + + self.and_0x80 = F::from_canonical_u8(and_0x80); + self.left_shift_1 = F::from_canonical_u8(left_shift_1); + self.xor_0x1b = F::from_canonical_u8(xor_0x1b); + self.is_xor = F::from_canonical_u8(is_xor); + + // Byte lookup events + let byte_event_and = + ByteLookupEvent { opcode: ByteOpcode::AND, a1: and_0x80 as u16, a2: 0, b: x, c: 0x80 }; + record.add_byte_lookup_event(byte_event_and); + + let byte_event_ssl = + ByteLookupEvent { opcode: ByteOpcode::SLL, a1: left_shift_1 as u16, a2: 0, b: x, c: 1 }; + record.add_byte_lookup_event(byte_event_ssl); + + if is_xor == 1 { + let byte_event_xor = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: xor_0x1b as u16, + a2: 0, + b: left_shift_1, + c: 0x1b, + }; + record.add_byte_lookup_event(byte_event_xor); + } + + xor_0x1b + } + + #[allow(unused_variables)] + pub fn eval( + builder: &mut AB, + x: AB::Var, + cols: MulBy2InAES, + is_real: AB::Var, + ) { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::AND as u32), + cols.and_0x80, + x, + AB::F::from_canonical_u32(0x80), + is_real, + ); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::SLL as u32), + cols.left_shift_1, + x, + AB::F::from_canonical_u32(1), + is_real, + ); + + builder.assert_bool(cols.is_xor); + // if cols.is_xor == 1, then and_0x80 == 128, else and_0x80 = 0 + builder.assert_eq(cols.and_0x80, cols.is_xor * AB::Expr::from_canonical_u8(128u8)); + + builder.assert_eq((AB::Expr::ONE - is_real.into()) * cols.is_xor, AB::Expr::ZERO); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.xor_0x1b, + cols.left_shift_1, + AB::F::from_canonical_u32(0x1b), + cols.is_xor, + ); + } +} diff --git a/crates/core/machine/src/operations/aes/aes_mul3.rs b/crates/core/machine/src/operations/aes/aes_mul3.rs new file mode 100644 index 000000000..10caabbd8 --- /dev/null +++ b/crates/core/machine/src/operations/aes/aes_mul3.rs @@ -0,0 +1,47 @@ +use crate::operations::aes_mul2::MulBy2InAES; +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord}, + ByteOpcode, +}; +use zkm_derive::AlignedBorrow; +use zkm_stark::air::ZKMAirBuilder; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct MulBy3InAES { + pub mul_by_2: MulBy2InAES, + pub xor_x: T, // also the result +} + +impl MulBy3InAES { + pub fn populate(&mut self, record: &mut impl ByteRecord, x: u8) -> u8 { + let x2 = self.mul_by_2.populate(record, x); + let xor_x = x2 ^ x; + self.xor_x = F::from_canonical_u8(xor_x); + + // Byte lookup event for the final XOR + let byte_event_xor = + ByteLookupEvent { opcode: ByteOpcode::XOR, a1: xor_x as u16, a2: 0, b: x2, c: x }; + record.add_byte_lookup_event(byte_event_xor); + xor_x + } + + #[allow(unused_variables)] + pub fn eval( + builder: &mut AB, + x: AB::Var, + cols: MulBy3InAES, + is_real: AB::Var, + ) { + MulBy2InAES::::eval(builder, x, cols.mul_by_2, is_real); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.xor_x, + cols.mul_by_2.xor_0x1b, + x, + is_real, + ); + } +} diff --git a/crates/core/machine/src/operations/aes/mix_column.rs b/crates/core/machine/src/operations/aes/mix_column.rs new file mode 100644 index 000000000..0fae83ec8 --- /dev/null +++ b/crates/core/machine/src/operations/aes/mix_column.rs @@ -0,0 +1,132 @@ +use crate::operations::aes::xor_byte_4::XorByte4; +use crate::operations::aes_mul2::MulBy2InAES; +use crate::operations::aes_mul3::MulBy3InAES; +use p3_field::Field; +use zkm_core_executor::events::ByteRecord; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct MixColumn { + pub mul_by_2s: [MulBy2InAES; 16], + pub mul_by_3s: [MulBy3InAES; 16], + pub xor_byte4s: [XorByte4; 16], +} + +impl MixColumn { + pub fn populate(&mut self, record: &mut impl ByteRecord, shifted_state: &[u8; 16]) -> [u8; 16] { + let mut mixed = [0u8; 16]; + for col in 0..4 { + let col_start = col * 4; + let s0 = shifted_state[col_start]; + let s1 = shifted_state[col_start + 1]; + let s2 = shifted_state[col_start + 2]; + let s3 = shifted_state[col_start + 3]; + + // mixed[col_start] = mul_md5(s0, 2) ^ mul_md5(s1, 3) ^ mul_md5(s2, 1) ^ mul_md5(s3, 1); + mixed[col_start] = { + let s0x2 = self.mul_by_2s[col_start].populate(record, s0); + let s1x3 = self.mul_by_3s[col_start].populate(record, s1); + self.xor_byte4s[col_start].populate(record, s0x2, s1x3, s2, s3) + }; + + // mixed[col_start + 1] = mul_md5(s0, 1) ^ mul_md5(s1, 2) ^ mul_md5(s2, 3) ^ mul_md5(s3, 1); + mixed[col_start + 1] = { + let s1x2 = self.mul_by_2s[col_start + 1].populate(record, s1); + let s2x3 = self.mul_by_3s[col_start + 1].populate(record, s2); + self.xor_byte4s[col_start + 1].populate(record, s0, s1x2, s2x3, s3) + }; + + // mixed[col_start + 2] = mul_md5(s0, 1) ^ mul_md5(s1, 1) ^ mul_md5(s2, 2) ^ mul_md5(s3, 3); + mixed[col_start + 2] = { + let s2x2 = self.mul_by_2s[col_start + 2].populate(record, s2); + let s3x3 = self.mul_by_3s[col_start + 2].populate(record, s3); + self.xor_byte4s[col_start + 2].populate(record, s0, s1, s2x2, s3x3) + }; + + // mixed[col_start + 3] = mul_md5(s0, 3) ^ mul_md5(s1, 1) ^ mul_md5(s2, 1) ^ mul_md5(s3, 2); + mixed[col_start + 3] = { + let s0x3 = self.mul_by_3s[col_start + 3].populate(record, s0); + let s3x2 = self.mul_by_2s[col_start + 3].populate(record, s3); + self.xor_byte4s[col_start + 3].populate(record, s0x3, s1, s2, s3x2) + } + } + mixed + } + + pub fn eval( + builder: &mut AB, + shifted_state: [AB::Var; 16], + cols: MixColumn, + is_real: AB::Var, + ) { + for col in 0..4 { + let col_start = col * 4; + let s0 = shifted_state[col_start]; + let s1 = shifted_state[col_start + 1]; + let s2 = shifted_state[col_start + 2]; + let s3 = shifted_state[col_start + 3]; + + // col_start + { + MulBy2InAES::::eval(builder, s0, cols.mul_by_2s[col_start], is_real); + MulBy3InAES::::eval(builder, s1, cols.mul_by_3s[col_start], is_real); + XorByte4::::eval( + builder, + cols.mul_by_2s[col_start].xor_0x1b, + cols.mul_by_3s[col_start].xor_x, + s2, + s3, + cols.xor_byte4s[col_start], + is_real, + ); + } + + // col_start + 1 + { + MulBy2InAES::::eval(builder, s1, cols.mul_by_2s[col_start + 1], is_real); + MulBy3InAES::::eval(builder, s2, cols.mul_by_3s[col_start + 1], is_real); + XorByte4::::eval( + builder, + s0, + cols.mul_by_2s[col_start + 1].xor_0x1b, + cols.mul_by_3s[col_start + 1].xor_x, + s3, + cols.xor_byte4s[col_start + 1], + is_real, + ) + } + + // col_start + 2 + { + MulBy2InAES::::eval(builder, s2, cols.mul_by_2s[col_start + 2], is_real); + MulBy3InAES::::eval(builder, s3, cols.mul_by_3s[col_start + 2], is_real); + XorByte4::::eval( + builder, + s0, + s1, + cols.mul_by_2s[col_start + 2].xor_0x1b, + cols.mul_by_3s[col_start + 2].xor_x, + cols.xor_byte4s[col_start + 2], + is_real, + ) + } + + // col_start + 3 + { + MulBy3InAES::::eval(builder, s0, cols.mul_by_3s[col_start + 3], is_real); + MulBy2InAES::::eval(builder, s3, cols.mul_by_2s[col_start + 3], is_real); + XorByte4::::eval( + builder, + cols.mul_by_3s[col_start + 3].xor_x, + s1, + s2, + cols.mul_by_2s[col_start + 3].xor_0x1b, + cols.xor_byte4s[col_start + 3], + is_real, + ) + } + } + } +} diff --git a/crates/core/machine/src/operations/aes/mod.rs b/crates/core/machine/src/operations/aes/mod.rs new file mode 100644 index 000000000..1ee16f771 --- /dev/null +++ b/crates/core/machine/src/operations/aes/mod.rs @@ -0,0 +1,6 @@ +pub mod aes_mul2; +pub mod aes_mul3; +pub mod mix_column; +pub mod round_key; +pub mod subs_byte; +pub mod xor_byte_4; diff --git a/crates/core/machine/src/operations/aes/round_key.rs b/crates/core/machine/src/operations/aes/round_key.rs new file mode 100644 index 000000000..113b2650f --- /dev/null +++ b/crates/core/machine/src/operations/aes/round_key.rs @@ -0,0 +1,196 @@ +use crate::operations::subs_byte::SubsByte; +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::events::{ByteLookupEvent, ByteRecord}; +use zkm_core_executor::ByteOpcode; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; + +pub const ROUND_CONST: [u8; 11] = + [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, 0x00]; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct NextRoundKey { + pub add_round_const: T, // XOR + pub w3_subs_byte: [SubsByte; 4], // for round key + // new round key + pub w4: [T; 4], + pub w5: [T; 4], + pub w6: [T; 4], + pub w7: [T; 4], +} + +impl NextRoundKey { + pub fn populate( + &mut self, + records: &mut impl ByteRecord, + prev_round_key: &[u8; 16], + round: u8, + ) -> [u8; 16] { + // check sbox values + let mut sub_rot_w3 = [0u8; 4]; + let shifted_w3 = + [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; + for i in 0..4 { + sub_rot_w3[i] = self.w3_subs_byte[i].populate(shifted_w3[i]); + } + + // previous round key + let w0 = &prev_round_key[0..4]; + let w1 = &prev_round_key[4..8]; + let w2 = &prev_round_key[8..12]; + let w3 = &prev_round_key[12..16]; + + let rcon = ROUND_CONST[round as usize]; + let first_byte = sub_rot_w3[0] ^ rcon; + self.add_round_const = F::from_canonical_u8(first_byte); + let first_byte_xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: first_byte as u16, + a2: 0, + b: sub_rot_w3[0], + c: rcon, + }; + records.add_byte_lookup_event(first_byte_xor_event); + // add constant + sub_rot_w3[0] = first_byte; + + // Compute new words + let mut new_key = [0u8; 16]; + // w4 = w0 ^ SubWord(RotWord(w3)) + for i in 0..4 { + new_key[i] = w0[i] ^ sub_rot_w3[i]; + self.w4[i] = F::from_canonical_u8(new_key[i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[i] as u16, + a2: 0, + b: w0[i], + c: sub_rot_w3[i], + }; + records.add_byte_lookup_event(xor_event); + } + + // w5 = w4 ^ w1 + for i in 0..4 { + new_key[4 + i] = new_key[i] ^ w1[i]; + self.w5[i] = F::from_canonical_u8(new_key[4 + i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[4 + i] as u16, + a2: 0, + b: new_key[i], + c: w1[i], + }; + records.add_byte_lookup_event(xor_event); + } + + // w6 = w5 ^ w2 + for i in 0..4 { + new_key[8 + i] = new_key[4 + i] ^ w2[i]; + self.w6[i] = F::from_canonical_u8(new_key[8 + i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[8 + i] as u16, + a2: 0, + b: new_key[4 + i], + c: w2[i], + }; + records.add_byte_lookup_event(xor_event); + } + + // w7 = w6 ^ w3 + for i in 0..4 { + new_key[12 + i] = new_key[8 + i] ^ w3[i]; + self.w7[i] = F::from_canonical_u8(new_key[12 + i]); + let xor_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: new_key[12 + i] as u16, + a2: 0, + b: new_key[8 + i], + c: w3[i], + }; + records.add_byte_lookup_event(xor_event); + } + + new_key + } + + pub fn eval( + builder: &mut AB, + cols: NextRoundKey, + prev_round_key: [AB::Var; 16], + rcon: AB::Var, + is_real: AB::Expr, + ) { + let w0 = &prev_round_key[0..4]; + let w1 = &prev_round_key[4..8]; + let w2 = &prev_round_key[8..12]; + let w3 = &prev_round_key[12..16]; + + let shifted_w3 = + [prev_round_key[13], prev_round_key[14], prev_round_key[15], prev_round_key[12]]; + for i in 0..4 { + SubsByte::::eval(builder, cols.w3_subs_byte[i], shifted_w3[i], is_real.clone()); + } + + // sbox substitution bytes. + let subs_byte_values: [AB::Var; 4] = cols.w3_subs_byte.map(|m| m.value); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.add_round_const, + subs_byte_values[0], + rcon, + is_real.clone(), + ); + + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w4[0], + w0[0], + cols.add_round_const, + is_real.clone(), + ); + + for i in 1..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w4[i], + w0[i], + subs_byte_values[i], + is_real.clone(), + ) + } + + for i in 0..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w5[i], + cols.w4[i], + w1[i], + is_real.clone(), + ) + } + + for i in 0..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w6[i], + cols.w5[i], + w2[i], + is_real.clone(), + ) + } + + for i in 0..4 { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.w7[i], + cols.w6[i], + w3[i], + is_real.clone(), + ) + } + } +} diff --git a/crates/core/machine/src/operations/aes/subs_byte.rs b/crates/core/machine/src/operations/aes/subs_byte.rs new file mode 100644 index 000000000..27509603c --- /dev/null +++ b/crates/core/machine/src/operations/aes/subs_byte.rs @@ -0,0 +1,115 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; +pub const AES_SBOX: [u8; 256] = [ + 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, + 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, + 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, + 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, + 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, + 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, + 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, + 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, + 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, + 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, +]; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct SubsByte { + pub positions: [[T; 32]; 4], + // if this byte is in range [0..=127] + pub is_left: T, + // the substituted byte + pub value: T, +} + +impl SubsByte { + pub fn populate(&mut self, byte: u8) -> u8 { + let mut pos = byte; + if byte > 127 { + self.is_left = F::ZERO; + pos -= 128; + } else { + self.is_left = F::ONE; + } + + for i in 0..128 { + let row = i / 32; + let col = i % 32; + if i as u8 == pos { + self.positions[row][col] = F::ONE; + } else { + self.positions[row][col] = F::ZERO; + } + } + + let substituted = AES_SBOX[byte as usize]; + self.value = F::from_canonical_u8(substituted); + substituted + } + + pub fn eval( + builder: &mut AB, + cols: SubsByte, + byte: AB::Var, + is_real: AB::Expr, + ) { + builder.assert_bool(cols.is_left); + builder.assert_bool(is_real.clone()); + // if is_real = 0 then is_left must be 0 + builder.assert_eq((AB::Expr::ONE - is_real.clone()) * cols.is_left, AB::Expr::ZERO); + + // exactly one position is selected + let sum_positions = cols + .positions + .iter() + .map(|&pos| pos.iter().map(|&p| p.into()).sum::()) + .sum::(); + builder.assert_eq(sum_positions, is_real.clone()); + + for i in 0..128 { + // positions are boolean + let row = i / 32; + let col = i % 32; + builder.assert_bool(cols.positions[row][col]); + builder.assert_eq( + (AB::Expr::ONE - is_real.clone()) * cols.positions[row][col], + AB::Expr::ZERO, + ); + // if is_left = 1 then byte = i else byte = i+128 + builder.assert_eq( + cols.is_left + * (AB::Expr::from_canonical_usize(i) - byte) + * cols.positions[row][col], + AB::Expr::ZERO, + ); + builder.assert_eq( + (AB::Expr::ONE - cols.is_left) + * (AB::Expr::from_canonical_usize(i + 128) - byte) + * cols.positions[row][col], + AB::Expr::ZERO, + ); + + // value = SBOX[byte] + builder.assert_eq( + cols.is_left + * (AB::Expr::from_canonical_u8(AES_SBOX[i]) - cols.value) + * cols.positions[row][col], + AB::Expr::ZERO, + ); + builder.assert_eq( + (AB::Expr::ONE - cols.is_left) + * (AB::Expr::from_canonical_u8(AES_SBOX[i + 128]) - cols.value) + * cols.positions[row][col], + AB::Expr::ZERO, + ); + } + } +} diff --git a/crates/core/machine/src/operations/aes/xor_byte_4.rs b/crates/core/machine/src/operations/aes/xor_byte_4.rs new file mode 100644 index 000000000..f1b18bc95 --- /dev/null +++ b/crates/core/machine/src/operations/aes/xor_byte_4.rs @@ -0,0 +1,79 @@ +use p3_field::{Field, FieldAlgebra}; +use zkm_core_executor::events::{ByteLookupEvent, ByteRecord}; +use zkm_core_executor::ByteOpcode; +use zkm_derive::AlignedBorrow; +use zkm_stark::ZKMAirBuilder; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +// x ^ y ^ z ^ w +pub struct XorByte4 { + pub interm1: T, + pub interm2: T, + pub value: T, +} + +impl XorByte4 { + pub fn populate(&mut self, record: &mut impl ByteRecord, x: u8, y: u8, z: u8, w: u8) -> u8 { + let xor_inter1 = x ^ y; + self.interm1 = F::from_canonical_u8(xor_inter1); + let byte_event = + ByteLookupEvent { opcode: ByteOpcode::XOR, a1: xor_inter1 as u16, a2: 0, b: x, c: y }; + record.add_byte_lookup_event(byte_event); + + let xor_inter2 = xor_inter1 ^ z; + self.interm2 = F::from_canonical_u8(xor_inter2); + let byte_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: xor_inter2 as u16, + a2: 0, + b: xor_inter1, + c: z, + }; + record.add_byte_lookup_event(byte_event); + + let result = xor_inter2 ^ w; + self.value = F::from_canonical_u8(result); + let byte_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: result as u16, + a2: 0, + b: xor_inter2, + c: w, + }; + record.add_byte_lookup_event(byte_event); + result + } + + pub fn eval( + builder: &mut AB, + x: AB::Var, + y: AB::Var, + z: AB::Var, + w: AB::Var, + cols: XorByte4, + is_real: AB::Var, + ) { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.interm1, + x, + y, + is_real, + ); + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.interm2, + cols.interm1, + z, + is_real, + ); + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + cols.value, + cols.interm2, + w, + is_real, + ); + } +} diff --git a/crates/core/machine/src/operations/mod.rs b/crates/core/machine/src/operations/mod.rs index 31a9537c7..667aed4de 100644 --- a/crates/core/machine/src/operations/mod.rs +++ b/crates/core/machine/src/operations/mod.rs @@ -8,6 +8,7 @@ mod add; mod add4; mod add5; mod adddouble; +mod aes; mod and; mod cmp; pub mod field; @@ -24,11 +25,11 @@ mod not; mod or; pub mod poseidon2; mod xor; - pub use add::*; pub use add4::*; pub use add5::*; pub use adddouble::*; +pub use aes::*; pub use and::*; pub use cmp::*; pub use fixed_rotate_right::*; diff --git a/crates/core/machine/src/runtime/syscall.rs b/crates/core/machine/src/runtime/syscall.rs index 902697912..8ac1b85ea 100644 --- a/crates/core/machine/src/runtime/syscall.rs +++ b/crates/core/machine/src/runtime/syscall.rs @@ -157,6 +157,9 @@ pub enum SyscallCode { /// Executes the `POSEIDON2_PERMUTE` precompile. POSEIDON2_PERMUTE = 0x00_00_01_30, + + /// Executes the `AES128_ENCRYPT` precompile. + AES128_ENCRYPT = 0x00_01_01_31, } impl SyscallCode { @@ -201,6 +204,7 @@ impl SyscallCode { 0x00_01_01_29 => SyscallCode::BN254_FP2_ADD, 0x00_01_01_2A => SyscallCode::BN254_FP2_SUB, 0x00_01_01_2B => SyscallCode::BN254_FP2_MUL, + 0x00_01_01_31 => SyscallCode::AES128_ENCRYPT, 0x00_00_01_1C => SyscallCode::BLS12381_DECOMPRESS, 0x00_00_01_30 => SyscallCode::POSEIDON2_PERMUTE, _ => panic!("invalid syscall number: {}", value), diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs new file mode 100644 index 000000000..fa968a6eb --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/air.rs @@ -0,0 +1,301 @@ +use crate::air::{MemoryAirBuilder, WordAirBuilder}; +use crate::memory::MemoryCols; +use crate::operations::mix_column::MixColumn; +use crate::operations::round_key::{NextRoundKey, ROUND_CONST}; +use crate::operations::subs_byte::SubsByte; +use crate::syscall::precompiles::aes128_encrypt::columns::{ + AES128EncryptionCols, NUM_AES128_ENCRYPTION_COLS, +}; +use crate::syscall::precompiles::aes128_encrypt::AES128EncryptChip; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::FieldAlgebra; +use p3_matrix::Matrix; +use std::borrow::Borrow; +use zkm_core_executor::events::AES_128_BLOCK_BYTES; +use zkm_core_executor::syscalls::SyscallCode; +use zkm_core_executor::ByteOpcode; +use zkm_stark::{LookupScope, ZKMAirBuilder}; + +impl BaseAir for AES128EncryptChip { + fn width(&self) -> usize { + NUM_AES128_ENCRYPTION_COLS + } +} + +impl Air for AES128EncryptChip +where + AB: ZKMAirBuilder, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local: &AES128EncryptionCols = (*local).borrow(); + let next: &AES128EncryptionCols = (*next).borrow(); + + builder.receive_syscall( + local.shard, + local.clk, + AB::F::from_canonical_u32(SyscallCode::AES128_ENCRYPT.syscall_id()), + local.block_address, + local.key_address, + local.receive_syscall, + LookupScope::Local, + ); + + self.eval_flags(builder, local); + self.eval_memory_access(builder, local); + self.eval_subs_byte(builder, local); + self.eval_mix_column(builder, local); + self.eval_add_round_key(builder, local); + self.eval_compute_round_key(builder, local); + self.eval_input_output(builder, local); + self.eval_transition(builder, local, next); + } +} + +impl AES128EncryptChip { + fn eval_flags( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let first_round = local.round[0]; + let last_round = local.round[10]; + for i in 0..11 { + builder.assert_bool(local.round[i]); + } + builder.assert_bool(local.round_1to9); + let mut computed_1to9 = AB::Expr::ZERO; + for i in 1..10 { + computed_1to9 = computed_1to9 + local.round[i]; + } + builder.assert_eq(computed_1to9, local.round_1to9); + builder.assert_eq(first_round * local.is_real, local.receive_syscall); + builder.assert_eq(last_round + first_round + local.round_1to9, local.is_real); + } + + fn eval_memory_access( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let mut round = AB::Expr::ZERO; + for i in 0..11 { + round = round + local.round[i] * AB::F::from_canonical_u32(i as u32); + } + + // if this is the first row, populate reading key + for i in 0..4 { + builder.eval_memory_access( + local.shard, + local.clk, + local.key_address + AB::F::from_canonical_u32(i as u32 * 4), + &local.key[i], + local.round[0], + ); + } + + // if this is the first row, populate reading input + for i in 0..4 { + builder.eval_memory_access( + local.shard, + local.clk, + local.block_address + AB::F::from_canonical_u32(i as u32 * 4), + &local.block[i], + local.round[0], + ); + } + + // if this is the last row, populate writing output + for i in 0..4 { + builder.eval_memory_access( + local.shard, + local.clk + AB::Expr::ONE, + local.block_address + AB::F::from_canonical_u32((i * 4) as u32), + &local.block[i], + local.round[10], + ); + } + } + + fn eval_subs_byte( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let round_1to10 = local.round_1to9 + local.round[10]; + for i in 0..16 { + SubsByte::::eval( + builder, + local.state_subs_byte[i], + local.state_matrix[i], + round_1to10.clone(), + ); + } + } + + fn eval_mix_column( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let shifted_state = [ + local.state_subs_byte[0].value, + local.state_subs_byte[5].value, + local.state_subs_byte[10].value, + local.state_subs_byte[15].value, + local.state_subs_byte[4].value, + local.state_subs_byte[9].value, + local.state_subs_byte[14].value, + local.state_subs_byte[3].value, + local.state_subs_byte[8].value, + local.state_subs_byte[13].value, + local.state_subs_byte[2].value, + local.state_subs_byte[7].value, + local.state_subs_byte[12].value, + local.state_subs_byte[1].value, + local.state_subs_byte[6].value, + local.state_subs_byte[11].value, + ]; + MixColumn::::eval(builder, shifted_state, local.mix_column, local.round_1to9); + } + + fn eval_add_round_key( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + for i in 0..AES_128_BLOCK_BYTES { + builder.send_byte( + AB::F::from_canonical_u32(ByteOpcode::XOR as u32), + local.add_round_key[i], + local.mix_column.xor_byte4s[i].value, + local.round_key_matrix[i], + local.is_real, + ) + } + } + + fn eval_compute_round_key( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + let round_0to9 = local.round_1to9 + local.round[0]; + NextRoundKey::::eval( + builder, + local.next_round_key, + local.round_key_matrix, + local.round_const, + round_0to9, + ); + + for i in 0..11 { + builder + .when(local.round[i]) + .assert_eq(local.round_const, AB::F::from_canonical_u8(ROUND_CONST[i])); + } + } + + fn eval_input_output( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + ) { + // In round 0, all state matrix values and key values are in [0, 255] + for i in 0..AES_128_BLOCK_BYTES { + builder.send_byte( + AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8), + AB::Expr::ZERO, + AB::Expr::ZERO, + local.state_matrix[i], + local.round[0], + ); + + builder.send_byte( + AB::Expr::from_canonical_u8(ByteOpcode::U8Range as u8), + AB::Expr::ZERO, + AB::Expr::ZERO, + local.round_key_matrix[i], + local.round[0], + ); + } + + // In round 0, state and key matrix should be derived from cols.block and cols.key + for i in 0..4 { + for j in 0..4 { + let idx = i * 4 + j; + builder + .when(local.round[0]) + .assert_eq(local.state_matrix[idx], local.block[i].access.value[j]); + builder + .when(local.round[0]) + .assert_eq(local.round_key_matrix[idx], local.key[i].access.value[j]); + } + } + + // In round 1-9, block should remain the same + for i in 0..4 { + builder + .when(local.round_1to9) + .assert_word_eq(*local.block[i].prev_value(), *local.block[i].value()); + } + + // In round 10, output block should be derived from state matrix + for i in 0..4 { + for j in 0..4 { + let idx = i * 4 + j; + builder + .when(local.round[10]) + .assert_eq(local.block[i].access.value[j], local.add_round_key[idx]); + } + } + } + + fn eval_transition( + &self, + builder: &mut AB, + local: &AES128EncryptionCols, + next: &AES128EncryptionCols, + ) { + // if it's not the last round, shard, clk remain the same + let round_0to9 = local.round_1to9 + local.round[0]; + builder.when(round_0to9.clone()).assert_eq(next.shard, local.shard); + builder.when(round_0to9.clone()).assert_eq(next.clk, local.clk); + + // key_address, block_address, sbox_address remain the same + builder.when(round_0to9.clone()).assert_eq(next.key_address, local.key_address); + builder.when(round_0to9.clone()).assert_eq(next.block_address, local.block_address); + + // round transition + for i in 0..10 { + builder.when(round_0to9.clone()).assert_eq(local.round[i], next.round[i + 1]); + } + + // state transition + for i in 0..AES_128_BLOCK_BYTES { + builder + .when(round_0to9.clone()) + .assert_eq(local.add_round_key[i], next.state_matrix[i]); + } + + // round key transition + for i in 0..4 { + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w4[i], next.round_key_matrix[i]); + + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w5[i], next.round_key_matrix[i + 4]); + + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w6[i], next.round_key_matrix[i + 8]); + + builder + .when(round_0to9.clone()) + .assert_eq(local.next_round_key.w7[i], next.round_key_matrix[i + 12]); + } + } +} diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs new file mode 100644 index 000000000..4acf6c7a8 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/columns.rs @@ -0,0 +1,31 @@ +use crate::memory::{MemoryReadCols, MemoryReadWriteCols}; +use crate::operations::mix_column::MixColumn; +use crate::operations::round_key::NextRoundKey; +use crate::operations::subs_byte::SubsByte; +use zkm_derive::AlignedBorrow; + +/// AES128EncryptCols is the column layout for the AES128 encryption. +/// The number of rows equal to the number of block. +#[derive(AlignedBorrow)] +#[repr(C)] +pub struct AES128EncryptionCols { + pub shard: T, + pub clk: T, + pub is_real: T, + pub key_address: T, + pub block_address: T, + pub receive_syscall: T, + pub key: [MemoryReadCols; 4], + pub block: [MemoryReadWriteCols; 4], + pub round: [T; 11], // [0,..,10] + pub round_1to9: T, // 1 to 9 + pub round_const: T, + pub state_matrix: [T; 16], + pub round_key_matrix: [T; 16], + pub state_subs_byte: [SubsByte; 16], + pub next_round_key: NextRoundKey, + pub mix_column: MixColumn, + pub add_round_key: [T; 16], // result of this round +} + +pub const NUM_AES128_ENCRYPTION_COLS: usize = size_of::>(); diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs new file mode 100644 index 000000000..f7d54bc12 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/mod.rs @@ -0,0 +1,26 @@ +mod air; +mod columns; +mod trace; + +#[derive(Default)] +pub struct AES128EncryptChip; + +impl AES128EncryptChip { + pub const fn new() -> Self { + Self {} + } +} + +#[cfg(test)] +pub mod tests { + use crate::utils::{self, run_test}; + use test_artifacts::AES128_ENCRYPT_ELF; + use zkm_core_executor::Program; + use zkm_stark::CpuProver; + #[test] + fn test_aes128_encrypt_program_prove() { + utils::setup_logger(); + let program = Program::from(AES128_ENCRYPT_ELF).unwrap(); + run_test::>(program).unwrap(); + } +} diff --git a/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs new file mode 100644 index 000000000..b90ff14d2 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/aes128_encrypt/trace.rs @@ -0,0 +1,228 @@ +use std::borrow::BorrowMut; + +use super::{columns::NUM_AES128_ENCRYPTION_COLS, AES128EncryptChip}; +use crate::operations::round_key::ROUND_CONST; +use crate::syscall::precompiles::aes128_encrypt::columns::AES128EncryptionCols; +use hashbrown::HashMap; +use itertools::Itertools; +use p3_field::PrimeField32; +use p3_matrix::dense::RowMajorMatrix; +use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; +use zkm_core_executor::events::{ + AES128EncryptEvent, MemoryRecordEnum, AES_128_BLOCK_BYTES, AES_128_BLOCK_U32S, +}; +use zkm_core_executor::{ + events::{ByteLookupEvent, ByteRecord, PrecompileEvent}, + syscalls::SyscallCode, + ByteOpcode, ExecutionRecord, Program, +}; +use zkm_stark::air::MachineAir; + +impl MachineAir for AES128EncryptChip { + type Record = ExecutionRecord; + type Program = Program; + + fn name(&self) -> String { + "Aes128Encrypt".to_string() + } + + fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { + let events = input.get_precompile_events(SyscallCode::AES128_ENCRYPT); + let chunk_size = std::cmp::max(events.len() / num_cpus::get(), 1); + + let blu_batches = events + .par_chunks(chunk_size) + .map(|events| { + let mut blu: HashMap = HashMap::new(); + events.iter().for_each(|(_, event)| { + let event = if let PrecompileEvent::Aes128Encrypt(event) = event { + event + } else { + unreachable!(); + }; + + self.event_to_rows::(event, &mut None, &mut blu); + }); + blu + }) + .collect::>(); + + output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec()); + } + + fn generate_trace( + &self, + input: &Self::Record, + _output: &mut Self::Record, + ) -> RowMajorMatrix { + let rows = Vec::new(); + log::info!("generate trace"); + + let mut wrapped_rows = Some(rows); + for (_, event) in input.get_precompile_events(SyscallCode::AES128_ENCRYPT) { + let event = if let PrecompileEvent::Aes128Encrypt(event) = event { + event + } else { + unreachable!(); + }; + self.event_to_rows(event, &mut wrapped_rows, &mut Vec::new()); + } + let mut rows = wrapped_rows.unwrap(); + let num_real_rows = rows.len(); + let padded_num_rows = num_real_rows.next_power_of_two(); + for _ in num_real_rows..padded_num_rows { + let row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; + rows.push(row); + } + RowMajorMatrix::new( + rows.into_iter().flatten().collect::>(), + NUM_AES128_ENCRYPTION_COLS, + ) + } + + fn included(&self, shard: &Self::Record) -> bool { + if let Some(shape) = shard.shape.as_ref() { + shape.included::(self) + } else { + !shard.get_precompile_events(SyscallCode::AES128_ENCRYPT).is_empty() + } + } +} + +impl AES128EncryptChip { + pub fn event_to_rows( + &self, + event: &AES128EncryptEvent, + rows: &mut Option>, + blu: &mut impl ByteRecord, + ) { + let num_round = 11; + let mut state = [0_u8; AES_128_BLOCK_BYTES]; + let mut round_key = [0_u8; AES_128_BLOCK_BYTES]; + for i in 0..AES_128_BLOCK_U32S { + state[i * 4..i * 4 + 4].copy_from_slice(&event.input[i].to_le_bytes()); + round_key[i * 4..i * 4 + 4].copy_from_slice(&event.key[i].to_le_bytes()); + } + for round in 0..num_round { + let mut row = [F::ZERO; NUM_AES128_ENCRYPTION_COLS]; + let cols: &mut AES128EncryptionCols = row.as_mut_slice().borrow_mut(); + cols.shard = F::from_canonical_u32(event.shard); + cols.clk = F::from_canonical_u32(event.clk); + cols.is_real = F::ONE; + cols.key_address = F::from_canonical_u32(event.key_addr); + cols.block_address = F::from_canonical_u32(event.block_addr); + cols.round = [F::ZERO; 11]; + cols.round[round] = F::ONE; + cols.receive_syscall = F::from_bool(round == 0); + cols.round_1to9 = F::from_bool((1..=9).contains(&round)); + cols.round_const = F::from_canonical_u8(ROUND_CONST[round]); + + for i in 0..AES_128_BLOCK_BYTES { + cols.state_matrix[i] = F::from_canonical_u8(state[i]); + cols.round_key_matrix[i] = F::from_canonical_u8(round_key[i]); + } + // all the state matrix values and key should be in [0, 255] in round 0 + if round == 0 { + for i in 0..AES_128_BLOCK_BYTES { + blu.add_u8_range_check(0, state[i]); + blu.add_u8_range_check(0, round_key[i]); + } + } + + if round == 0 { + // read the input + for i in 0..AES_128_BLOCK_U32S { + cols.block[i] + .populate(MemoryRecordEnum::Read(event.input_read_records[i]), blu); + } + // read the key + for i in 0..AES_128_BLOCK_U32S { + cols.key[i].populate(event.key_read_records[i], blu); + } + + // the mix column value should be the state + for i in 0..AES_128_BLOCK_BYTES { + cols.mix_column.xor_byte4s[i].value = F::from_canonical_u8(state[i]); + } + + // add_round_key + for i in 0..AES_128_BLOCK_BYTES { + let tmp = state[i] ^ round_key[i]; + cols.add_round_key[i] = F::from_canonical_u8(tmp); + let byte_lookup_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: tmp as u16, + a2: 0, + b: state[i], + c: round_key[i], + }; + blu.add_byte_lookup_event(byte_lookup_event); + state[i] = tmp; + } + } else { + // subs_bytes + for i in 0..AES_128_BLOCK_BYTES { + let subs_value = cols.state_subs_byte[i].populate(state[i]); + state[i] = subs_value; + } + + // shift_rows + let shifted_row = [ + state[0], state[5], state[10], state[15], state[4], state[9], state[14], + state[3], state[8], state[13], state[2], state[7], state[12], state[1], + state[6], state[11], + ]; + + // Mix columns + let mixed_columns = if round != 10 { + cols.mix_column.populate(blu, &shifted_row) + } else { + for i in 0..AES_128_BLOCK_BYTES { + cols.mix_column.xor_byte4s[i].value = F::from_canonical_u8(shifted_row[i]); + } + shifted_row + }; + + // Add round key + for i in 0..AES_128_BLOCK_BYTES { + state[i] = mixed_columns[i] ^ round_key[i]; + cols.add_round_key[i] = F::from_canonical_u8(state[i]); + let byte_lookup_event = ByteLookupEvent { + opcode: ByteOpcode::XOR, + a1: state[i] as u16, + a2: 0, + b: mixed_columns[i], + c: round_key[i], + }; + blu.add_byte_lookup_event(byte_lookup_event); + } + } + + if round != 10 { + // compute next round key + let next_round_key = cols.next_round_key.populate(blu, &round_key, round as u8); + round_key = next_round_key; + } else { + for i in 0..4 { + // check output + let tmp = event.output_write_records[i].value.to_le_bytes(); + for _ in 0..4 { + assert_eq!(state[i * 4], tmp[0]); + assert_eq!(state[i * 4 + 1], tmp[1]); + assert_eq!(state[i * 4 + 2], tmp[2]); + assert_eq!(state[i * 4 + 3], tmp[3]); + } + } + // write output + for i in 0..AES_128_BLOCK_U32S { + cols.block[i] + .populate(MemoryRecordEnum::Write(event.output_write_records[i]), blu); + } + } + + if rows.as_ref().is_some() { + rows.as_mut().unwrap().push(row); + } + } + } +} diff --git a/crates/core/machine/src/syscall/precompiles/mod.rs b/crates/core/machine/src/syscall/precompiles/mod.rs index 6a5c37835..e011c11fe 100644 --- a/crates/core/machine/src/syscall/precompiles/mod.rs +++ b/crates/core/machine/src/syscall/precompiles/mod.rs @@ -1,3 +1,4 @@ +pub mod aes128_encrypt; pub mod edwards; pub mod fptower; pub mod keccak_sponge; diff --git a/crates/stark/src/opts.rs b/crates/stark/src/opts.rs index 78f284dcf..3d380cd9d 100644 --- a/crates/stark/src/opts.rs +++ b/crates/stark/src/opts.rs @@ -204,6 +204,8 @@ pub struct SplitOpts { pub sha_extend: usize, /// The threshold for sha compress events. pub sha_compress: usize, + /// The threshold for aes128 encrypt events. + pub aes128_encrypt: usize, /// The threshold for memory events. pub memory: usize, } @@ -217,6 +219,7 @@ impl SplitOpts { keccak: 8 * deferred_split_threshold / 24, sha_extend: 32 * deferred_split_threshold / 48, sha_compress: 32 * deferred_split_threshold / 80, + aes128_encrypt: 32 * deferred_split_threshold / 11, memory: 64 * deferred_split_threshold, } } diff --git a/crates/test-artifacts/guests/Cargo.toml b/crates/test-artifacts/guests/Cargo.toml index 79b8ba081..e2868abc6 100644 --- a/crates/test-artifacts/guests/Cargo.toml +++ b/crates/test-artifacts/guests/Cargo.toml @@ -44,6 +44,7 @@ members = [ "verify-proof", "u256x2048-mul", "unconstrained", + "aes128", ] resolver = "2" diff --git a/crates/test-artifacts/guests/aes128/Cargo.toml b/crates/test-artifacts/guests/aes128/Cargo.toml new file mode 100644 index 000000000..335358d5c --- /dev/null +++ b/crates/test-artifacts/guests/aes128/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "aes128" +version = "1.1.0" +edition = "2021" +publish = false + +[dependencies] +zkm-zkvm = { path = "../../../../crates/zkvm/entrypoint" } diff --git a/crates/test-artifacts/guests/aes128/src/main.rs b/crates/test-artifacts/guests/aes128/src/main.rs new file mode 100644 index 000000000..d90db5ad3 --- /dev/null +++ b/crates/test-artifacts/guests/aes128/src/main.rs @@ -0,0 +1,13 @@ +#![no_std] +#![no_main] +zkm_zkvm::entrypoint!(main); + +use zkm_zkvm::lib::aes128::aes128_encrypt; + +pub fn main() { + for _ in 0..1 { + let mut state = [0u8; 16]; + let key = [0u8; 16]; + aes128_encrypt(&mut state, &key); + } +} diff --git a/crates/test-artifacts/src/lib.rs b/crates/test-artifacts/src/lib.rs index 2f8a21c9d..309202176 100644 --- a/crates/test-artifacts/src/lib.rs +++ b/crates/test-artifacts/src/lib.rs @@ -8,6 +8,7 @@ pub const HELLO_WORLD_ELF: &[u8] = include_elf!("hello-world"); pub const POSEIDON2_PERMUTE_ELF: &[u8] = include_elf!("poseidon2-permute-test"); +pub const AES128_ENCRYPT_ELF: &[u8] = include_elf!("aes128"); pub const SHA2_ELF: &[u8] = include_elf!("sha2-test"); pub const SHA_EXTEND_ELF: &[u8] = include_elf!("sha-extend-test"); pub const SHA_COMPRESS_ELF: &[u8] = include_elf!("sha-compress-test"); diff --git a/crates/zkvm/entrypoint/src/syscalls/aes128.rs b/crates/zkvm/entrypoint/src/syscalls/aes128.rs new file mode 100644 index 000000000..97b9fda17 --- /dev/null +++ b/crates/zkvm/entrypoint/src/syscalls/aes128.rs @@ -0,0 +1,25 @@ +#[cfg(target_os = "zkvm")] +use core::arch::asm; + +/// Executes the AES-128 encryption on the given block with the given key. +/// +/// ### Safety +/// +/// The caller must ensure that `state` and `key` are valid pointers to data that are aligned along +/// a four byte boundary. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4]) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "syscall", + in("$2") crate::syscalls::AES128_ENCRYPT, + in("$4") state, + in("$5") key, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} diff --git a/crates/zkvm/entrypoint/src/syscalls/mod.rs b/crates/zkvm/entrypoint/src/syscalls/mod.rs index b1aa1d61a..46838f440 100644 --- a/crates/zkvm/entrypoint/src/syscalls/mod.rs +++ b/crates/zkvm/entrypoint/src/syscalls/mod.rs @@ -1,3 +1,4 @@ +mod aes128; mod bigint; mod bls12381; mod bn254; @@ -19,6 +20,7 @@ mod unconstrained; #[cfg(feature = "verify")] mod verify; +pub use aes128::*; pub use bigint::*; pub use bls12381::*; pub use bn254::*; @@ -162,3 +164,6 @@ pub const BN254_FP2_MUL: u32 = 0x01_01_00_2B; /// Executes the `POSEIDON2_PERMUTE` precompile. pub const POSEIDON2_PERMUTE: u32 = 0x00_01_00_30; + +/// Executes the `AES128_ENCRYPT` precompile. +pub const AES128_ENCRYPT: u32 = 0x01_01_00_31; diff --git a/crates/zkvm/lib/src/aes128.rs b/crates/zkvm/lib/src/aes128.rs new file mode 100644 index 000000000..648209d27 --- /dev/null +++ b/crates/zkvm/lib/src/aes128.rs @@ -0,0 +1,22 @@ +use crate::syscall_aes128_encrypt; + +pub fn aes128_encrypt(state: &mut [u8; 16], key: &[u8; 16]) { + // convert to u32 to align the memory + let mut state_u32 = [0u32; 4]; + let mut key_u32 = [0u32; 4]; + + for i in 0..4 { + state_u32[i] = u32::from_le_bytes([ + state[i * 4], + state[i * 4 + 1], + state[i * 4 + 2], + state[i * 4 + 3], + ]); + key_u32[i] = + u32::from_le_bytes([key[i * 4], key[i * 4 + 1], key[i * 4 + 2], key[i * 4 + 3]]); + } + unsafe { syscall_aes128_encrypt(&mut state_u32, &key_u32) } + for i in 0..4 { + state[4 * i..4 * i + 4].copy_from_slice(&state_u32[i].to_le_bytes()); + } +} diff --git a/crates/zkvm/lib/src/lib.rs b/crates/zkvm/lib/src/lib.rs index 35dd71a0d..dda6e537e 100644 --- a/crates/zkvm/lib/src/lib.rs +++ b/crates/zkvm/lib/src/lib.rs @@ -3,11 +3,11 @@ //! Documentation for these syscalls can be found in the zkVM entrypoint //! `zkm_zkvm::syscalls` module. +pub mod aes128; pub mod bls12381; pub mod bn254; #[cfg(feature = "ecdsa")] pub mod ecdsa; - pub mod ed25519; pub mod io; pub mod keccak256; @@ -78,6 +78,9 @@ extern "C" { /// Executes the Poseidon2 permutation pub fn syscall_poseidon2_permute(state: *mut [u32; 16]); + /// Executes the AES-128 encryption on the given state with the given key. + pub fn syscall_aes128_encrypt(state: *mut [u32; 4], key: *const [u32; 4]); + /// Executes an uint256 multiplication on the given inputs. pub fn syscall_uint256_mulmod(x: *mut [u32; 8], y: *const [u32; 8]); diff --git a/examples/Cargo.lock b/examples/Cargo.lock index ba80a1e3e..7e53311da 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -49,6 +49,24 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "aes128" +version = "1.1.0" +dependencies = [ + "zkm-zkvm", +] + +[[package]] +name = "aes128-host" +version = "1.1.0" +dependencies = [ + "hex", + "log", + "tracing", + "zkm-build", + "zkm-sdk", +] + [[package]] name = "aggregation" version = "1.1.0" @@ -8721,6 +8739,7 @@ dependencies = [ name = "zkm-core-executor" version = "1.2.0" dependencies = [ + "aes", "anyhow", "bincode", "bytemuck", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 504354292..6cb0ee51a 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -1,5 +1,7 @@ [workspace] members = [ + "aes128/guest", + "aes128/host", "aggregation/guest", "aggregation/host", "bn254/guest", diff --git a/examples/aes128/guest/Cargo.toml b/examples/aes128/guest/Cargo.toml new file mode 100644 index 000000000..602f6a47e --- /dev/null +++ b/examples/aes128/guest/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "aes128" +version = "1.1.0" +edition = "2021" +publish = false + +[dependencies] +zkm-zkvm = { path = "../../../crates/zkvm/entrypoint" } + diff --git a/examples/aes128/guest/src/main.rs b/examples/aes128/guest/src/main.rs new file mode 100644 index 000000000..97e499337 --- /dev/null +++ b/examples/aes128/guest/src/main.rs @@ -0,0 +1,28 @@ +#![no_std] +#![no_main] + +extern crate alloc; + +use alloc::vec::Vec; +use core::cmp::min; +use zkm_zkvm::lib::aes128::aes128_encrypt; +zkm_zkvm::entrypoint!(main); + +pub fn main() { + let plain_text: Vec = zkm_zkvm::io::read(); + let key: [u8; 16] = zkm_zkvm::io::read(); + let iv: [u8; 16] = zkm_zkvm::io::read(); + let output = cipher_block_chaining(&plain_text, &key, &iv); + zkm_zkvm::io::commit::<[u8; 16]>(&output); +} + +fn cipher_block_chaining(input: &[u8], key: &[u8; 16], iv: &[u8; 16]) -> [u8; 16] { + let mut block = iv.clone(); + for chunk in input.chunks(16) { + for i in 0..min(chunk.len(), 16) { + block[i] ^= chunk[i]; + } + aes128_encrypt(&mut block, key); + } + block +} diff --git a/examples/aes128/host/Cargo.toml b/examples/aes128/host/Cargo.toml new file mode 100644 index 000000000..dc0fe3c83 --- /dev/null +++ b/examples/aes128/host/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "aes128-host" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[dependencies] +zkm-sdk = { workspace = true } +tracing = { workspace = true } +hex = "0.4.3" +log = "0.4.27" + +[build-dependencies] +zkm-build = { workspace = true } diff --git a/examples/aes128/host/build.rs b/examples/aes128/host/build.rs new file mode 100644 index 000000000..e7ea36ace --- /dev/null +++ b/examples/aes128/host/build.rs @@ -0,0 +1,3 @@ +fn main() { + zkm_build::build_program("../guest"); +} diff --git a/examples/aes128/host/src/main.rs b/examples/aes128/host/src/main.rs new file mode 100644 index 000000000..544d6a697 --- /dev/null +++ b/examples/aes128/host/src/main.rs @@ -0,0 +1,51 @@ +use std::env; +use zkm_sdk::{include_elf, utils, ProverClient, ZKMStdin}; + +/// The ELF we want to execute inside the zkVM. +const ELF: &[u8] = include_elf!("aes128"); +fn prove_aes128_rust() { + let mut stdin = ZKMStdin::new(); + + // load input + let plain_text = vec![ + 21_u8, 2, 23, 21, 1, 1, 2, 2, 2, 7, 128, 21, 25, 57, 247, 26, 35, 97, 244, 57, 25, 124, + 234, 234, 234, 214, 134, 135, 246, 17, 29, 7, + ]; + let key = [0_u8; 16]; + let iv = [0_u8; 16]; + + let expected_output = + vec![97_u8, 203, 140, 117, 36, 211, 41, 97, 177, 36, 93, 148, 107, 228, 201, 129]; + + stdin.write(&plain_text); + stdin.write(&key); + stdin.write(&iv); + + // Create a `ProverClient` method. + let client = ProverClient::new(); + + // Execute the program using the `ProverClient.execute` method, without generating a proof. + let (_, report) = client.execute(ELF, stdin.clone()).run().unwrap(); + println!("executed program with {} cycles", report.total_instruction_count()); + + // // Generate the proof for the given program and input. + let (pk, vk) = client.setup(ELF); + let mut proof = client.prove(&pk, stdin).run().unwrap(); + println!("generated proof"); + + // Read and verify the output. + // + // Note that this output is read from values committed to in the program using + // `zkm_zkvm::io::commit`. + let public_input = proof.public_values.read::<[u8; 16]>(); + assert_eq!(expected_output, public_input); + + // Verify proof and public values + client.verify(&proof, &vk).expect("verification failed"); + println!("successfully generated and verified proof for the program!"); +} + +fn main() { + utils::setup_logger(); + prove_aes128_rust(); +} diff --git a/examples/fibonacci_c_lib/host/build.rs b/examples/fibonacci_c_lib/host/build.rs index 6c6031ecb..eff897a02 100644 --- a/examples/fibonacci_c_lib/host/build.rs +++ b/examples/fibonacci_c_lib/host/build.rs @@ -3,4 +3,4 @@ use zkm_build::{build_program_with_args, BuildArgs}; fn main() { zkm_build::build_program("../guest"); -} \ No newline at end of file +} diff --git a/examples/large-sum/host/Cargo.toml b/examples/large-sum/host/Cargo.toml index 970134117..ca2f0cdab 100644 --- a/examples/large-sum/host/Cargo.toml +++ b/examples/large-sum/host/Cargo.toml @@ -13,21 +13,21 @@ zkm-sdk = { workspace = true } [build-dependencies] zkm-build = { workspace = true } -[[bin]] -name = "plonk_bn254" -path = "bin/plonk_bn254.rs" - -[[bin]] -name = "groth16_bn254" -path = "bin/groth16_bn254.rs" - -[[bin]] -name = "compressed" -path = "bin/compressed.rs" - -[[bin]] -name = "execute" -path = "bin/execute.rs" +#[[bin]] +#name = "plonk_bn254" +#path = "bin/plonk_bn254.rs" +# +#[[bin]] +#name = "groth16_bn254" +#path = "bin/groth16_bn254.rs" +# +#[[bin]] +#name = "compressed" +#path = "bin/compressed.rs" +# +#[[bin]] +#name = "execute" +#path = "bin/execute.rs" [[bin]] name = "large-sum-host"