diff --git a/ff_derive/src/lib.rs b/ff_derive/src/lib.rs index 0a0b08d..17a0f3b 100644 --- a/ff_derive/src/lib.rs +++ b/ff_derive/src/lib.rs @@ -164,7 +164,7 @@ pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut gen = proc_macro2::TokenStream::new(); - let (constants_impl, sqrt_impl) = + let (constants_impl, sqrt_impl, sqrt_ratio_impl) = prime_field_constants_and_sqrt(&ast.ident, &modulus, limbs, generator); gen.extend(constants_impl); @@ -176,6 +176,7 @@ pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { &endianness, limbs, sqrt_impl, + sqrt_ratio_impl, )); // Return the generated impl @@ -462,12 +463,13 @@ fn test_exp() { ); } + fn prime_field_constants_and_sqrt( name: &syn::Ident, modulus: &BigUint, limbs: usize, generator: BigUint, -) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { +) -> (proc_macro2::TokenStream, proc_macro2::TokenStream, proc_macro2::TokenStream) { let bytes = limbs * 8; let modulus_num_bits = biguint_num_bits(modulus.clone()); @@ -498,63 +500,25 @@ fn prime_field_constants_and_sqrt( // Compute 2^s root of unity given the generator let root_of_unity = exp(generator.clone(), &t, &modulus); - let root_of_unity_inv = biguint_to_u64_vec(to_mont(invert(root_of_unity.clone())), limbs); - let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity), limbs); - let delta = biguint_to_u64_vec( - to_mont(exp(generator.clone(), &(BigUint::one() << s), &modulus)), - limbs, - ); - let generator = biguint_to_u64_vec(to_mont(generator), limbs); - - let sqrt_impl = - if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { - // Addition chain for (r + 1) // 4 - let mod_plus_1_over_4 = pow_fixed::generate( - "e! {self}, - (modulus + BigUint::from_str("1").unwrap()) >> 2, - ); + let root_of_unity_inv = biguint_to_u64_vec(to_mont(invert(root_of_unity.clone())), limbs); - quote! { - use ::ff::derive::subtle::ConstantTimeEq; - - // Because r = 3 (mod 4) - // sqrt can be done with only one exponentiation, - // via the computation of self^((r + 1) // 4) (mod r) - let sqrt = { - #mod_plus_1_over_4 - }; - - ::ff::derive::subtle::CtOption::new( - sqrt, - (sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root. - ) - } - } else { - // Addition chain for (t - 1) // 2 - let t_minus_1_over_2 = if t == BigUint::one() { - quote!( #name::ONE ) - } else { - pow_fixed::generate("e! {self}, (&t - BigUint::one()) >> 1) - }; - - quote! { - // Tonelli-Shanks algorithm works for every remaining odd prime. - // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) + // Tonelli shanks logic starts here + fn generate_tonelli_shanks_loop(name: &syn::Ident) -> proc_macro2::TokenStream { + quote! { + /// The loop takes in x and the `projenator` w = x^((t - 1)/2). The function + /// the modifies x, and the final value for x is sqrt(x) (iff x is a QR). + fn tonelli_shanks_loop(x: &mut #name, w: &#name) { use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq}; - - // w = self^((t - 1) // 2) - let w = { - #t_minus_1_over_2 - }; - - let mut v = S; - let mut x = *self * &w; - let mut b = x * &w; + use ::ff::PrimeField; + + let mut v = #name::S; + *x *= w; + let mut b = *x * w; // Initialize z as the 2^S root of unity. - let mut z = ROOT_OF_UNITY; + let mut z = #name::ROOT_OF_UNITY; - for max_v in (1..=S).rev() { + for max_v in (1..=#name::S).rev() { let mut k = 1; let mut tmp = b.square(); let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into(); @@ -569,20 +533,177 @@ fn prime_field_constants_and_sqrt( z = #name::conditional_select(&z, &new_z, j_less_than_v); } - let result = x * &z; - x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE)); + let result = *x * &z; + *x = #name::conditional_select(&result, x, b.ct_eq(&#name::ONE)); z = z.square(); b *= &z; v = k; } - - ::ff::derive::subtle::CtOption::new( - x, - (x * &x).ct_eq(self), // Only return Some if it's the square root. - ) } + } + } + + // Recall p - 1 = 2^s * t + // Addition chain for (t - 1) // 2 + + let t_minus_1_over_2 = if t == BigUint::one() { + quote!( #name::ONE ) + } else { + pow_fixed::generate("e! {x}, (&t - BigUint::one()) >> 1) + }; + + // Tonelli--Shanks inner loop + let tonelli_shanks_loop = generate_tonelli_shanks_loop(&name); + + let sqrt_impl = if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { + // Addition chain for (r + 1) // 4 + let mod_plus_1_over_4 = pow_fixed::generate( + "e! {self}, + (modulus + BigUint::from_str("1").unwrap()) >> 2, + ); + + quote! { + use ::ff::derive::subtle::ConstantTimeEq; + + // Because r = 3 (mod 4) + // sqrt can be done with only one exponentiation, + // via the computation of self^((r + 1) // 4) (mod r) + let sqrt = { + #mod_plus_1_over_4 + }; + + ::ff::derive::subtle::CtOption::new( + sqrt, + (sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root. + ) + } + } else { + quote! { + // Remark: The Tonelli-Shanks algorithm works for every odd prime. + // However, leave the above 3 mod 4. + // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) + use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq}; + #tonelli_shanks_loop + + // w = self^((t - 1) // 2) + let mut x = *self; + + let w = { + #t_minus_1_over_2 + }; + + tonelli_shanks_loop(&mut x, &w); + + ::ff::derive::subtle::CtOption::new( + x, + (x * &x).ct_eq(self), // Only return Some if it's the square root. + ) + } + }; + + // Generates an implimentation of sqrt(num/div) using the merged inverse-and-sqrt + // version of the Tonelli--Shanks algorithm combining Scott's `Tricks of the trade` paper + // Section 2 and 2.1 (see https://eprint.iacr.org/2020/1497) + // This is a more general version of p.15 of https://eprint.iacr.org/2011/368.pdf + let sqrt_ratio_impl = { + // setup some compile-time constants + let zeta_2 = biguint_to_u64_vec( + to_mont(exp(root_of_unity.clone(), &BigUint::from_str("2").unwrap(), &modulus)), + limbs, + ); + let tw = exp(root_of_unity.clone(), &BigUint::from_str("3").unwrap(), &modulus); + let tw_inv = biguint_to_u64_vec( + to_mont(invert(tw.clone())), + limbs + ); + let tw_proj = biguint_to_u64_vec( + to_mont(exp(tw.clone(), &((&t - BigUint::one()) >> 1), &modulus)), + limbs, + ); + let tw = biguint_to_u64_vec( + to_mont(tw), + limbs + ); + + // Fixed exponentiation (x^2*w^4)^(2^(s-1) - 1) + let two_s_m1 = BigUint::one() << (&s - 1); + let two_s_minus_1_m1 = if two_s_m1 == BigUint::one() { + quote!( #name::ONE ) + } else { + pow_fixed::generate("e! {x2w4}, &two_s_m1 - BigUint::one()) }; + + quote! { + use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq, CtOption}; + use ::ff::PrimeField; + #tonelli_shanks_loop + + const Z2: #name = #name(#zeta_2); // #name::ROOT_OF_UNITY^2 + const TW: #name = #name(#tw); // #name::ROOT_OF_UNITY^3 + const TW_INV: #name = #name(#tw_inv); // TW^(-1) + const TW_PROJ: #name = #name(#tw_proj); // projenator of twist TW^((t - 1)/2) + + let num_is_zero = num.is_zero(); + let div_is_zero = div.is_zero(); + + let mut x = num.cube() * div; + let mut sqrtx = x.clone(); + + let mut tw_x = x.clone(); + tw_x *= TW; + let mut tw_sqrtx = tw_x.clone(); + + let w = { + #t_minus_1_over_2 + }; + let tw_w = TW_PROJ * &w; + + tonelli_shanks_loop(&mut sqrtx, &w); //sqrtx = sqrt(x) now + tonelli_shanks_loop(&mut tw_sqrtx, &tw_w); //tw_sqrtx = sqrt(tw_x) now + // Remark: One can avoid a second call to the loop when p = 3 (4) or p = 5 (8) + // since then tw_sqrtx = tw * sqrtx (cf. p15 of https://eprint.iacr.org/2011/368.pdf) + + // x <- x^(-1) = (x^2 * w^4)^(2^(s-1) - 1) * (x * w^4) + let xw4 = w.square().square() * &x; + let mut x2w4 = x * &xw4; + x2w4 = { + #two_s_minus_1_m1 + }; + x = x2w4 * xw4; + + // tx_x <- tw_x^(-1) = tw^(-1) * x^(-1) + let mut tw_x = TW_INV * &x; + + // x = sqrt(num/div) + let n2 = num.square(); + x = x * sqrtx * &n2; + + // tw_x = sqrt(zeta * num / div) + tw_x = Z2 * tw_x * tw_sqrtx * n2; + + let tw_num = #name::ROOT_OF_UNITY * num; + + let is_square = (x.square() * div).ct_eq(num); + let is_nonsquare = (tw_x.square() * div).ct_eq(&tw_num); + + assert!(bool::from( + num_is_zero | div_is_zero | (is_square ^ is_nonsquare) + )); + ( + is_square & (num_is_zero | !div_is_zero), + #name::conditional_select(&tw_x, &x, is_square), + ) + } + }; + // Some more constants + let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity), limbs); + let delta = biguint_to_u64_vec( + to_mont(exp(generator.clone(), &(BigUint::one() << s), &modulus)), + limbs, + ); + let generator = biguint_to_u64_vec(to_mont(generator), limbs); + // Compute R^2 mod m let r2 = biguint_to_u64_vec((&r * &r) % modulus, limbs); @@ -649,6 +770,7 @@ fn prime_field_constants_and_sqrt( const DELTA: #name = #name(#delta); }, sqrt_impl, + sqrt_ratio_impl, ) } @@ -660,6 +782,7 @@ fn prime_field_impl( endianness: &ReprEndianness, limbs: usize, sqrt_impl: proc_macro2::TokenStream, + sqrt_ratio_impl: proc_macro2::TokenStream, ) -> proc_macro2::TokenStream { // Returns r{n} as an ident. fn get_temp(n: usize) -> syn::Ident { @@ -880,6 +1003,7 @@ fn prime_field_impl( } } + let squaring_impl = sqr_impl(quote! {self}, limbs); let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs); let invert_impl = inv_impl(quote! {self}, modulus); @@ -1317,7 +1441,7 @@ fn prime_field_impl( } fn sqrt_ratio(num: &Self, div: &Self) -> (::ff::derive::subtle::Choice, Self) { - ::ff::helpers::sqrt_ratio_generic(num, div) + #sqrt_ratio_impl } fn sqrt(&self) -> ::ff::derive::subtle::CtOption { diff --git a/tests/derive.rs b/tests/derive.rs index 5baf435..1f94a7e 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -156,3 +156,41 @@ fn sqrt() { use rand::rngs::OsRng; test(Fp::random(OsRng)); } + +#[test] +fn sqrt_ratio_test() { + use ff::{Field, PrimeField}; + + #[derive(PrimeField)] + #[PrimeFieldModulus = "357686312646216567629137"] + #[PrimeFieldGenerator = "5"] + #[PrimeFieldReprEndianness = "little"] + struct Fp([u64; 2]); + + fn test(num: Fp, div: Fp) { + let (choice, sqrt) = Fp::sqrt_ratio(&num, &div); + + if bool::from(choice) { + assert!(div != Fp::ZERO); + let div_inv = div.invert().unwrap(); + let expected = num * div_inv; + assert_eq!(sqrt.square(), expected); + } else if div != Fp::ZERO { + let div_inv = div.invert().unwrap(); + let expected = Fp::ROOT_OF_UNITY * num * div_inv; + assert_eq!(sqrt.square(), expected); + } else { + assert_eq!(sqrt.square(), Fp::ZERO); + } + } + + // Easy cases + test(Fp::ZERO, Fp::ONE); // sqrt(0/1) = (true, 0) + test(Fp::ONE, Fp::ZERO); // sqrt(1/0) = (false, 0) + + // Random case + use rand::rngs::OsRng; + let a = Fp::random(&mut OsRng); + let b = Fp::random(&mut OsRng); + test(a, b); +}