diff --git a/plonkish_backend/Cargo.toml b/plonkish_backend/Cargo.toml index e22ceb94..3670971d 100644 --- a/plonkish_backend/Cargo.toml +++ b/plonkish_backend/Cargo.toml @@ -16,6 +16,10 @@ serde = { version = "1.0", features = ["derive"] } bincode = "1.3.3" sha3 = "0.10.6" poseidon = { git = "https://github.com/han0110/poseidon", branch = "feature/with-spec" } +ff = { version = "0.13", features = ["bits"] } +group = "0.13" +proptest = { version = "1.0.0", optional = true } +uint = "0.9.2" # timer ark-std = { version = "^0.4.0", default-features = false, optional = true } @@ -30,6 +34,7 @@ halo2_proofs = { git = "https://github.com/han0110/halo2.git", branch = "feature paste = "1.0.11" criterion = "0.4.0" pprof = { version = "0.11.0", features = ["criterion", "flamegraph"] } +proptest = "1.0.0" [features] default = ["parallel", "frontend-halo2"] diff --git a/plonkish_backend/src/backend/hyperplonk/verifier.rs b/plonkish_backend/src/backend/hyperplonk/verifier.rs index 4a20ef67..f9f6049e 100644 --- a/plonkish_backend/src/backend/hyperplonk/verifier.rs +++ b/plonkish_backend/src/backend/hyperplonk/verifier.rs @@ -48,6 +48,8 @@ pub(crate) fn verify_sum_check( y: &[F], transcript: &mut impl FieldTranscriptRead, ) -> Result<(Vec>, Vec>), Error> { + // In description of the sum check protocol in https://eprint.iacr.org/2022/1355.pdf, + // x_eval corresponds to the v in the final check, x is \alpha_1, ..., \alpha_{\mu} the challenges during the sum check protocol let (x_eval, x) = ClassicSumCheck::, BinaryField>::verify( &(), num_vars, @@ -56,6 +58,7 @@ pub(crate) fn verify_sum_check( transcript, )?; + // Check that f(\alpha_1, ..., \alpha_{\mu}) = v holds, where f is the expression let pcs_query = pcs_query(expression, instances.len()); let (evals_for_rotation, evals) = pcs_query .iter() @@ -79,6 +82,7 @@ pub(crate) fn verify_sum_check( )); } + // Obtain points and evaluations for which we need to verify the opening proofs let point_offset = point_offset(&pcs_query); let evals = pcs_query .iter() diff --git a/plonkish_backend/src/circuits.rs b/plonkish_backend/src/circuits.rs new file mode 100644 index 00000000..5eaed0c2 --- /dev/null +++ b/plonkish_backend/src/circuits.rs @@ -0,0 +1 @@ +pub mod poseidongadget; diff --git a/plonkish_backend/src/circuits/poseidongadget.rs b/plonkish_backend/src/circuits/poseidongadget.rs new file mode 100644 index 00000000..430d0ec4 --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget.rs @@ -0,0 +1,14 @@ +// This crate been taking from privacy scaling explorations crate: https://github.com/privacy-scaling-explorations/poseidon-gadget/tree/main/src +//! This crate provides the poseidon gadget for use with `halo2_proofs`. +//! This gadget has been extracted from zcash's halo2_gadgets: +//! https://github.com/zcash/halo2/tree/main/halo2_gadgets + +#![cfg_attr(docsrs, feature(doc_cfg))] +// Catch documentation errors caused by code changes. +#![deny(rustdoc::broken_intra_doc_links)] +#![deny(missing_debug_implementations)] +#![deny(missing_docs)] +#![deny(unsafe_code)] + +pub mod poseidon; +pub mod utilities; diff --git a/plonkish_backend/src/circuits/poseidongadget/poseidon.rs b/plonkish_backend/src/circuits/poseidongadget/poseidon.rs new file mode 100644 index 00000000..e93e7a2c --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/poseidon.rs @@ -0,0 +1,298 @@ +//! The Poseidon algebraic hash function. + +use std::convert::TryInto; +use std::fmt; +use std::marker::PhantomData; + +use ff::PrimeField; +use group::ff::Field; +use halo2_proofs::{ + circuit::{AssignedCell, Chip, Layouter}, + plonk::Error, +}; + +mod pow5; +pub use pow5::{Pow5Chip, Pow5Config, StateWord}; + +pub mod primitives; +use primitives::{Absorbing, ConstantLength, Domain, Spec, SpongeMode, Squeezing, State}; + +/// A word from the padded input to a Poseidon sponge. +#[derive(Clone, Debug)] +pub enum PaddedWord { + /// A message word provided by the prover. + Message(AssignedCell), + /// A padding word, that will be fixed in the circuit parameters. + Padding(F), +} + +/// The set of circuit instructions required to use the Poseidon permutation. +pub trait PoseidonInstructions, const T: usize, const RATE: usize>: + Chip +{ + /// Variable representing the word over which the Poseidon permutation operates. + type Word: Clone + fmt::Debug + From> + Into>; + + /// Applies the Poseidon permutation to the given state. + fn permute( + &self, + layouter: &mut impl Layouter, + initial_state: &State, + ) -> Result, Error>; +} + +/// The set of circuit instructions required to use the [`Sponge`] and [`Hash`] gadgets. +/// +/// [`Hash`]: self::Hash +pub trait PoseidonSpongeInstructions< + F: Field, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, +>: PoseidonInstructions +{ + /// Returns the initial empty state for the given domain. + fn initial_state(&self, layouter: &mut impl Layouter) + -> Result, Error>; + + /// Adds the given input to the state. + fn add_input( + &self, + layouter: &mut impl Layouter, + initial_state: &State, + input: &Absorbing, RATE>, + ) -> Result, Error>; + + /// Extracts sponge output from the given state. + fn get_output(state: &State) -> Squeezing; +} + +/// A word over which the Poseidon permutation operates. +#[derive(Debug)] +pub struct Word< + F: Field, + PoseidonChip: PoseidonInstructions, + S: Spec, + // Width + const T: usize, + const RATE: usize, +> { + inner: PoseidonChip::Word, +} + +impl< + F: Field, + PoseidonChip: PoseidonInstructions, + S: Spec, + const T: usize, + const RATE: usize, + > Word +{ + /// The word contained in this gadget. + pub fn inner(&self) -> PoseidonChip::Word { + self.inner.clone() + } + + /// Construct a [`Word`] gadget from the inner word. + pub fn from_inner(inner: PoseidonChip::Word) -> Self { + Self { inner } + } +} + +fn poseidon_sponge< + F: Field, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, +>( + chip: &PoseidonChip, + mut layouter: impl Layouter, + state: &mut State, + input: Option<&Absorbing, RATE>>, +) -> Result, Error> { + if let Some(input) = input { + *state = chip.add_input(&mut layouter, state, input)?; + } + *state = chip.permute(&mut layouter, state)?; + Ok(PoseidonChip::get_output(state)) +} + +/// A Poseidon sponge. +#[derive(Debug)] +pub struct Sponge< + F: Field, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + M: SpongeMode, + D: Domain, + const T: usize, + const RATE: usize, +> { + chip: PoseidonChip, + mode: M, + state: State, + _marker: PhantomData, +} + +impl< + F: Field, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, + > Sponge, RATE>, D, T, RATE> +{ + /// Constructs a new duplex sponge for the given Poseidon specification. + pub fn new(chip: PoseidonChip, mut layouter: impl Layouter) -> Result { + chip.initial_state(&mut layouter).map(|state| Sponge { + chip, + mode: Absorbing( + (0..RATE) + .map(|_| None) + .collect::>() + .try_into() + .unwrap(), + ), + state, + _marker: PhantomData, + }) + } + + /// Absorbs an element into the sponge. + pub fn absorb( + &mut self, + mut layouter: impl Layouter, + value: PaddedWord, + ) -> Result<(), Error> { + for entry in self.mode.0.iter_mut() { + if entry.is_none() { + *entry = Some(value); + return Ok(()); + } + } + + // We've already absorbed as many elements as we can + let _ = poseidon_sponge( + &self.chip, + layouter.namespace(|| "PoseidonSponge"), + &mut self.state, + Some(&self.mode), + )?; + self.mode = Absorbing::init_with(value); + + Ok(()) + } + + /// Transitions the sponge into its squeezing state. + #[allow(clippy::type_complexity)] + pub fn finish_absorbing( + mut self, + mut layouter: impl Layouter, + ) -> Result, D, T, RATE>, Error> + { + let mode = poseidon_sponge( + &self.chip, + layouter.namespace(|| "PoseidonSponge"), + &mut self.state, + Some(&self.mode), + )?; + + Ok(Sponge { + chip: self.chip, + mode, + state: self.state, + _marker: PhantomData, + }) + } +} + +impl< + F: Field, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, + > Sponge, D, T, RATE> +{ + /// Squeezes an element from the sponge. + pub fn squeeze(&mut self, mut layouter: impl Layouter) -> Result, Error> { + loop { + for entry in self.mode.0.iter_mut() { + if let Some(inner) = entry.take() { + return Ok(inner.into()); + } + } + + // We've already squeezed out all available elements + self.mode = poseidon_sponge( + &self.chip, + layouter.namespace(|| "PoseidonSponge"), + &mut self.state, + None, + )?; + } + } +} + +/// A Poseidon hash function, built around a sponge. +#[derive(Debug)] +pub struct Hash< + F: Field, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, +> { + sponge: Sponge, RATE>, D, T, RATE>, +} + +impl< + F: Field, + PoseidonChip: PoseidonSpongeInstructions, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, + > Hash +{ + /// Initializes a new hasher. + pub fn init(chip: PoseidonChip, layouter: impl Layouter) -> Result { + Sponge::new(chip, layouter).map(|sponge| Hash { sponge }) + } +} + +impl< + F: PrimeField, + PoseidonChip: PoseidonSpongeInstructions, T, RATE>, + S: Spec, + const T: usize, + const RATE: usize, + const L: usize, + > Hash, T, RATE> +{ + /// Hashes the given input. + pub fn hash( + mut self, + mut layouter: impl Layouter, + message: [AssignedCell; L], + ) -> Result, Error> { + for (i, value) in message + .into_iter() + .map(PaddedWord::Message) + .chain( as Domain>::padding(L).map(PaddedWord::Padding)) + .enumerate() + { + self.sponge + .absorb(layouter.namespace(|| format!("absorb_{i}")), value)?; + } + self.sponge + .finish_absorbing(layouter.namespace(|| "finish absorbing"))? + .squeeze(layouter.namespace(|| "squeeze")) + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/poseidon/pow5.rs b/plonkish_backend/src/circuits/poseidongadget/poseidon/pow5.rs new file mode 100644 index 00000000..68669357 --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/poseidon/pow5.rs @@ -0,0 +1,952 @@ +use std::convert::TryInto; +use std::iter; + +use halo2_proofs::{ + arithmetic::Field, + circuit::{AssignedCell, Cell, Chip, Layouter, Region, Value}, + plonk::{ + Advice, Any, Column, ConstraintSystem, Constraints, Error, Expression, Fixed, Selector, + }, + poly::Rotation, +}; + +use super::{ + primitives::{Absorbing, Domain, Mds, Spec, Squeezing, State}, + PaddedWord, PoseidonInstructions, PoseidonSpongeInstructions, +}; +use crate::circuits::poseidongadget::utilities::Var; + +/// Configuration for a [`Pow5Chip`]. +#[derive(Clone, Debug)] +pub struct Pow5Config { + pub(crate) state: [Column; WIDTH], + partial_sbox: Column, + rc_a: [Column; WIDTH], + rc_b: [Column; WIDTH], + s_full: Selector, + s_partial: Selector, + s_pad_and_add: Selector, + + half_full_rounds: usize, + half_partial_rounds: usize, + // alpha represents exponent in s-box + alpha: [u64; 4], + round_constants: Vec<[F; WIDTH]>, + m_reg: Mds, + #[allow(dead_code)] + m_inv: Mds, +} + +/// A Poseidon chip using an $x^5$ S-Box. +/// +/// The chip is implemented using a single round per row for full rounds, and two rounds +/// per row for partial rounds. +#[derive(Debug)] +pub struct Pow5Chip { + config: Pow5Config, +} + +impl Pow5Chip { + /// Configures this chip for use in a circuit. + /// + /// # Side-effects + /// + /// All columns in `state` will be equality-enabled. + // + // TODO: Does the rate need to be hard-coded here, or only the width? It probably + // needs to be known wherever we implement the hashing gadget, but it isn't strictly + // necessary for the permutation. + pub fn configure>( + meta: &mut ConstraintSystem, + state: [Column; WIDTH], + partial_sbox: Column, + rc_a: [Column; WIDTH], + rc_b: [Column; WIDTH], + ) -> Pow5Config { + assert_eq!(RATE, WIDTH - 1); + // Generate constants for the Poseidon permutation. + // This gadget requires R_F and R_P to be even. + assert!(S::full_rounds() & 1 == 0); + assert!(S::partial_rounds() & 1 == 0); + let half_full_rounds = S::full_rounds() / 2; + let half_partial_rounds = S::partial_rounds() / 2; + let (round_constants, m_reg, m_inv) = S::constants(); + + // This allows state words to be initialized (by constraining them equal to fixed + // values), and used in a permutation from an arbitrary region. rc_a is used in + // every permutation round, while rc_b is empty in the initial and final full + // rounds, so we use rc_b as "scratch space" for fixed values (enabling potential + // layouter optimisations). + for column in iter::empty() + .chain(state.iter().cloned().map(Column::::from)) + .chain(rc_b.iter().cloned().map(Column::::from)) + { + meta.enable_equality(column); + } + + let s_full = meta.selector(); + let s_partial = meta.selector(); + let s_pad_and_add = meta.selector(); + + let alpha = [5, 0, 0, 0]; + let pow_5 = |v: Expression| { + let v2 = v.clone() * v.clone(); + v2.clone() * v2 * v + }; + + meta.create_gate("full round", |meta| { + let s_full = meta.query_selector(s_full); + + Constraints::with_selector( + s_full, + (0..WIDTH) + .map(|next_idx| { + let state_next = meta.query_advice(state[next_idx], Rotation::next()); + let expr = (0..WIDTH) + .map(|idx| { + let state_cur = meta.query_advice(state[idx], Rotation::cur()); + let rc_a = meta.query_fixed(rc_a[idx], Rotation::cur()); + pow_5(state_cur + rc_a) * m_reg[next_idx][idx] + }) + .reduce(|acc, term| acc + term) + .expect("WIDTH > 0"); + expr - state_next + }) + .collect::>(), + ) + }); + + meta.create_gate("partial rounds", |meta| { + let cur_0 = meta.query_advice(state[0], Rotation::cur()); + let mid_0 = meta.query_advice(partial_sbox, Rotation::cur()); + + let rc_a0 = meta.query_fixed(rc_a[0], Rotation::cur()); + let rc_b0 = meta.query_fixed(rc_b[0], Rotation::cur()); + + let s_partial = meta.query_selector(s_partial); + + use halo2_proofs::plonk::VirtualCells; + let mid = |idx: usize, meta: &mut VirtualCells| { + let mid = mid_0.clone() * m_reg[idx][0]; + (1..WIDTH).fold(mid, |acc, cur_idx| { + let cur = meta.query_advice(state[cur_idx], Rotation::cur()); + let rc_a = meta.query_fixed(rc_a[cur_idx], Rotation::cur()); + acc + (cur + rc_a) * m_reg[idx][cur_idx] + }) + }; + + let next = |idx: usize, meta: &mut VirtualCells| { + (0..WIDTH) + .map(|next_idx| { + let next = meta.query_advice(state[next_idx], Rotation::next()); + next * m_inv[idx][next_idx] + }) + .reduce(|acc, next| acc + next) + .expect("WIDTH > 0") + }; + + let partial_round_linear = |idx: usize, meta: &mut VirtualCells| { + let rc_b = meta.query_fixed(rc_b[idx], Rotation::cur()); + mid(idx, meta) + rc_b - next(idx, meta) + }; + + Constraints::with_selector( + s_partial, + std::iter::empty() + // state[0] round a + .chain(Some(pow_5(cur_0 + rc_a0) - mid_0.clone())) + // state[0] round b + .chain(Some(pow_5(mid(0, meta) + rc_b0) - next(0, meta))) + .chain((1..WIDTH).map(|idx| partial_round_linear(idx, meta))) + .collect::>(), + ) + }); + + meta.create_gate("pad-and-add", |meta| { + let initial_state_rate = meta.query_advice(state[RATE], Rotation::prev()); + let output_state_rate = meta.query_advice(state[RATE], Rotation::next()); + + let s_pad_and_add = meta.query_selector(s_pad_and_add); + + let pad_and_add = |idx: usize| { + let initial_state = meta.query_advice(state[idx], Rotation::prev()); + let input = meta.query_advice(state[idx], Rotation::cur()); + let output_state = meta.query_advice(state[idx], Rotation::next()); + + // We pad the input by storing the required padding in fixed columns and + // then constraining the corresponding input columns to be equal to it. + initial_state + input - output_state + }; + + Constraints::with_selector( + s_pad_and_add, + (0..RATE) + .map(pad_and_add) + // The capacity element is never altered by the input. + .chain(Some(initial_state_rate - output_state_rate)) + .collect::>(), + ) + }); + + Pow5Config { + state, + partial_sbox, + rc_a, + rc_b, + s_full, + s_partial, + s_pad_and_add, + half_full_rounds, + half_partial_rounds, + alpha, + round_constants, + m_reg, + m_inv, + } + } + + /// Construct a [`Pow5Chip`]. + pub fn construct(config: Pow5Config) -> Self { + Pow5Chip { config } + } +} + +impl Chip for Pow5Chip { + type Config = Pow5Config; + type Loaded = (); + + fn config(&self) -> &Self::Config { + &self.config + } + + fn loaded(&self) -> &Self::Loaded { + &() + } +} + +impl, const WIDTH: usize, const RATE: usize> + PoseidonInstructions for Pow5Chip +{ + type Word = StateWord; + + fn permute( + &self, + layouter: &mut impl Layouter, + initial_state: &State, + ) -> Result, Error> { + let config = self.config(); + + layouter.assign_region( + || "permute state", + |mut region| { + // Load the initial state into this region. + let state = Pow5State::load(&mut region, config, initial_state)?; + + let state = (0..config.half_full_rounds) + .try_fold(state, |res, r| res.full_round(&mut region, config, r, r))?; + + let state = (0..config.half_partial_rounds).try_fold(state, |res, r| { + res.partial_round( + &mut region, + config, + config.half_full_rounds + 2 * r, + config.half_full_rounds + r, + ) + })?; + + let state = (0..config.half_full_rounds).try_fold(state, |res, r| { + res.full_round( + &mut region, + config, + config.half_full_rounds + 2 * config.half_partial_rounds + r, + config.half_full_rounds + config.half_partial_rounds + r, + ) + })?; + + Ok(state.0) + }, + ) + } +} + +impl< + F: Field, + S: Spec, + D: Domain, + const WIDTH: usize, + const RATE: usize, + > PoseidonSpongeInstructions for Pow5Chip +{ + fn initial_state( + &self, + layouter: &mut impl Layouter, + ) -> Result, Error> { + let config = self.config(); + let state = layouter.assign_region( + || format!("initial state for domain {}", D::name()), + |mut region| { + let mut state = Vec::with_capacity(WIDTH); + let mut load_state_word = |i: usize, value: F| -> Result<_, Error> { + let var = region.assign_advice_from_constant( + || format!("state_{i}"), + config.state[i], + 0, + value, + )?; + state.push(StateWord(var)); + + Ok(()) + }; + + for i in 0..RATE { + load_state_word(i, F::ZERO)?; + } + load_state_word(RATE, D::initial_capacity_element())?; + + Ok(state) + }, + )?; + + Ok(state.try_into().unwrap()) + } + + fn add_input( + &self, + layouter: &mut impl Layouter, + initial_state: &State, + input: &Absorbing, RATE>, + ) -> Result, Error> { + let config = self.config(); + layouter.assign_region( + || format!("add input for domain {}", D::name()), + |mut region| { + config.s_pad_and_add.enable(&mut region, 1)?; + + // Load the initial state into this region. + let load_state_word = |i: usize| { + initial_state[i] + .0 + .copy_advice( + || format!("load state_{i}"), + &mut region, + config.state[i], + 0, + ) + .map(StateWord) + }; + let initial_state: Result, Error> = + (0..WIDTH).map(load_state_word).collect(); + let initial_state = initial_state?; + + // Load the input into this region. + let load_input_word = |i: usize| { + let constraint_var = match input.0[i].clone() { + Some(PaddedWord::Message(word)) => word, + Some(PaddedWord::Padding(padding_value)) => region.assign_fixed( + || format!("load pad_{i}"), + config.rc_b[i], + 1, + || Value::known(padding_value), + )?, + _ => panic!("Input is not padded"), + }; + // TO DO: the Synthesis error in poseidon_hash_longer_input occurs here when i =1 (the input is padding) + constraint_var + .copy_advice( + || format!("load input_{i}"), + &mut region, + config.state[i], + 1, + ) + .map(StateWord) + }; + let input: Result, Error> = (0..RATE).map(load_input_word).collect(); + let input = input?; + + // Constrain the output. + let constrain_output_word = |i: usize| { + let value = initial_state[i].0.value().copied() + + input + .get(i) + .map(|word| word.0.value().cloned()) + // The capacity element is never altered by the input. + .unwrap_or_else(|| Value::known(F::ZERO)); + region + .assign_advice(|| format!("load output_{i}"), config.state[i], 2, || value) + .map(StateWord) + }; + + let output: Result, Error> = (0..WIDTH).map(constrain_output_word).collect(); + output.map(|output| output.try_into().unwrap()) + }, + ) + } + + fn get_output(state: &State) -> Squeezing { + Squeezing( + state[..RATE] + .iter() + .map(|word| Some(word.clone())) + .collect::>() + .try_into() + .unwrap(), + ) + } +} + +/// A word in the Poseidon state. +#[derive(Clone, Debug)] +pub struct StateWord(AssignedCell); + +impl From> for AssignedCell { + fn from(state_word: StateWord) -> AssignedCell { + state_word.0 + } +} + +impl From> for StateWord { + fn from(cell_value: AssignedCell) -> StateWord { + StateWord(cell_value) + } +} + +impl Var for StateWord { + fn cell(&self) -> Cell { + self.0.cell() + } + + fn value(&self) -> Value { + self.0.value().cloned() + } +} + +#[derive(Debug)] +struct Pow5State([StateWord; WIDTH]); + +impl Pow5State { + fn full_round( + self, + region: &mut Region, + config: &Pow5Config, + round: usize, + offset: usize, + ) -> Result { + Self::round(region, config, round, offset, config.s_full, |_| { + let q = self.0.iter().enumerate().map(|(idx, word)| { + word.0 + .value() + .map(|v| *v + config.round_constants[round][idx]) + }); + let r: Value> = q.map(|q| q.map(|q| q.pow(config.alpha))).collect(); + let m = &config.m_reg; + let state = m.iter().map(|m_i| { + r.as_ref().map(|r| { + r.iter() + .enumerate() + .fold(F::ZERO, |acc, (j, r_j)| acc + m_i[j] * r_j) + }) + }); + + Ok((round + 1, state.collect::>().try_into().unwrap())) + }) + } + + fn partial_round( + self, + region: &mut Region, + config: &Pow5Config, + round: usize, + offset: usize, + ) -> Result { + Self::round(region, config, round, offset, config.s_partial, |region| { + let m = &config.m_reg; + let p: Value> = self.0.iter().map(|word| word.0.value().cloned()).collect(); + + let r: Value> = p.map(|p| { + let r_0 = (p[0] + config.round_constants[round][0]).pow(config.alpha); + let r_i = p[1..] + .iter() + .enumerate() + .map(|(i, p_i)| *p_i + config.round_constants[round][i + 1]); + std::iter::empty().chain(Some(r_0)).chain(r_i).collect() + }); + + region.assign_advice( + || format!("round_{round} partial_sbox"), + config.partial_sbox, + offset, + || r.as_ref().map(|r| r[0]), + )?; + + let p_mid: Value> = m + .iter() + .map(|m_i| { + r.as_ref().map(|r| { + m_i.iter() + .zip(r.iter()) + .fold(F::ZERO, |acc, (m_ij, r_j)| acc + *m_ij * r_j) + }) + }) + .collect(); + + // Load the second round constants. + let mut load_round_constant = |i: usize| { + region.assign_fixed( + || format!("round_{} rc_{}", round + 1, i), + config.rc_b[i], + offset, + || Value::known(config.round_constants[round + 1][i]), + ) + }; + for i in 0..WIDTH { + load_round_constant(i)?; + } + + let r_mid: Value> = p_mid.map(|p| { + let r_0 = (p[0] + config.round_constants[round + 1][0]).pow(config.alpha); + let r_i = p[1..] + .iter() + .enumerate() + .map(|(i, p_i)| *p_i + config.round_constants[round + 1][i + 1]); + std::iter::empty().chain(Some(r_0)).chain(r_i).collect() + }); + + let state: Vec> = m + .iter() + .map(|m_i| { + r_mid.as_ref().map(|r| { + m_i.iter() + .zip(r.iter()) + .fold(F::ZERO, |acc, (m_ij, r_j)| acc + *m_ij * r_j) + }) + }) + .collect(); + + Ok((round + 2, state.try_into().unwrap())) + }) + } + + fn load( + region: &mut Region, + config: &Pow5Config, + initial_state: &State, WIDTH>, + ) -> Result { + let load_state_word = |i: usize| { + initial_state[i] + .0 + .copy_advice(|| format!("load state_{i}"), region, config.state[i], 0) + .map(StateWord) + }; + + let state: Result, _> = (0..WIDTH).map(load_state_word).collect(); + state.map(|state| Pow5State(state.try_into().unwrap())) + } + + fn round( + region: &mut Region, + config: &Pow5Config, + round: usize, + offset: usize, + round_gate: Selector, + round_fn: impl FnOnce(&mut Region) -> Result<(usize, [Value; WIDTH]), Error>, + ) -> Result { + // Enable the required gate. + round_gate.enable(region, offset)?; + + // Load the round constants. + let mut load_round_constant = |i: usize| { + region.assign_fixed( + || format!("round_{round} rc_{i}"), + config.rc_a[i], + offset, + || Value::known(config.round_constants[round][i]), + ) + }; + for i in 0..WIDTH { + load_round_constant(i)?; + } + + // Compute the next round's state. + let (next_round, next_state) = round_fn(region)?; + + let next_state_word = |i: usize| { + let value = next_state[i]; + let var = region.assign_advice( + || format!("round_{next_round} state_{i}"), + config.state[i], + offset + 1, + || value, + )?; + Ok(StateWord(var)) + }; + + let next_state: Result, _> = (0..WIDTH).map(next_state_word).collect(); + next_state.map(|next_state| Pow5State(next_state.try_into().unwrap())) + } +} + +#[cfg(test)] +mod tests { + use ff::Field; + use halo2_curves::bn256::{Bn256, Fr}; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Circuit, ConstraintSystem, Error}, + }; + + use super::{PoseidonInstructions, Pow5Chip, Pow5Config, StateWord}; + use crate::circuits::poseidongadget::poseidon::{ + primitives::{self as poseidon, BN256param as newParam, ConstantLength, Spec}, + Hash, + }; + use rand::rngs::OsRng; + use std::convert::TryInto; + use std::marker::PhantomData; + + use crate::backend::{hyperplonk::HyperPlonk, PlonkishBackend, PlonkishCircuit}; + + use crate::{ + frontend::halo2::{CircuitExt, Halo2Circuit}, + pcs::{multilinear::Zeromorph, univariate::UnivariateKzg}, + util::{ + test::seeded_std_rng, + transcript::{InMemoryTranscript, Keccak256Transcript}, + }, + }; + + struct PermuteCircuit, const WIDTH: usize, const RATE: usize>( + PhantomData, + ); + + impl, const WIDTH: usize, const RATE: usize> Circuit + for PermuteCircuit + { + type Config = Pow5Config; + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + PermuteCircuit::(PhantomData) + } + + fn configure(meta: &mut ConstraintSystem) -> Pow5Config { + let state = (0..WIDTH).map(|_| meta.advice_column()).collect::>(); + let partial_sbox = meta.advice_column(); + + let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + + Pow5Chip::configure::( + meta, + state.try_into().unwrap(), + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + ) + } + + fn synthesize( + &self, + config: Pow5Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let initial_state = layouter.assign_region( + || "prepare initial state", + |mut region| { + let state_word = |i: usize| { + let value = Value::known(Fr::from(i as u64)); + let var = region.assign_advice( + || format!("load state_{}", i), + config.state[i], + 0, + || value, + )?; + Ok(StateWord(var)) + }; + + let state: Result, Error> = (0..WIDTH).map(state_word).collect(); + Ok(state?.try_into().unwrap()) + }, + )?; + + let chip = Pow5Chip::construct(config.clone()); + let final_state = as PoseidonInstructions< + Fr, + S, + WIDTH, + RATE, + >>::permute(&chip, &mut layouter, &initial_state)?; + + // For the purpose of this test, compute the real final state inline. + let mut expected_final_state = (0..WIDTH) + .map(|idx| Fr::from(idx as u64)) + .collect::>() + .try_into() + .unwrap(); + let (round_constants, mds, _) = S::constants(); + poseidon::permute::<_, S, WIDTH, RATE>( + &mut expected_final_state, + &mds, + &round_constants, + ); + layouter.assign_region( + || "constrain final state", + |mut region| { + let mut final_state_word = |i: usize| { + let var = region.assign_advice( + || format!("load final_state_{}", i), + config.state[i], + 0, + || Value::known(expected_final_state[i]), + )?; + region.constrain_equal(final_state[i].0.cell(), var.cell()) + }; + + for i in 0..(WIDTH) { + final_state_word(i)?; + } + + Ok(()) + }, + ) + } + } + + macro_rules! impl_circuit_ext { + ($($n:expr, $m:expr),*) => { + $( + impl CircuitExt for PermuteCircuit, $n, $m> { + fn instances(&self) -> Vec> { + Vec::new() + } + } + )* + } + } + + impl_circuit_ext!(2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6); + + #[test] + fn poseidon_permute() { + type Pb = HyperPlonk>>; + macro_rules! test_poseidon_permute { + ($i:expr, $j:expr) => {{ + let circuit = Halo2Circuit::new::( + 6, + PermuteCircuit::, $j, $i>(PhantomData), + ); + let param = Pb::setup(&circuit.circuit_info().unwrap(), seeded_std_rng()).unwrap(); + let (pp, vp) = Pb::preprocess(¶m, &circuit.circuit_info().unwrap()).unwrap(); + let proof = { + let mut transcript = Keccak256Transcript::new(()); + Pb::prove(&pp, &circuit, &mut transcript, seeded_std_rng()).unwrap(); + transcript.into_proof() + }; + let result = { + let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + Pb::verify(&vp, circuit.instances(), &mut transcript, seeded_std_rng()) + }; + assert_eq!(result, Ok(())) + }}; + } + for i in 2..7 { + match i { + 1 => test_poseidon_permute!(1, 2), + 2 => test_poseidon_permute!(2, 3), + 3 => test_poseidon_permute!(3, 4), + 4 => test_poseidon_permute!(4, 5), + 5 => test_poseidon_permute!(5, 6), + 6 => test_poseidon_permute!(6, 7), + _ => unreachable!(), + } + } + } + + struct HashCircuit< + S: Spec, + const WIDTH: usize, + const RATE: usize, + const L: usize, + > { + message: Value<[Fr; L]>, + // For the purpose of this test, witness the result. + // TODO: Move this into an instance column. + output: Value, + _spec: PhantomData, + } + + impl, const WIDTH: usize, const RATE: usize, const L: usize> + Circuit for HashCircuit + { + type Config = Pow5Config; + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + Self { + message: Value::unknown(), + output: Value::unknown(), + _spec: PhantomData, + } + } + + fn configure(meta: &mut ConstraintSystem) -> Pow5Config { + let state = (0..WIDTH).map(|_| meta.advice_column()).collect::>(); + let partial_sbox = meta.advice_column(); + + let rc_a = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + let rc_b = (0..WIDTH).map(|_| meta.fixed_column()).collect::>(); + + meta.enable_constant(rc_b[0]); + + Pow5Chip::configure::( + meta, + state.try_into().unwrap(), + partial_sbox, + rc_a.try_into().unwrap(), + rc_b.try_into().unwrap(), + ) + } + + fn synthesize( + &self, + config: Pow5Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let chip = Pow5Chip::construct(config.clone()); + let message = layouter.assign_region( + || "load message", + |mut region| { + let message_word = |i: usize| { + let value = self.message.map(|message_vals| message_vals[i]); + region.assign_advice( + || format!("load message_{}", i), + config.state[i], + 0, + || value, + ) + }; + + let message: Result, Error> = (0..L).map(message_word).collect(); + Ok(message?.try_into().unwrap()) + }, + )?; + + let hasher = Hash::<_, _, S, ConstantLength, WIDTH, RATE>::init( + chip, + layouter.namespace(|| "init"), + )?; + //TO DO: add_input called here leading to error in poseidon_hash_longer_input + let output = hasher.hash(layouter.namespace(|| "hash"), message)?; + + layouter.assign_region( + || "constrain output", + |mut region| { + let expected_var = region.assign_advice( + || "load output", + config.state[0], + 0, + || self.output, + )?; + region.constrain_equal(output.cell(), expected_var.cell()) + }, + ) + } + } + + macro_rules! impl_circuit_ext { + ($($n:expr, $m:expr),*) => { + $( + impl CircuitExt for HashCircuit, $n, $m, $m> { + fn instances(&self) -> Vec> { + Vec::new() + } + } + )* + } + } + + impl_circuit_ext!(2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6); + + impl CircuitExt for HashCircuit, 3, 2, 3> { + fn instances(&self) -> Vec> { + Vec::new() + } + } + + #[test] + fn poseidon_hash() { + type Pb = HyperPlonk>>; + macro_rules! test_poseidon_hash { + ($i:expr, $j:expr) => {{ + let message: [Fr; $i] = [Fr::random(OsRng); $i]; + let output = + poseidon::Hash::<_, newParam<$j, $i, 0>, ConstantLength<$i>, $j, $i>::init() + .hash(message); + let circuit = Halo2Circuit::new::( + 6, + HashCircuit::, $j, $i, $i> { + message: Value::known(message), + output: Value::known(output), + _spec: PhantomData, + }, + ); + let param = Pb::setup(&circuit.circuit_info().unwrap(), seeded_std_rng()).unwrap(); + let (pp, vp) = Pb::preprocess(¶m, &circuit.circuit_info().unwrap()).unwrap(); + let proof = { + let mut transcript = Keccak256Transcript::new(()); + Pb::prove(&pp, &circuit, &mut transcript, seeded_std_rng()).unwrap(); + transcript.into_proof() + }; + let result = { + let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + Pb::verify(&vp, circuit.instances(), &mut transcript, seeded_std_rng()) + }; + assert_eq!(result, Ok(())) + }}; + } + for i in 2..7 { + match i { + 1 => test_poseidon_hash!(1, 2), + 2 => test_poseidon_hash!(2, 3), + 3 => test_poseidon_hash!(3, 4), + 4 => test_poseidon_hash!(4, 5), + 5 => test_poseidon_hash!(5, 6), + 6 => test_poseidon_hash!(6, 7), + _ => unreachable!(), + } + } + } + + // This test is ignored because there is an error that should be fixed. The error is on line 353 of this file, in the add_input function. + #[ignore] + #[test] + fn poseidon_hash_longer_input() { + let message = [Fr::random(OsRng), Fr::random(OsRng), Fr::random(OsRng)]; + let output = + poseidon::Hash::<_, newParam<3, 2, 0>, ConstantLength<3>, 3, 2>::init().hash(message); + type Pb = HyperPlonk>>; + let circuit = Halo2Circuit::new::( + 7, + HashCircuit::, 3, 2, 3> { + message: Value::known(message), + output: Value::known(output), + _spec: PhantomData, + }, + ); + let param = Pb::setup(&circuit.circuit_info().unwrap(), seeded_std_rng()).unwrap(); + let (pp, vp) = Pb::preprocess(¶m, &circuit.circuit_info().unwrap()).unwrap(); + let proof = { + let mut transcript = Keccak256Transcript::new(()); + Pb::prove(&pp, &circuit, &mut transcript, seeded_std_rng()).unwrap(); + transcript.into_proof() + }; + let result = { + let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + Pb::verify(&vp, circuit.instances(), &mut transcript, seeded_std_rng()) + }; + assert_eq!(result, Ok(())) + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives.rs b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives.rs new file mode 100644 index 00000000..23d1b92c --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives.rs @@ -0,0 +1,403 @@ +//! The Poseidon algebraic hash function. + +use std::convert::TryInto; +use std::fmt; +use std::iter; +use std::marker::PhantomData; + +use ff::FromUniformBytes; +use ff::PrimeField; +use halo2_proofs::arithmetic::Field; + +pub(crate) mod grain; +pub(crate) mod mds; + +mod bn256param; +pub use bn256param::BN256param; + +use grain::SboxType; + +/// The type used to hold permutation state. +pub(crate) type State = [F; T]; + +/// The type used to hold sponge rate. +pub(crate) type SpongeRate = [Option; RATE]; + +/// The type used to hold the MDS matrix and its inverse. +pub type Mds = [[F; T]; T]; + +/// A specification for a Poseidon permutation. +pub trait Spec: fmt::Debug { + /// The number of full rounds for this specification. + /// + /// This must be an even number. + fn full_rounds() -> usize; + + /// The number of partial rounds for this specification. + fn partial_rounds() -> usize; + + /// The S-box for this specification. + fn sbox(val: F) -> F; + + /// Side-loaded index of the first correct and secure MDS that will be generated by + /// the reference implementation. + /// + /// This is used by the default implementation of [`Spec::constants`]. If you are + /// hard-coding the constants, you may leave this unimplemented. + fn secure_mds() -> usize; + + /// Generates `(round_constants, mds, mds^-1)` corresponding to this specification. + fn constants() -> (Vec<[F; T]>, Mds, Mds); +} + +/// Generates `(round_constants, mds, mds^-1)` corresponding to this specification. +pub fn generate_constants< + F: FromUniformBytes<64> + Ord, + S: Spec, + const T: usize, + const RATE: usize, +>() -> (Vec<[F; T]>, Mds, Mds) { + let r_f = S::full_rounds(); + let r_p = S::partial_rounds(); + + let mut grain = grain::Grain::new(SboxType::Pow, T as u16, r_f as u16, r_p as u16); + + let round_constants = (0..(r_f + r_p)) + .map(|_| { + let mut rc_row = [F::ZERO; T]; + for (rc, value) in rc_row + .iter_mut() + .zip((0..T).map(|_| grain.next_field_element())) + { + *rc = value; + } + rc_row + }) + .collect(); + + let (mds, mds_inv) = mds::generate_mds::(&mut grain, S::secure_mds()); + + (round_constants, mds, mds_inv) +} + +/// Runs the Poseidon permutation on the given state. +pub(crate) fn permute, const T: usize, const RATE: usize>( + state: &mut State, + mds: &Mds, + round_constants: &[[F; T]], +) { + let r_f = S::full_rounds() / 2; + let r_p = S::partial_rounds(); + + let apply_mds = |state: &mut State| { + let mut new_state = [F::ZERO; T]; + // Matrix multiplication + #[allow(clippy::needless_range_loop)] + for i in 0..T { + for j in 0..T { + new_state[i] += mds[i][j] * state[j]; + } + } + *state = new_state; + }; + + let full_round = |state: &mut State, rcs: &[F; T]| { + for (word, rc) in state.iter_mut().zip(rcs.iter()) { + *word = S::sbox(*word + rc); + } + apply_mds(state); + }; + + let part_round = |state: &mut State, rcs: &[F; T]| { + for (word, rc) in state.iter_mut().zip(rcs.iter()) { + *word += rc; + } + // In a partial round, the S-box is only applied to the first state word. + state[0] = S::sbox(state[0]); + apply_mds(state); + }; + + iter::empty() + .chain(iter::repeat(&full_round as &dyn Fn(&mut State, &[F; T])).take(r_f)) + .chain(iter::repeat(&part_round as &dyn Fn(&mut State, &[F; T])).take(r_p)) + .chain(iter::repeat(&full_round as &dyn Fn(&mut State, &[F; T])).take(r_f)) + .zip(round_constants.iter()) + .fold(state, |state, (round, rcs)| { + round(state, rcs); + state + }); +} + +fn poseidon_sponge, const T: usize, const RATE: usize>( + state: &mut State, + input: Option<&Absorbing>, + mds_matrix: &Mds, + round_constants: &[[F; T]], +) -> Squeezing { + if let Some(Absorbing(input)) = input { + // `Iterator::zip` short-circuits when one iterator completes, so this will only + // mutate the rate portion of the state. + for (word, value) in state.iter_mut().zip(input.iter()) { + *word += value.expect("poseidon_sponge is called with a padded input"); + } + } + + permute::(state, mds_matrix, round_constants); + + let mut output = [None; RATE]; + for (word, value) in output.iter_mut().zip(state.iter()) { + *word = Some(*value); + } + Squeezing(output) +} + +mod private { + pub trait SealedSpongeMode {} + impl SealedSpongeMode for super::Absorbing {} + impl SealedSpongeMode for super::Squeezing {} +} + +/// The state of the `Sponge`. +pub trait SpongeMode: private::SealedSpongeMode {} + +/// The absorbing state of the `Sponge`. +#[derive(Debug, Clone)] +pub struct Absorbing(pub(crate) SpongeRate); + +/// The squeezing state of the `Sponge`. +#[derive(Debug)] +pub struct Squeezing(pub(crate) SpongeRate); + +impl SpongeMode for Absorbing {} +impl SpongeMode for Squeezing {} + +impl Absorbing { + pub(crate) fn init_with(val: F) -> Self { + Self( + iter::once(Some(val)) + .chain((1..RATE).map(|_| None)) + .collect::>() + .try_into() + .unwrap(), + ) + } +} + +#[derive(Clone)] +/// A Poseidon sponge. +pub(crate) struct Sponge< + F: Field, + S: Spec, + M: SpongeMode, + const T: usize, + const RATE: usize, +> { + mode: M, + state: State, + mds_matrix: Mds, + round_constants: Vec<[F; T]>, + _marker: PhantomData, +} + +impl, const T: usize, const RATE: usize> + Sponge, T, RATE> +{ + /// Constructs a new sponge for the given Poseidon specification. + pub(crate) fn new(initial_capacity_element: F) -> Self { + let (round_constants, mds_matrix, _) = S::constants(); + + let mode = Absorbing([None; RATE]); + let mut state = [F::ZERO; T]; + state[RATE] = initial_capacity_element; + + Sponge { + mode, + state, + mds_matrix, + round_constants, + _marker: PhantomData, + } + } + + /// Absorbs an element into the sponge. + pub(crate) fn absorb(&mut self, value: F) { + for entry in self.mode.0.iter_mut() { + if entry.is_none() { + *entry = Some(value); + return; + } + } + + // We've already absorbed as many elements as we can + let _ = poseidon_sponge::( + &mut self.state, + Some(&self.mode), + &self.mds_matrix, + &self.round_constants, + ); + self.mode = Absorbing::init_with(value); + } + + /// Transitions the sponge into its squeezing state. + pub(crate) fn finish_absorbing(mut self) -> Sponge, T, RATE> { + let mode = poseidon_sponge::( + &mut self.state, + Some(&self.mode), + &self.mds_matrix, + &self.round_constants, + ); + + Sponge { + mode, + state: self.state, + mds_matrix: self.mds_matrix, + round_constants: self.round_constants, + _marker: PhantomData, + } + } +} + +impl, const T: usize, const RATE: usize> + Sponge, T, RATE> +{ + /// Squeezes an element from the sponge. + pub(crate) fn squeeze(&mut self) -> F { + loop { + for entry in self.mode.0.iter_mut() { + if let Some(e) = entry.take() { + return e; + } + } + + // We've already squeezed out all available elements + self.mode = poseidon_sponge::( + &mut self.state, + None, + &self.mds_matrix, + &self.round_constants, + ); + } + } +} + +/// A domain in which a Poseidon hash function is being used. +pub trait Domain { + /// Iterator that outputs padding field elements. + type Padding: IntoIterator; + + /// The name of this domain, for debug formatting purposes. + fn name() -> String; + + /// The initial capacity element, encoding this domain. + fn initial_capacity_element() -> F; + + /// Returns the padding to be appended to the input. + fn padding(input_len: usize) -> Self::Padding; +} + +/// A Poseidon hash function used with constant input length. +/// +/// Domain specified in [ePrint 2019/458 section 4.2](https://eprint.iacr.org/2019/458.pdf). +#[derive(Clone, Copy, Debug)] +pub struct ConstantLength; + +impl Domain for ConstantLength { + type Padding = iter::Take>; + + fn name() -> String { + format!("ConstantLength<{L}>") + } + + fn initial_capacity_element() -> F { + // Capacity value is $length \cdot 2^64 + (o-1)$ where o is the output length. + // We hard-code an output length of 1. + F::from_u128((L as u128) << 64) + } + + fn padding(input_len: usize) -> Self::Padding { + assert_eq!(input_len, L); + // For constant-input-length hashing, we pad the input with zeroes to a multiple + // of RATE. On its own this would not be sponge-compliant padding, but the + // Poseidon authors encode the constant length into the capacity element, ensuring + // that inputs of different lengths do not share the same permutation. + let k = (L + RATE - 1) / RATE; + iter::repeat(F::ZERO).take(k * RATE - L) + } +} + +#[derive(Clone)] +/// A Poseidon hash function, built around a sponge. +pub struct Hash< + F: Field, + S: Spec, + D: Domain, + const T: usize, + const RATE: usize, +> { + sponge: Sponge, T, RATE>, + _domain: PhantomData, +} + +impl, D: Domain, const T: usize, const RATE: usize> + fmt::Debug for Hash +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Hash") + .field("width", &T) + .field("rate", &RATE) + .field("R_F", &S::full_rounds()) + .field("R_P", &S::partial_rounds()) + .field("domain", &D::name()) + .finish() + } +} + +impl, D: Domain, const T: usize, const RATE: usize> + Hash +{ + /// Initializes a new hasher. + pub fn init() -> Self { + Hash { + sponge: Sponge::new(D::initial_capacity_element()), + _domain: PhantomData, + } + } +} + +impl, const T: usize, const RATE: usize, const L: usize> + Hash, T, RATE> +{ + /// Hashes the given input. + pub fn hash(mut self, message: [F; L]) -> F { + for value in message + .into_iter() + .chain( as Domain>::padding(L)) + { + self.sponge.absorb(value); + } + self.sponge.finish_absorbing().squeeze() + } +} + +#[cfg(test)] +mod tests { + use super::{permute, BN256param, ConstantLength, Hash, Spec}; + use ff::PrimeField; + use halo2_curves::bn256::Fr; + + #[test] + fn bn256_spec_equivalence() { + let message = [Fr::from(6), Fr::from(42)]; + + let (round_constants, mds, _) = BN256param::<3, 2, 0>::constants(); + + let hasher = Hash::<_, BN256param<3, 2, 0>, ConstantLength<2>, 3, 2>::init(); + let result = hasher.hash(message); + + // The result should be equivalent to just directly applying the permutation and + // taking the first state element as the output. + let mut state = [message[0], message[1], Fr::from_u128(2 << 64)]; + permute::<_, BN256param<3, 2, 0>, 3, 2>(&mut state, &mds, &round_constants); + assert_eq!(state[0], result); + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/bn256param.rs b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/bn256param.rs new file mode 100644 index 00000000..cac79fb3 --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/bn256param.rs @@ -0,0 +1,118 @@ +// TO DO: don't want to use two lines below +use super::{Mds, Spec}; +use crate::circuits::poseidongadget::poseidon::primitives::generate_constants; +use halo2_curves::bn256::Fr; +use halo2_proofs::arithmetic::Field; + +// To do rewrite this below +/// Poseidon-128 using the $x^5$ S-box, with a width of 3 field elements, and the +/// standard number of rounds for 128-bit security "with margin". +/// +/// The standard specification for this set of parameters (on either of the Pasta +/// fields) uses $R_F = 8, R_P = 56$. This is conveniently an even number of +/// partial rounds, making it easier to construct a Halo 2 circuit. +#[derive(Debug)] +// Do we have to specify width and rate generically? +pub struct BN256param; + +impl Spec + for BN256param +{ + fn full_rounds() -> usize { + 8 + } + + fn partial_rounds() -> usize { + //TO DO: we need an even number of partial rounds - can we round up + match T { + 2 => 56, + // Rounded up from 57 + 3 => 58, + 4 => 56, + 5 => 60, + 6 => 60, + // Rounded up from 63 + 7 => 64, + _ => unimplemented!(), + } + } + + fn sbox(val: Fr) -> Fr { + val.pow_vartime([5]) + } + + fn secure_mds() -> usize { + SECURE_MDS + } + + fn constants() -> (Vec<[Fr; T]>, Mds, Mds) { + // TO DO: manually generate the constants here + generate_constants::<_, Self, T, R>() + } +} + +// TO DO Remove both + +#[cfg(test)] +mod tests { + #![allow(dead_code)] + use crate::circuits::poseidongadget::poseidon::primitives::{generate_constants, Mds, Spec}; + use ff::{Field, FromUniformBytes}; + use std::marker::PhantomData; + + /// The same Poseidon specification as poseidon::P128Pow5T3, but constructed + /// such that its constants will be generated at runtime. + #[derive(Debug)] + // to do change Field to Fr? + pub struct BN256paramGen( + PhantomData, + ); + + impl + BN256paramGen + { + pub fn new() -> Self { + BN256paramGen(PhantomData) + } + } + + impl< + const T: usize, + const R: usize, + F: FromUniformBytes<64> + Ord, + const SECURE_MDS: usize, + > Spec for BN256paramGen + { + fn full_rounds() -> usize { + 8 + } + + fn partial_rounds() -> usize { + //TO DO: we need an even number of partial rounds - can we round up + match T { + 2 => 56, + // Rounded up from 57 + 3 => 58, + 4 => 56, + 5 => 60, + 6 => 60, + // Rounded up from 63 + 7 => 64, + _ => unimplemented!(), + } + } + + fn sbox(val: F) -> F { + val.pow_vartime([5]) + } + + fn secure_mds() -> usize { + SECURE_MDS + } + + fn constants() -> (Vec<[F; T]>, Mds, Mds) { + // TO DO: manually generate the constants here + generate_constants::<_, Self, T, R>() + } + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/grain.rs b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/grain.rs new file mode 100644 index 00000000..6c0ebf80 --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/grain.rs @@ -0,0 +1,196 @@ +//! The Grain LFSR in self-shrinking mode, as used by Poseidon. + +use std::marker::PhantomData; + +use bitvec::prelude::*; +use ff::{FromUniformBytes, PrimeField}; +use halo2_proofs::arithmetic::Field; + +const STATE: usize = 80; + +#[derive(Debug, Clone, Copy)] +pub(super) enum FieldType { + /// GF(2^n) + #[allow(dead_code)] + Binary, + /// GF(p) + PrimeOrder, +} + +impl FieldType { + fn tag(&self) -> u8 { + match self { + FieldType::Binary => 0, + FieldType::PrimeOrder => 1, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub(super) enum SboxType { + /// x^alpha + Pow, + /// x^(-1) + #[allow(dead_code)] + Inv, +} + +impl SboxType { + fn tag(&self) -> u8 { + match self { + SboxType::Pow => 0, + SboxType::Inv => 1, + } + } +} + +pub(super) struct Grain { + state: BitArr!(for 80, in u8, Msb0), + next_bit: usize, + _field: PhantomData, +} + +impl Grain { + pub(super) fn new(sbox: SboxType, t: u16, r_f: u16, r_p: u16) -> Self { + // Initialize the LFSR state. + let mut state = bitarr![u8, Msb0; 1; STATE]; + let mut set_bits = |offset: usize, len, value| { + // Poseidon reference impl sets initial state bits in MSB order. + for i in 0..len { + *state.get_mut(offset + len - 1 - i).unwrap() = (value >> i) & 1 != 0; + } + }; + set_bits(0, 2, FieldType::PrimeOrder.tag() as u16); + set_bits(2, 4, sbox.tag() as u16); + set_bits(6, 12, F::NUM_BITS as u16); + set_bits(18, 12, t); + set_bits(30, 10, r_f); + set_bits(40, 10, r_p); + + let mut grain = Grain { + state, + next_bit: STATE, + _field: PhantomData, + }; + + // Discard the first 160 bits. + for _ in 0..20 { + grain.load_next_8_bits(); + grain.next_bit = STATE; + } + + grain + } + + fn load_next_8_bits(&mut self) { + let mut new_bits = 0u8; + for i in 0..8 { + new_bits |= ((self.state[i + 62] + ^ self.state[i + 51] + ^ self.state[i + 38] + ^ self.state[i + 23] + ^ self.state[i + 13] + ^ self.state[i]) as u8) + << i; + } + self.state.rotate_left(8); + self.next_bit -= 8; + for i in 0..8 { + *self.state.get_mut(self.next_bit + i).unwrap() = (new_bits >> i) & 1 != 0; + } + } + + fn get_next_bit(&mut self) -> bool { + if self.next_bit == STATE { + self.load_next_8_bits(); + } + let ret = self.state[self.next_bit]; + self.next_bit += 1; + ret + } + + /// Returns the next field element from this Grain instantiation. + pub(super) fn next_field_element(&mut self) -> F { + // Loop until we get an element in the field. + loop { + let mut bytes = F::Repr::default(); + + // Poseidon reference impl interprets the bits as a repr in MSB order, because + // it's easy to do that in Python. Meanwhile, our field elements all use LSB + // order. There's little motivation to diverge from the reference impl; these + // are all constants, so we aren't introducing big-endianness into the rest of + // the circuit (assuming unkeyed Poseidon, but we probably wouldn't want to + // implement Grain inside a circuit, so we'd use a different round constant + // derivation function there). + let view = bytes.as_mut(); + for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() { + // If we diverged from the reference impl and interpreted the bits in LSB + // order, we would remove this line. + let i = F::NUM_BITS as usize - 1 - i; + + view[i / 8] |= if bit { 1 << (i % 8) } else { 0 }; + } + + if let Some(f) = F::from_repr_vartime(bytes) { + break f; + } + } + } +} + +impl> Grain { + /// Returns the next field element from this Grain instantiation, without using + /// rejection sampling. + pub(super) fn next_field_element_without_rejection(&mut self) -> F { + let mut bytes = [0u8; 64]; + + // Poseidon reference impl interprets the bits as a repr in MSB order, because + // it's easy to do that in Python. Additionally, it does not use rejection + // sampling in cases where the constants don't specifically need to be uniformly + // random for security. We do not provide APIs that take a field-element-sized + // array and reduce it modulo the field order, because those are unsafe APIs to + // offer generally (accidentally using them can lead to divergence in consensus + // systems due to not rejecting canonical forms). + // + // Given that we don't want to diverge from the reference implementation, we hack + // around this restriction by serializing the bits into a 64-byte array and then + // calling F::from_bytes_wide. PLEASE DO NOT COPY THIS INTO YOUR OWN CODE! + let view = bytes.as_mut(); + for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() { + // If we diverged from the reference impl and interpreted the bits in LSB + // order, we would remove this line. + let i = F::NUM_BITS as usize - 1 - i; + + view[i / 8] |= if bit { 1 << (i % 8) } else { 0 }; + } + + F::from_uniform_bytes(&bytes) + } +} + +impl Iterator for Grain { + type Item = bool; + + fn next(&mut self) -> Option { + // Evaluate bits in pairs: + // - If the first bit is a 1, output the second bit. + // - If the first bit is a 0, discard the second bit. + while !self.get_next_bit() { + self.get_next_bit(); + } + Some(self.get_next_bit()) + } +} + +#[cfg(test)] +mod tests { + use halo2_curves::pasta::Fp; + + use super::{Grain, SboxType}; + + #[test] + fn grain() { + let mut grain = Grain::::new(SboxType::Pow, 3, 8, 56); + let _f = grain.next_field_element(); + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/mds.rs b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/mds.rs new file mode 100644 index 00000000..2a6f33f5 --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/poseidon/primitives/mds.rs @@ -0,0 +1,124 @@ +use ff::FromUniformBytes; + +use super::{grain::Grain, Mds}; + +pub(super) fn generate_mds + Ord, const T: usize>( + grain: &mut Grain, + mut select: usize, +) -> (Mds, Mds) { + let (xs, ys, mds) = loop { + // Generate two [F; T] arrays of unique field elements. + let (xs, ys) = loop { + let mut vals: Vec<_> = (0..2 * T) + .map(|_| grain.next_field_element_without_rejection()) + .collect(); + + // Check that we have unique field elements. + let mut unique = vals.clone(); + unique.sort_unstable(); + unique.dedup(); + if vals.len() == unique.len() { + let rhs = vals.split_off(T); + break (vals, rhs); + } + }; + + // We need to ensure that the MDS is secure. Instead of checking the MDS against + // the relevant algorithms directly, we witness a fixed number of MDS matrices + // that we need to sample from the given Grain state before obtaining a secure + // matrix. This can be determined out-of-band via the reference implementation in + // Sage. + if select != 0 { + select -= 1; + continue; + } + + // Generate a Cauchy matrix, with elements a_ij in the form: + // a_ij = 1/(x_i + y_j); x_i + y_j != 0 + // + // It would be much easier to use the alternate definition: + // a_ij = 1/(x_i - y_j); x_i - y_j != 0 + // + // These are clearly equivalent on `y <- -y`, but it is easier to work with the + // negative formulation, because ensuring that xs ∪ ys is unique implies that + // x_i - y_j != 0 by construction (whereas the positive case does not hold). It + // also makes computation of the matrix inverse simpler below (the theorem used + // was formulated for the negative definition). + // + // However, the Poseidon paper and reference impl use the positive formulation, + // and we want to rely on the reference impl for MDS security, so we use the same + // formulation. + let mut mds = [[F::ZERO; T]; T]; + #[allow(clippy::needless_range_loop)] + for i in 0..T { + for j in 0..T { + let sum = xs[i] + ys[j]; + // We leverage the secure MDS selection counter to also check this. + assert!(!sum.is_zero_vartime()); + mds[i][j] = sum.invert().unwrap(); + } + } + + break (xs, ys, mds); + }; + + // Compute the inverse. All square Cauchy matrices have a non-zero determinant and + // thus are invertible. The inverse for a Cauchy matrix of the form: + // + // a_ij = 1/(x_i - y_j); x_i - y_j != 0 + // + // has elements b_ij given by: + // + // b_ij = (x_j - y_i) A_j(y_i) B_i(x_j) (Schechter 1959, Theorem 1) + // + // where A_i(x) and B_i(x) are the Lagrange polynomials for xs and ys respectively. + // + // We adapt this to the positive Cauchy formulation by negating ys. + let mut mds_inv = [[F::ZERO; T]; T]; + let l = |xs: &[F], j, x: F| { + let x_j = xs[j]; + xs.iter().enumerate().fold(F::ONE, |acc, (m, x_m)| { + if m == j { + acc + } else { + // We can invert freely; by construction, the elements of xs are distinct. + let diff: F = x_j - *x_m; + acc * (x - x_m) * diff.invert().unwrap() + } + }) + }; + let neg_ys: Vec<_> = ys.iter().map(|y| -*y).collect(); + for i in 0..T { + for j in 0..T { + mds_inv[i][j] = (xs[j] - neg_ys[i]) * l(&xs, j, neg_ys[i]) * l(&neg_ys, i, xs[j]); + } + } + + (mds, mds_inv) +} + +#[cfg(test)] +mod tests { + use halo2_curves::pasta::Fp; + + use super::{generate_mds, Grain}; + + #[test] + fn poseidon_mds() { + const T: usize = 3; + let mut grain = Grain::new(super::super::grain::SboxType::Pow, T as u16, 8, 56); + let (mds, mds_inv) = generate_mds::(&mut grain, 0); + + // Verify that MDS * MDS^-1 = I. + #[allow(clippy::needless_range_loop)] + for i in 0..T { + for j in 0..T { + let expected = if i == j { Fp::one() } else { Fp::zero() }; + assert_eq!( + (0..T).fold(Fp::zero(), |acc, k| acc + (mds[i][k] * mds_inv[k][j])), + expected + ); + } + } + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/utilities.rs b/plonkish_backend/src/circuits/poseidongadget/utilities.rs new file mode 100644 index 00000000..eeeddd86 --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/utilities.rs @@ -0,0 +1,496 @@ +//! Utility gadgets. + +use ff::{Field, PrimeField, PrimeFieldBits}; +use halo2_proofs::{ + circuit::{AssignedCell, Cell, Layouter, Value}, + plonk::{Advice, Column, Error, Expression}, +}; +use std::marker::PhantomData; +use std::ops::Range; + +pub mod cond_swap; + +/// A type that has a value at either keygen or proving time. +pub trait FieldValue { + /// Returns the value of this type. + fn value(&self) -> Value<&F>; +} + +impl FieldValue for Value { + fn value(&self) -> Value<&F> { + self.as_ref() + } +} + +impl FieldValue for AssignedCell { + fn value(&self) -> Value<&F> { + self.value() + } +} + +/// Trait for a variable in the circuit. +pub trait Var: Clone + std::fmt::Debug + From> { + /// The cell at which this variable was allocated. + fn cell(&self) -> Cell; + + /// The value allocated to this variable. + fn value(&self) -> Value; +} + +impl Var for AssignedCell { + fn cell(&self) -> Cell { + self.cell() + } + + fn value(&self) -> Value { + self.value().cloned() + } +} + +/// Trait for utilities used across circuits. +pub trait UtilitiesInstructions { + /// Variable in the circuit. + type Var: Var; + + /// Load a variable. + fn load_private( + &self, + mut layouter: impl Layouter, + column: Column, + value: Value, + ) -> Result { + layouter.assign_region( + || "load private", + |mut region| { + region + .assign_advice(|| "load private", column, 0, || value) + .map(Self::Var::from) + }, + ) + } +} + +/// A type representing a range-constrained field element. +#[derive(Clone, Copy, Debug)] +pub struct RangeConstrained> { + inner: T, + num_bits: usize, + _phantom: PhantomData, +} + +impl> RangeConstrained { + /// Returns the range-constrained inner type. + pub fn inner(&self) -> &T { + &self.inner + } + + /// Returns the number of bits to which this cell is constrained. + pub fn num_bits(&self) -> usize { + self.num_bits + } +} + +impl RangeConstrained> { + /// Constructs a `RangeConstrained>` as a bitrange of the given value. + pub fn bitrange_of(value: Value<&F>, bitrange: Range) -> Self { + let num_bits = bitrange.len(); + Self { + inner: value.map(|value| bitrange_subset(value, bitrange)), + num_bits, + _phantom: PhantomData, + } + } +} + +impl RangeConstrained> { + /// Constructs a `RangeConstrained>` without verifying that the + /// cell is correctly range constrained. + /// + /// This API only exists to ease with integrating this type into existing circuits, + /// and will likely be removed in future. + pub fn unsound_unchecked(cell: AssignedCell, num_bits: usize) -> Self { + Self { + inner: cell, + num_bits, + _phantom: PhantomData, + } + } + + /// Extracts the range-constrained value from this range-constrained cell. + pub fn value(&self) -> RangeConstrained> { + RangeConstrained { + inner: self.inner.value().copied(), + num_bits: self.num_bits, + _phantom: PhantomData, + } + } +} + +/// Checks that an expression is either 1 or 0. +pub fn bool_check(value: Expression) -> Expression { + range_check(value, 2) +} + +/// If `a` then `b`, else `c`. Returns (a * b) + (1 - a) * c. +/// +/// `a` must be a boolean-constrained expression. +pub fn ternary(a: Expression, b: Expression, c: Expression) -> Expression { + let one_minus_a = Expression::Constant(F::ONE) - a.clone(); + a * b + one_minus_a * c +} + +/// Takes a specified subsequence of the little-endian bit representation of a field element. +/// The bits are numbered from 0 for the LSB. +pub fn bitrange_subset(field_elem: &F, bitrange: Range) -> F { + // We can allow a subsequence of length NUM_BITS, because + // field_elem.to_le_bits() returns canonical bitstrings. + assert!(bitrange.end <= F::NUM_BITS as usize); + + field_elem + .to_le_bits() + .iter() + .by_vals() + .skip(bitrange.start) + .take(bitrange.end - bitrange.start) + .rev() + .fold(F::ZERO, |acc, bit| { + if bit { + acc.double() + F::ONE + } else { + acc.double() + } + }) +} + +/// Check that an expression is in the small range [0..range), +/// i.e. 0 ≤ word < range. +pub fn range_check(word: Expression, range: usize) -> Expression { + (1..range).fold(word.clone(), |acc, i| { + acc * (Expression::Constant(F::from(i as u64)) - word.clone()) + }) +} + +/// Decompose a word `alpha` into `window_num_bits` bits (little-endian) +/// For a window size of `w`, this returns [k_0, ..., k_n] where each `k_i` +/// is a `w`-bit value, and `scalar = k_0 + k_1 * w + k_n * w^n`. +/// +/// # Panics +/// +/// We are returning a `Vec` which means the window size is limited to +/// <= 8 bits. +pub fn decompose_word( + word: &F, + word_num_bits: usize, + window_num_bits: usize, +) -> Vec { + assert!(window_num_bits <= 8); + + // Pad bits to multiple of window_num_bits + let padding = (window_num_bits - (word_num_bits % window_num_bits)) % window_num_bits; + let bits: Vec = word + .to_le_bits() + .into_iter() + .take(word_num_bits) + .chain(std::iter::repeat(false).take(padding)) + .collect(); + assert_eq!(bits.len(), word_num_bits + padding); + + bits.chunks_exact(window_num_bits) + .map(|chunk| chunk.iter().rev().fold(0, |acc, b| (acc << 1) + (*b as u8))) + .collect() +} + +/// The u64 integer represented by an L-bit little-endian bitstring. +/// +/// # Panics +/// +/// Panics if the bitstring is longer than 64 bits. +pub fn lebs2ip(bits: &[bool; L]) -> u64 { + assert!(L <= 64); + bits.iter() + .enumerate() + .fold(0u64, |acc, (i, b)| acc + if *b { 1 << i } else { 0 }) +} + +/// The sequence of bits representing a u64 in little-endian order. +/// +/// # Panics +/// +/// Panics if the expected length of the sequence `NUM_BITS` exceeds +/// 64. +pub fn i2lebsp(int: u64) -> [bool; NUM_BITS] { + /// Takes in an FnMut closure and returns a constant-length array with elements of + /// type `Output`. + fn gen_const_array( + closure: impl FnMut(usize) -> Output, + ) -> [Output; LEN] { + let mut ret: [Output; LEN] = [Default::default(); LEN]; + for (bit, val) in ret.iter_mut().zip((0..LEN).map(closure)) { + *bit = val; + } + ret + } + assert!(NUM_BITS <= 64); + gen_const_array(|mask: usize| (int & (1 << mask)) != 0) +} + +#[cfg(test)] +mod tests { + use super::*; + use ff::FromUniformBytes; + use group::ff::{Field, PrimeField}; + use halo2_curves::pasta::pallas; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + dev::{FailureLocation, MockProver, VerifyFailure}, + plonk::{Any, Circuit, ConstraintSystem, Constraints, Error, Selector}, + poly::Rotation, + }; + use proptest::prelude::*; + use rand::rngs::OsRng; + use std::convert::TryInto; + use std::iter; + use uint::construct_uint; + + #[test] + fn test_range_check() { + struct MyCircuit(u8); + + impl UtilitiesInstructions for MyCircuit { + type Var = AssignedCell; + } + + #[derive(Clone)] + struct Config { + selector: Selector, + advice: Column, + } + + impl Circuit for MyCircuit { + type Config = Config; + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + MyCircuit(self.0) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let selector = meta.selector(); + let advice = meta.advice_column(); + + meta.create_gate("range check", |meta| { + let selector = meta.query_selector(selector); + let advice = meta.query_advice(advice, Rotation::cur()); + + Constraints::with_selector(selector, Some(range_check(advice, RANGE))) + }); + + Config { selector, advice } + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "range constrain", + |mut region| { + config.selector.enable(&mut region, 0)?; + region.assign_advice( + || format!("witness {}", self.0), + config.advice, + 0, + || Value::known(pallas::Base::from(self.0 as u64)), + )?; + + Ok(()) + }, + ) + } + } + + for i in 0..8 { + let circuit: MyCircuit<8> = MyCircuit(i); + let prover = MockProver::::run::<_, true>(3, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())); + } + + { + let circuit: MyCircuit<8> = MyCircuit(8); + let prover = MockProver::::run::<_, true>(3, &circuit, vec![]).unwrap(); + assert_eq!( + prover.verify(), + Err(vec![VerifyFailure::ConstraintNotSatisfied { + constraint: ((0, "range check").into(), 0, "").into(), + location: FailureLocation::InRegion { + region: (0, "range constrain").into(), + offset: 0, + }, + cell_values: vec![(((Any::advice(), 0).into(), 0).into(), "0x8".to_string())], + }]) + ); + } + } + + #[allow(clippy::assign_op_pattern)] + #[allow(clippy::ptr_offset_with_cast)] + #[allow(clippy::single_range_in_vec_init)] + #[test] + fn test_bitrange_subset() { + let rng = OsRng; + + construct_uint! { + struct U256(4); + } + + // Subset full range. + { + let field_elem = pallas::Base::random(rng); + let bitrange = 0..(pallas::Base::NUM_BITS as usize); + let subset = bitrange_subset(&field_elem, bitrange); + assert_eq!(field_elem, subset); + } + + // Subset zero bits + { + let field_elem = pallas::Base::random(rng); + let bitrange = 0..0; + let subset = bitrange_subset(&field_elem, bitrange); + assert_eq!(pallas::Base::zero(), subset); + } + + // Closure to decompose field element into pieces using consecutive ranges, + // and check that we recover the original. + let decompose = |field_elem: pallas::Base, ranges: &[Range]| { + assert_eq!( + ranges.iter().map(|range| range.len()).sum::(), + pallas::Base::NUM_BITS as usize + ); + assert_eq!(ranges[0].start, 0); + assert_eq!(ranges.last().unwrap().end, pallas::Base::NUM_BITS as usize); + + // Check ranges are contiguous + #[allow(unused_assignments)] + { + let mut ranges = ranges.iter(); + let mut range = ranges.next().unwrap(); + if let Some(next_range) = ranges.next() { + assert_eq!(range.end, next_range.start); + range = next_range; + } + } + + let subsets = ranges + .iter() + .map(|range| bitrange_subset(&field_elem, range.clone())) + .collect::>(); + + let mut sum = subsets[0]; + let mut num_bits = 0; + for (idx, subset) in subsets.iter().skip(1).enumerate() { + // 2^num_bits + let range_shift: [u8; 32] = { + num_bits += ranges[idx].len(); + let mut range_shift = [0u8; 32]; + U256([2, 0, 0, 0]) + .pow(U256([num_bits as u64, 0, 0, 0])) + .to_little_endian(&mut range_shift); + range_shift + }; + sum += subset * pallas::Base::from_repr(range_shift).unwrap(); + } + assert_eq!(field_elem, sum); + }; + decompose(pallas::Base::random(rng), &[0..255]); + decompose(pallas::Base::random(rng), &[0..1, 1..255]); + decompose(pallas::Base::random(rng), &[0..254, 254..255]); + decompose(pallas::Base::random(rng), &[0..127, 127..255]); + decompose(pallas::Base::random(rng), &[0..128, 128..255]); + decompose( + pallas::Base::random(rng), + &[0..50, 50..100, 100..150, 150..200, 200..255], + ); + } + + prop_compose! { + fn arb_scalar()(bytes in prop::array::uniform32(0u8..)) -> pallas::Scalar { + // Instead of rejecting out-of-range bytes, let's reduce them. + let mut buf = [0; 64]; + buf[..32].copy_from_slice(&bytes); + pallas::Scalar::from_uniform_bytes(&buf) + } + } + + proptest! { + #[test] + fn test_decompose_word( + scalar in arb_scalar(), + window_num_bits in 1u8..9 + ) { + // Get decomposition into `window_num_bits` bits + let decomposed = decompose_word(&scalar, pallas::Scalar::NUM_BITS as usize, window_num_bits as usize); + + // Flatten bits + let bits = decomposed + .iter() + .flat_map(|window| (0..window_num_bits).map(move |mask| (window & (1 << mask)) != 0)); + + // Ensure this decomposition contains 256 or fewer set bits. + assert!(!bits.clone().skip(32*8).any(|b| b)); + + // Pad or truncate bits to 32 bytes + let bits: Vec = bits.chain(iter::repeat(false)).take(32*8).collect(); + + let bytes: Vec = bits.chunks_exact(8).map(|chunk| chunk.iter().rev().fold(0, |acc, b| (acc << 1) + (*b as u8))).collect(); + + // Check that original scalar is recovered from decomposition + assert_eq!(scalar, pallas::Scalar::from_repr(bytes.try_into().unwrap()).unwrap()); + } + } + + #[test] + fn lebs2ip_round_trip() { + use rand::rngs::OsRng; + + let mut rng = OsRng; + { + let int = rng.next_u64(); + assert_eq!(lebs2ip::<64>(&i2lebsp(int)), int); + } + + assert_eq!(lebs2ip::<64>(&i2lebsp(0)), 0); + assert_eq!( + lebs2ip::<64>(&i2lebsp(0xFFFFFFFFFFFFFFFF)), + 0xFFFFFFFFFFFFFFFF + ); + } + + #[test] + fn i2lebsp_round_trip() { + { + let bitstring = (0..64).map(|_| rand::random()).collect::>(); + assert_eq!( + i2lebsp::<64>(lebs2ip::<64>(&bitstring.clone().try_into().unwrap())).to_vec(), + bitstring + ); + } + + { + let bitstring = [false; 64]; + assert_eq!(i2lebsp(lebs2ip(&bitstring)), bitstring); + } + + { + let bitstring = [true; 64]; + assert_eq!(i2lebsp(lebs2ip(&bitstring)), bitstring); + } + + { + let bitstring = []; + assert_eq!(i2lebsp(lebs2ip(&bitstring)), bitstring); + } + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/utilities/cond_swap.rs b/plonkish_backend/src/circuits/poseidongadget/utilities/cond_swap.rs new file mode 100644 index 00000000..0198bf0c --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/utilities/cond_swap.rs @@ -0,0 +1,289 @@ +//! Gadget and chip for a conditional swap utility. + +use super::{bool_check, ternary, UtilitiesInstructions}; +use ff::{Field, PrimeField}; +use halo2_proofs::{ + circuit::{AssignedCell, Chip, Layouter, Value}, + plonk::{Advice, Column, ConstraintSystem, Constraints, Error, Selector}, + poly::Rotation, +}; +use std::marker::PhantomData; + +/// Instructions for a conditional swap gadget. +pub trait CondSwapInstructions: UtilitiesInstructions { + #[allow(clippy::type_complexity)] + /// Given an input pair (a,b) and a `swap` boolean flag, returns + /// (b,a) if `swap` is set, else (a,b) if `swap` is not set. + /// + /// The second element of the pair is required to be a witnessed + /// value, not a variable that already exists in the circuit. + fn swap( + &self, + layouter: impl Layouter, + pair: (Self::Var, Value), + swap: Value, + ) -> Result<(Self::Var, Self::Var), Error>; +} + +/// A chip implementing a conditional swap. +#[derive(Clone, Debug)] +pub struct CondSwapChip { + config: CondSwapConfig, + _marker: PhantomData, +} + +impl Chip for CondSwapChip { + type Config = CondSwapConfig; + type Loaded = (); + + fn config(&self) -> &Self::Config { + &self.config + } + + fn loaded(&self) -> &Self::Loaded { + &() + } +} + +/// Configuration for the [`CondSwapChip`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CondSwapConfig { + q_swap: Selector, + a: Column, + b: Column, + a_swapped: Column, + b_swapped: Column, + swap: Column, +} + +impl UtilitiesInstructions for CondSwapChip { + type Var = AssignedCell; +} + +impl CondSwapInstructions for CondSwapChip { + #[allow(clippy::type_complexity)] + fn swap( + &self, + mut layouter: impl Layouter, + pair: (Self::Var, Value), + swap: Value, + ) -> Result<(Self::Var, Self::Var), Error> { + let config = self.config(); + + layouter.assign_region( + || "swap", + |mut region| { + // Enable `q_swap` selector + config.q_swap.enable(&mut region, 0)?; + + // Copy in `a` value + let a = pair.0.copy_advice(|| "copy a", &mut region, config.a, 0)?; + + // Witness `b` value + let b = region.assign_advice(|| "witness b", config.b, 0, || pair.1)?; + + // Witness `swap` value + let swap_val = swap.map(|swap| F::from(swap as u64)); + region.assign_advice(|| "swap", config.swap, 0, || swap_val)?; + + // Conditionally swap a + let a_swapped = { + let a_swapped = a + .value() + .zip(b.value()) + .zip(swap) + .map(|((a, b), swap)| if swap { b } else { a }) + .cloned(); + region.assign_advice(|| "a_swapped", config.a_swapped, 0, || a_swapped)? + }; + + // Conditionally swap b + let b_swapped = { + let b_swapped = a + .value() + .zip(b.value()) + .zip(swap) + .map(|((a, b), swap)| if swap { a } else { b }) + .cloned(); + region.assign_advice(|| "b_swapped", config.b_swapped, 0, || b_swapped)? + }; + + // Return swapped pair + Ok((a_swapped, b_swapped)) + }, + ) + } +} + +impl CondSwapChip { + /// Configures this chip for use in a circuit. + /// + /// # Side-effects + /// + /// `advices[0]` will be equality-enabled. + pub fn configure( + meta: &mut ConstraintSystem, + advices: [Column; 5], + ) -> CondSwapConfig { + let a = advices[0]; + // Only column a is used in an equality constraint directly by this chip. + meta.enable_equality(a); + + let q_swap = meta.selector(); + + let config = CondSwapConfig { + q_swap, + a, + b: advices[1], + a_swapped: advices[2], + b_swapped: advices[3], + swap: advices[4], + }; + + // TODO: optimise shape of gate for Merkle path validation + + meta.create_gate("a' = b ⋅ swap + a ⋅ (1-swap)", |meta| { + let q_swap = meta.query_selector(q_swap); + + let a = meta.query_advice(config.a, Rotation::cur()); + let b = meta.query_advice(config.b, Rotation::cur()); + let a_swapped = meta.query_advice(config.a_swapped, Rotation::cur()); + let b_swapped = meta.query_advice(config.b_swapped, Rotation::cur()); + let swap = meta.query_advice(config.swap, Rotation::cur()); + + // This checks that `a_swapped` is equal to `b` when `swap` is set, + // but remains as `a` when `swap` is not set. + let a_check = a_swapped - ternary(swap.clone(), b.clone(), a.clone()); + + // This checks that `b_swapped` is equal to `a` when `swap` is set, + // but remains as `b` when `swap` is not set. + let b_check = b_swapped - ternary(swap.clone(), a, b); + + // Check `swap` is boolean. + let bool_check = bool_check(swap); + + Constraints::with_selector( + q_swap, + [ + ("a check", a_check), + ("b check", b_check), + ("swap is bool", bool_check), + ], + ) + }); + + config + } + + /// Constructs a [`CondSwapChip`] given a [`CondSwapConfig`]. + pub fn construct(config: CondSwapConfig) -> Self { + CondSwapChip { + config, + _marker: PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use super::super::UtilitiesInstructions; + use super::{CondSwapChip, CondSwapConfig, CondSwapInstructions}; + use ff::PrimeField; + use group::ff::Field; + use halo2_curves::pasta::pallas::Base; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{Circuit, ConstraintSystem, Error}, + }; + use rand::rngs::OsRng; + + #[test] + fn cond_swap() { + #[derive(Default)] + struct MyCircuit { + a: Value, + b: Value, + swap: Value, + } + + impl Circuit for MyCircuit { + type Config = CondSwapConfig; + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let advices = [ + meta.advice_column(), + meta.advice_column(), + meta.advice_column(), + meta.advice_column(), + meta.advice_column(), + ]; + + CondSwapChip::::configure(meta, advices) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let chip = CondSwapChip::::construct(config.clone()); + + // Load the pair and the swap flag into the circuit. + let a = chip.load_private(layouter.namespace(|| "a"), config.a, self.a)?; + // Return the swapped pair. + let swapped_pair = chip.swap( + layouter.namespace(|| "swap"), + (a.clone(), self.b), + self.swap, + )?; + + self.swap + .zip(a.value().zip(self.b.as_ref())) + .zip(swapped_pair.0.value().zip(swapped_pair.1.value())) + .assert_if_known(|((swap, (a, b)), (a_swapped, b_swapped))| { + if *swap { + // Check that `a` and `b` have been swapped + (a_swapped == b) && (b_swapped == a) + } else { + // Check that `a` and `b` have not been swapped + (a_swapped == a) && (b_swapped == b) + } + }); + + Ok(()) + } + } + + let rng = OsRng; + + // Test swap case + { + let circuit: MyCircuit = MyCircuit { + a: Value::known(Base::random(rng)), + b: Value::known(Base::random(rng)), + swap: Value::known(true), + }; + let prover = MockProver::::run::<_, true>(3, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())); + } + + // Test non-swap case + { + let circuit: MyCircuit = MyCircuit { + a: Value::known(Base::random(rng)), + b: Value::known(Base::random(rng)), + swap: Value::known(false), + }; + let prover = MockProver::::run::<_, true>(3, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())); + } + } +} diff --git a/plonkish_backend/src/circuits/poseidongadget/utilities/rm_lookup_range_check.rs b/plonkish_backend/src/circuits/poseidongadget/utilities/rm_lookup_range_check.rs new file mode 100644 index 00000000..a4a48bff --- /dev/null +++ b/plonkish_backend/src/circuits/poseidongadget/utilities/rm_lookup_range_check.rs @@ -0,0 +1,658 @@ +//! Make use of a K-bit lookup table to decompose a field element into K-bit +//! words. + +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, Region}, + plonk::{Advice, Column, ConstraintSystem, Constraints, Error, Selector, TableColumn}, + poly::Rotation, +}; +use std::{convert::TryInto, marker::PhantomData}; + +use ff::PrimeFieldBits; + +use super::*; + +/// The running sum $[z_0, ..., z_W]$. If created in strict mode, $z_W = 0$. +#[derive(Debug)] +pub struct RunningSum(Vec>); +impl std::ops::Deref for RunningSum { + type Target = Vec>; + + fn deref(&self) -> &Vec> { + &self.0 + } +} + +impl RangeConstrained> { + /// Witnesses a subset of the bits in `value` and constrains them to be the correct + /// number of bits. + /// + /// # Panics + /// + /// Panics if `bitrange.len() >= K`. + pub fn witness_short( + lookup_config: &LookupRangeCheckConfig, + layouter: impl Layouter, + value: Value<&F>, + bitrange: Range, + ) -> Result { + let num_bits = bitrange.len(); + assert!(num_bits < K); + + // Witness the subset and constrain it to be the correct number of bits. + lookup_config + .witness_short_check( + layouter, + value.map(|value| bitrange_subset(value, bitrange)), + num_bits, + ) + .map(|inner| Self { + inner, + num_bits, + _phantom: PhantomData, + }) + } +} + +/// Configuration that provides methods for a lookup range check. +#[derive(Eq, PartialEq, Debug, Clone, Copy)] +pub struct LookupRangeCheckConfig { + q_lookup: Selector, + q_running: Selector, + q_bitshift: Selector, + running_sum: Column, + table_idx: TableColumn, + _marker: PhantomData, +} + +impl LookupRangeCheckConfig { + /// The `running_sum` advice column breaks the field element into `K`-bit + /// words. It is used to construct the input expression to the lookup + /// argument. + /// + /// The `table_idx` fixed column contains values from [0..2^K). Looking up + /// a value in `table_idx` constrains it to be within this range. The table + /// can be loaded outside this helper. + /// + /// # Side-effects + /// + /// Both the `running_sum` and `constants` columns will be equality-enabled. + pub fn configure( + meta: &mut ConstraintSystem, + running_sum: Column, + table_idx: TableColumn, + ) -> Self { + meta.enable_equality(running_sum); + + let q_lookup = meta.complex_selector(); + let q_running = meta.complex_selector(); + let q_bitshift = meta.selector(); + let config = LookupRangeCheckConfig { + q_lookup, + q_running, + q_bitshift, + running_sum, + table_idx, + _marker: PhantomData, + }; + + // https://p.z.cash/halo2-0.1:decompose-combined-lookup + meta.lookup("lookup", |meta| { + let q_lookup = meta.query_selector(config.q_lookup); + let q_running = meta.query_selector(config.q_running); + let z_cur = meta.query_advice(config.running_sum, Rotation::cur()); + + // In the case of a running sum decomposition, we recover the word from + // the difference of the running sums: + // z_i = 2^{K}⋅z_{i + 1} + a_i + // => a_i = z_i - 2^{K}⋅z_{i + 1} + let running_sum_lookup = { + let running_sum_word = { + let z_next = meta.query_advice(config.running_sum, Rotation::next()); + z_cur.clone() - z_next * F::from(1 << K) + }; + + q_running.clone() * running_sum_word + }; + + // In the short range check, the word is directly witnessed. + let short_lookup = { + let short_word = z_cur; + let q_short = Expression::Constant(F::ONE) - q_running; + + q_short * short_word + }; + + // Combine the running sum and short lookups: + vec![( + q_lookup * (running_sum_lookup + short_lookup), + config.table_idx, + )] + }); + + // For short lookups, check that the word has been shifted by the correct number of bits. + // https://p.z.cash/halo2-0.1:decompose-short-lookup + meta.create_gate("Short lookup bitshift", |meta| { + let q_bitshift = meta.query_selector(config.q_bitshift); + let word = meta.query_advice(config.running_sum, Rotation::prev()); + let shifted_word = meta.query_advice(config.running_sum, Rotation::cur()); + let inv_two_pow_s = meta.query_advice(config.running_sum, Rotation::next()); + + let two_pow_k = F::from(1 << K); + + // shifted_word = word * 2^{K-s} + // = word * 2^K * inv_two_pow_s + Constraints::with_selector( + q_bitshift, + Some(word * two_pow_k * inv_two_pow_s - shifted_word), + ) + }); + + config + } + + #[cfg(test)] + // Loads the values [0..2^K) into `table_idx`. This is only used in testing + // for now, since the Sinsemilla chip provides a pre-loaded table in the + // Orchard context. + pub fn load(&self, layouter: &mut impl Layouter) -> Result<(), Error> { + layouter.assign_table( + || "table_idx", + |mut table| { + // We generate the row values lazily (we only need them during keygen). + for index in 0..(1 << K) { + table.assign_cell( + || "table_idx", + self.table_idx, + index, + || Value::known(F::from(index as u64)), + )?; + } + Ok(()) + }, + ) + } + + /// Range check on an existing cell that is copied into this helper. + /// + /// Returns an error if `element` is not in a column that was passed to + /// [`ConstraintSystem::enable_equality`] during circuit configuration. + pub fn copy_check( + &self, + mut layouter: impl Layouter, + element: AssignedCell, + num_words: usize, + strict: bool, + ) -> Result, Error> { + layouter.assign_region( + || format!("{:?} words range check", num_words), + |mut region| { + // Copy `element` and initialize running sum `z_0 = element` to decompose it. + let z_0 = element.copy_advice(|| "z_0", &mut region, self.running_sum, 0)?; + self.range_check(&mut region, z_0, num_words, strict) + }, + ) + } + + /// Range check on a value that is witnessed in this helper. + pub fn witness_check( + &self, + mut layouter: impl Layouter, + value: Value, + num_words: usize, + strict: bool, + ) -> Result, Error> { + layouter.assign_region( + || "Witness element", + |mut region| { + let z_0 = + region.assign_advice(|| "Witness element", self.running_sum, 0, || value)?; + self.range_check(&mut region, z_0, num_words, strict) + }, + ) + } + + /// If `strict` is set to "true", the field element must fit into + /// `num_words * K` bits. In other words, the final cumulative sum `z_{num_words}` + /// must be zero. + /// + /// If `strict` is set to "false", the final `z_{num_words}` is not constrained. + /// + /// `element` must have been assigned to `self.running_sum` at offset 0. + fn range_check( + &self, + region: &mut Region<'_, F>, + element: AssignedCell, + num_words: usize, + strict: bool, + ) -> Result, Error> { + // `num_words` must fit into a single field element. + assert!(num_words * K <= F::CAPACITY as usize); + let num_bits = num_words * K; + + // Chunk the first num_bits bits into K-bit words. + let words = { + // Take first num_bits bits of `element`. + let bits = element.value().map(|element| { + element + .to_le_bits() + .into_iter() + .take(num_bits) + .collect::>() + }); + + bits.map(|bits| { + bits.chunks_exact(K) + .map(|word| F::from(lebs2ip::(&(word.try_into().unwrap())))) + .collect::>() + }) + .transpose_vec(num_words) + }; + + let mut zs = vec![element.clone()]; + + // Assign cumulative sum such that + // z_i = 2^{K}⋅z_{i + 1} + a_i + // => z_{i + 1} = (z_i - a_i) / (2^K) + // + // For `element` = a_0 + 2^10 a_1 + ... + 2^{120} a_{12}}, initialize z_0 = `element`. + // If `element` fits in 130 bits, we end up with z_{13} = 0. + let mut z = element; + let inv_two_pow_k = F::from(1u64 << K).invert().unwrap(); + for (idx, word) in words.iter().enumerate() { + // Enable q_lookup on this row + self.q_lookup.enable(region, idx)?; + // Enable q_running on this row + self.q_running.enable(region, idx)?; + + // z_next = (z_cur - m_cur) / 2^K + z = { + let z_val = z + .value() + .zip(*word) + .map(|(z, word)| (*z - word) * inv_two_pow_k); + + // Assign z_next + region.assign_advice( + || format!("z_{:?}", idx + 1), + self.running_sum, + idx + 1, + || z_val, + )? + }; + zs.push(z.clone()); + } + + if strict { + // Constrain the final `z` to be zero. + region.constrain_constant(zs.last().unwrap().cell(), F::ZERO)?; + } + + Ok(RunningSum(zs)) + } + + /// Short range check on an existing cell that is copied into this helper. + /// + /// # Panics + /// + /// Panics if NUM_BITS is equal to or larger than K. + pub fn copy_short_check( + &self, + mut layouter: impl Layouter, + element: AssignedCell, + num_bits: usize, + ) -> Result<(), Error> { + assert!(num_bits < K); + layouter.assign_region( + || format!("Range check {:?} bits", num_bits), + |mut region| { + // Copy `element` to use in the k-bit lookup. + let element = + element.copy_advice(|| "element", &mut region, self.running_sum, 0)?; + + self.short_range_check(&mut region, element, num_bits) + }, + ) + } + + /// Short range check on value that is witnessed in this helper. + /// + /// # Panics + /// + /// Panics if num_bits is larger than K. + pub fn witness_short_check( + &self, + mut layouter: impl Layouter, + element: Value, + num_bits: usize, + ) -> Result, Error> { + assert!(num_bits <= K); + layouter.assign_region( + || format!("Range check {:?} bits", num_bits), + |mut region| { + // Witness `element` to use in the k-bit lookup. + let element = + region.assign_advice(|| "Witness element", self.running_sum, 0, || element)?; + + self.short_range_check(&mut region, element.clone(), num_bits)?; + + Ok(element) + }, + ) + } + + /// Constrain `x` to be a NUM_BITS word. + /// + /// `element` must have been assigned to `self.running_sum` at offset 0. + fn short_range_check( + &self, + region: &mut Region<'_, F>, + element: AssignedCell, + num_bits: usize, + ) -> Result<(), Error> { + // Enable lookup for `element`, to constrain it to 10 bits. + self.q_lookup.enable(region, 0)?; + + // Enable lookup for shifted element, to constrain it to 10 bits. + self.q_lookup.enable(region, 1)?; + + // Check element has been shifted by the correct number of bits. + self.q_bitshift.enable(region, 1)?; + + // Assign shifted `element * 2^{K - num_bits}` + let shifted = element.value().into_field() * F::from(1 << (K - num_bits)); + + region.assign_advice( + || format!("element * 2^({}-{})", K, num_bits), + self.running_sum, + 1, + || shifted, + )?; + + // Assign 2^{-num_bits} from a fixed column. + let inv_two_pow_s = F::from(1 << num_bits).invert().unwrap(); + region.assign_advice_from_constant( + || format!("2^(-{})", num_bits), + self.running_sum, + 2, + inv_two_pow_s, + )?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::LookupRangeCheckConfig; + + use super::super::lebs2ip; + use crate::sinsemilla::primitives::K; + + use ff::{Field, PrimeFieldBits}; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::{FailureLocation, MockProver, VerifyFailure}, + plonk::{Circuit, ConstraintSystem, Error}, + }; + use halo2_curves::pasta::pallas; + + use std::{convert::TryInto, marker::PhantomData}; + + #[test] + fn lookup_range_check() { + #[derive(Clone, Copy)] + struct MyCircuit { + num_words: usize, + _marker: PhantomData, + } + + impl Circuit for MyCircuit { + type Config = LookupRangeCheckConfig; + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + *self + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let running_sum = meta.advice_column(); + let table_idx = meta.lookup_table_column(); + let constants = meta.fixed_column(); + meta.enable_constant(constants); + + LookupRangeCheckConfig::::configure(meta, running_sum, table_idx) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // Load table_idx + config.load(&mut layouter)?; + + // Lookup constraining element to be no longer than num_words * K bits. + let elements_and_expected_final_zs = [ + (F::from((1 << (self.num_words * K)) - 1), F::ZERO, true), // a word that is within self.num_words * K bits long + (F::from(1 << (self.num_words * K)), F::ONE, false), // a word that is just over self.num_words * K bits long + ]; + + fn expected_zs( + element: F, + num_words: usize, + ) -> Vec { + let chunks = { + element + .to_le_bits() + .iter() + .by_vals() + .take(num_words * K) + .collect::>() + .chunks_exact(K) + .map(|chunk| F::from(lebs2ip::(chunk.try_into().unwrap()))) + .collect::>() + }; + let expected_zs = { + let inv_two_pow_k = F::from(1 << K).invert().unwrap(); + chunks.iter().fold(vec![element], |mut zs, a_i| { + // z_{i + 1} = (z_i - a_i) / 2^{K} + let z = (zs[zs.len() - 1] - a_i) * inv_two_pow_k; + zs.push(z); + zs + }) + }; + expected_zs + } + + for (element, expected_final_z, strict) in elements_and_expected_final_zs.iter() { + let expected_zs = expected_zs::(*element, self.num_words); + + let zs = config.witness_check( + layouter.namespace(|| format!("Lookup {:?}", self.num_words)), + Value::known(*element), + self.num_words, + *strict, + )?; + + assert_eq!(*expected_zs.last().unwrap(), *expected_final_z); + + for (expected_z, z) in expected_zs.into_iter().zip(zs.iter()) { + z.value().assert_if_known(|z| &&expected_z == z); + } + } + Ok(()) + } + } + + { + let circuit: MyCircuit = MyCircuit { + num_words: 6, + _marker: PhantomData, + }; + + let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())); + } + } + + #[test] + fn short_range_check() { + struct MyCircuit { + element: Value, + num_bits: usize, + } + + impl Circuit for MyCircuit { + type Config = LookupRangeCheckConfig; + type FloorPlanner = SimpleFloorPlanner; + #[cfg(feature = "circuit-params")] + type Params = (); + + fn without_witnesses(&self) -> Self { + MyCircuit { + element: Value::unknown(), + num_bits: self.num_bits, + } + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let running_sum = meta.advice_column(); + let table_idx = meta.lookup_table_column(); + let constants = meta.fixed_column(); + meta.enable_constant(constants); + + LookupRangeCheckConfig::::configure(meta, running_sum, table_idx) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // Load table_idx + config.load(&mut layouter)?; + + // Lookup constraining element to be no longer than num_bits. + config.witness_short_check( + layouter.namespace(|| format!("Lookup {:?} bits", self.num_bits)), + self.element, + self.num_bits, + )?; + + Ok(()) + } + } + + // Edge case: zero bits + { + let circuit: MyCircuit = MyCircuit { + element: Value::known(pallas::Base::zero()), + num_bits: 0, + }; + let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())); + } + + // Edge case: K bits + { + let circuit: MyCircuit = MyCircuit { + element: Value::known(pallas::Base::from((1 << K) - 1)), + num_bits: K, + }; + let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())); + } + + // Element within `num_bits` + { + let circuit: MyCircuit = MyCircuit { + element: Value::known(pallas::Base::from((1 << 6) - 1)), + num_bits: 6, + }; + let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); + assert_eq!(prover.verify(), Ok(())); + } + + // Element larger than `num_bits` but within K bits + { + let circuit: MyCircuit = MyCircuit { + element: Value::known(pallas::Base::from(1 << 6)), + num_bits: 6, + }; + let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); + assert_eq!( + prover.verify(), + Err(vec![VerifyFailure::Lookup { + name: "lookup".to_string(), + lookup_index: 0, + location: FailureLocation::InRegion { + region: (1, "Range check 6 bits").into(), + offset: 1, + }, + }]), + ); + } + + // Element larger than K bits + { + let circuit: MyCircuit = MyCircuit { + element: Value::known(pallas::Base::from(1 << K)), + num_bits: 6, + }; + let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); + assert_eq!( + prover.verify(), + Err(vec![ + VerifyFailure::Lookup { + name: "lookup".to_string(), + lookup_index: 0, + location: FailureLocation::InRegion { + region: (1, "Range check 6 bits").into(), + offset: 0, + }, + }, + VerifyFailure::Lookup { + name: "lookup".to_string(), + lookup_index: 0, + location: FailureLocation::InRegion { + region: (1, "Range check 6 bits").into(), + offset: 1, + }, + }, + ]) + ); + } + + // Element which is not within `num_bits`, but which has a shifted value within + // num_bits + { + let num_bits = 6; + let shifted = pallas::Base::from((1 << num_bits) - 1); + // Recall that shifted = element * 2^{K-s} + // => element = shifted * 2^{s-K} + let element = shifted + * pallas::Base::from(1 << (K as u64 - num_bits)) + .invert() + .unwrap(); + let circuit: MyCircuit = MyCircuit { + element: Value::known(element), + num_bits: num_bits as usize, + }; + let prover = MockProver::::run(11, &circuit, vec![]).unwrap(); + assert_eq!( + prover.verify(), + Err(vec![VerifyFailure::Lookup { + name: "lookup".to_string(), + lookup_index: 0, + location: FailureLocation::InRegion { + region: (1, "Range check 6 bits").into(), + offset: 0, + }, + }]) + ); + } + } +} diff --git a/plonkish_backend/src/lib.rs b/plonkish_backend/src/lib.rs index 1cd59eec..ede727b2 100644 --- a/plonkish_backend/src/lib.rs +++ b/plonkish_backend/src/lib.rs @@ -2,6 +2,7 @@ pub mod accumulation; pub mod backend; +pub mod circuits; pub mod frontend; pub mod pcs; pub mod piop; diff --git a/plonkish_backend/src/pcs/multilinear/zeromorph.rs b/plonkish_backend/src/pcs/multilinear/zeromorph.rs index 32713e3b..929993b0 100644 --- a/plonkish_backend/src/pcs/multilinear/zeromorph.rs +++ b/plonkish_backend/src/pcs/multilinear/zeromorph.rs @@ -381,7 +381,6 @@ where let comm = if cfg!(feature = "sanity-check") { assert_eq!(f.evaluate(&x), C::Scalar::ZERO); - UnivariateIpa::::commit(pp, &f)? } else { Default::default() diff --git a/plonkish_backend/src/piop/sum_check/classic/eval.rs b/plonkish_backend/src/piop/sum_check/classic/eval.rs index 429c16a1..62bcfb39 100644 --- a/plonkish_backend/src/piop/sum_check/classic/eval.rs +++ b/plonkish_backend/src/piop/sum_check/classic/eval.rs @@ -24,6 +24,7 @@ impl Evaluations { Self(vec![F::ZERO; degree + 1]) } + // All points between 0 and degree +1 fn points(degree: usize) -> Vec { steps(F::ZERO).take(degree + 1).collect() } diff --git a/plonkish_backend/src/poly/multilinear.rs b/plonkish_backend/src/poly/multilinear.rs index c673eb66..efd9f67c 100644 --- a/plonkish_backend/src/poly/multilinear.rs +++ b/plonkish_backend/src/poly/multilinear.rs @@ -195,6 +195,7 @@ impl MultilinearPolynomial { Self::new(output) } + // New polynomial with first variable fixed to x_i, i.e. f(X_i, .., X_n) becomes f(x_i, X_{i+1}, .., X_n). pub fn fix_var_in_place(&mut self, x_i: &F, buf: &mut Self) { merge_into(&mut buf.evals, self.evals(), x_i, 1, 0); buf.num_vars = self.num_vars - 1; diff --git a/plonkish_backend/src/util/arithmetic/msm.rs b/plonkish_backend/src/util/arithmetic/msm.rs index 3ca3d2a2..4764863f 100644 --- a/plonkish_backend/src/util/arithmetic/msm.rs +++ b/plonkish_backend/src/util/arithmetic/msm.rs @@ -157,7 +157,6 @@ fn variable_base_msm_serial( } } } - let scalars = scalars.iter().map(|scalar| scalar.to_repr()).collect_vec(); let num_bytes = scalars[0].as_ref().len(); let num_bits = 8 * num_bytes;