From b67964f4a070b15928701c46fe4d499094a4f5d3 Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Mon, 18 Aug 2025 17:18:17 +0200 Subject: [PATCH] feat(zk): add ZeroizeZp type that is automatically zeroized on drop --- tfhe-zk-pok/src/curve_api/bls12_446.rs | 209 +++++++++++++++++++++++-- 1 file changed, 193 insertions(+), 16 deletions(-) diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index 0b6ffcac7..1ca99f78b 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -174,6 +174,12 @@ mod g1 { } } + pub fn mul_scalar_zeroize(self, scalar: &ZeroizeZp) -> Self { + Self { + inner: scalar.mul_point(self.inner), + } + } + #[track_caller] pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { use rayon::prelude::*; @@ -534,6 +540,12 @@ mod g2 { } } + pub fn mul_scalar_zeroize(self, scalar: &ZeroizeZp) -> Self { + Self { + inner: scalar.mul_point(self.inner), + } + } + pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self { use rayon::prelude::*; let n_threads = rayon::current_num_threads(); @@ -945,12 +957,13 @@ mod gt { mod zp { use super::*; + use crate::curve_446::FrConfig; use crate::serialization::InvalidArraySizeError; - use ark_ff::Fp; + use ark_ff::{Fp, FpConfig, MontBackend, PrimeField}; use tfhe_versionable::Versionize; - use zeroize::Zeroize; + use zeroize::{Zeroize, ZeroizeOnDrop}; - fn redc(n: [u64; 5], nprime: u64, mut t: [u64; 7]) -> [u64; 5] { + fn redc(n: [u64; 5], nprime: u64, t: &mut [u64; 7], out: &mut [u64; 5]) { for i in 0..2 { let mut c = 0u64; let m = u64::wrapping_mul(t[i], nprime); @@ -968,20 +981,22 @@ mod zp { } } - let mut t = [t[2], t[3], t[4], t[5], t[6]]; + out[0] = t[2]; + out[1] = t[3]; + out[2] = t[4]; + out[3] = t[5]; + out[4] = t[6]; - if t.into_iter().rev().ge(n.into_iter().rev()) { + if out.iter().rev().ge(n.iter().rev()) { let mut o = false; for i in 0..5 { - let (ti, o0) = u64::overflowing_sub(t[i], n[i]); + let (ti, o0) = u64::overflowing_sub(out[i], n[i]); let (ti, o1) = u64::overflowing_sub(ti, o as u64); o = o0 | o1; - t[i] = ti; + out[i] = ti; } } - assert!(t.into_iter().rev().lt(n.into_iter().rev())); - - t + assert!(out.iter().rev().lt(n.iter().rev())); } #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)] @@ -1066,11 +1081,10 @@ mod zp { let mut n = n; // zero the 22 leading bits, so the result is <= MODULUS * 2^128 n[6] &= (1 << 42) - 1; + let mut res = [0; 5]; + redc(MODULUS.0, MODULUS_MONTGOMERY, &mut n, &mut res); Zp { - inner: Fp( - BigInt(redc(MODULUS.0, MODULUS_MONTGOMERY, n)), - core::marker::PhantomData, - ), + inner: Fp(BigInt(res), core::marker::PhantomData), } } @@ -1195,18 +1209,131 @@ mod zp { iter.fold(Zp::ZERO, Add::add) } } + + /// This is like [`Zp`] but will automatically be zeroized on drop, at the cost of not being + /// Copy + #[derive(Clone, PartialEq, Eq, Hash, ZeroizeOnDrop)] + pub struct ZeroizeZp { + inner: crate::curve_446::Fr, + } + + #[cfg(test)] + impl From for crate::curve_446::Fr { + fn from(value: ZeroizeZp) -> Self { + value.inner + } + } + + impl Mul<&ZeroizeZp> for &ZeroizeZp { + type Output = ZeroizeZp; + + #[inline] + fn mul(self, rhs: &ZeroizeZp) -> Self::Output { + let mut result = self.clone(); + MontBackend::::mul_assign(&mut result.inner, &rhs.inner); + result + } + } + + impl Mul<&ZeroizeZp> for Zp { + type Output = Zp; + + #[inline] + fn mul(mut self, rhs: &ZeroizeZp) -> Self::Output { + MontBackend::::mul_assign(&mut self.inner, &rhs.inner); + self + } + } + + impl Add<&ZeroizeZp> for &ZeroizeZp { + type Output = ZeroizeZp; + + #[inline] + fn add(self, rhs: &ZeroizeZp) -> Self::Output { + let mut result = self.clone(); + MontBackend::::add_assign(&mut result.inner, &rhs.inner); + result + } + } + + impl Add<&ZeroizeZp> for Zp { + type Output = Zp; + + #[inline] + fn add(mut self, rhs: &ZeroizeZp) -> Self::Output { + MontBackend::::add_assign(&mut self.inner, &rhs.inner); + self + } + } + + impl ZeroizeZp { + pub const ZERO: Self = Self { + inner: MontFp!("0"), + }; + + pub const ONE: Self = Self { + inner: MontFp!("1"), + }; + + fn reduce_from_raw_u64x7(n: &mut [u64; 7], out: &mut [u64; 5]) { + const MODULUS: BigInt<5> = BigInt!( + "645383785691237230677916041525710377746967055506026847120930304831624105190538527824412673" + ); + + const MODULUS_MONTGOMERY: u64 = 272467794636046335; + + // zero the 22 leading bits, so the result is <= MODULUS * 2^128 + n[6] &= (1 << 42) - 1; + + redc(MODULUS.0, MODULUS_MONTGOMERY, n, out); + } + + /// Replace the content of the provided element with a random but valid one + pub fn rand_in_place(&mut self, rng: &mut dyn rand::RngCore) { + use rand::Rng; + + let mut values = [0; 7]; + rng.fill(&mut values); + Self::reduce_from_raw_u64x7(&mut values, &mut self.inner.0 .0); + values.zeroize(); + } + + pub fn mul_point + Group>(&self, x: T) -> T { + let zero = T::zero(); + let mut n = self.clone().inner.into_bigint(); + + if n.0 == [0; 5] { + return zero; + } + + let mut y = zero; + let mut x = x; + + for word in &n.0 { + for idx in 0..64 { + let bit = (word >> idx) & 1; + if bit == 1 { + y += x; + } + x.double_in_place(); + } + } + n.zeroize(); + y + } + } } pub use g1::{G1Affine, G1}; pub use g2::{G2Affine, G2}; pub use gt::Gt; -pub use zp::Zp; +pub use zp::{ZeroizeZp, Zp}; #[cfg(test)] mod tests { use super::*; use rand::rngs::StdRng; - use rand::SeedableRng; + use rand::{thread_rng, Rng, SeedableRng}; use std::collections::HashMap; #[test] @@ -1358,4 +1485,54 @@ mod tests { hm.insert(a_affine, 2); assert_eq!(hm.len(), 1); } + + /// Test that ZeroizeZp is equivalent to Zp + #[test] + fn test_zeroize_equivalency() { + let seed = thread_rng().gen(); + println!("zeroize_equivalency seed: {seed:x}"); + let rng = &mut StdRng::seed_from_u64(seed); + + let mut zeroize1 = ZeroizeZp::ZERO; + ZeroizeZp::rand_in_place(&mut zeroize1, rng); + let mut zeroize2 = ZeroizeZp::ZERO; + ZeroizeZp::rand_in_place(&mut zeroize2, rng); + + let rng = &mut StdRng::seed_from_u64(seed); + let zp1 = Zp::rand(rng); + let zp2 = Zp::rand(rng); + + assert_eq!(zp1.inner, zeroize1.clone().into()); + assert_eq!(zp2.inner, zeroize2.clone().into()); + + let sum_zeroize = &zeroize1 + &zeroize2; + let sum_zp = zp1 + zp2; + + assert_eq!(sum_zp.inner, sum_zeroize.clone().into()); + + let sum_zeroize_zp = zp1 + &zeroize2; + + assert_eq!(sum_zp.inner, sum_zeroize_zp.inner); + + let prod_zeroize = &zeroize1 * &zeroize2; + let prod_zp = zp1 * zp2; + + assert_eq!(prod_zp.inner, prod_zeroize.clone().into()); + + let prod_zeroize_zp = zp1 * &zeroize2; + + assert_eq!(prod_zp.inner, prod_zeroize_zp.inner); + + let g1 = G1::GENERATOR; + let g1_zeroize = g1.mul_scalar_zeroize(&zeroize1); + let g1_zp = g1.mul_scalar(zp1); + + assert_eq!(g1_zp, g1_zeroize); + + let g2 = G2::GENERATOR; + let g2_zeroize = g2.mul_scalar_zeroize(&zeroize1); + let g2_zp = g2.mul_scalar(zp1); + + assert_eq!(g2_zp, g2_zeroize); + } }