feat(zk): add ZeroizeZp type that is automatically zeroized on drop

This commit is contained in:
Nicolas Sarlin
2025-08-18 17:18:17 +02:00
committed by Nicolas Sarlin
parent 1647ec8f21
commit b67964f4a0

View File

@@ -174,6 +174,12 @@ mod g1 {
}
}
pub fn mul_scalar_zeroize(self, scalar: &ZeroizeZp) -> Self {
Self {
inner: scalar.mul_point(self.inner),
}
}
#[track_caller]
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self {
use rayon::prelude::*;
@@ -534,6 +540,12 @@ mod g2 {
}
}
pub fn mul_scalar_zeroize(self, scalar: &ZeroizeZp) -> Self {
Self {
inner: scalar.mul_point(self.inner),
}
}
pub fn multi_mul_scalar(bases: &[Self], scalars: &[Zp]) -> Self {
use rayon::prelude::*;
let n_threads = rayon::current_num_threads();
@@ -945,12 +957,13 @@ mod gt {
mod zp {
use super::*;
use crate::curve_446::FrConfig;
use crate::serialization::InvalidArraySizeError;
use ark_ff::Fp;
use ark_ff::{Fp, FpConfig, MontBackend, PrimeField};
use tfhe_versionable::Versionize;
use zeroize::Zeroize;
use zeroize::{Zeroize, ZeroizeOnDrop};
fn redc(n: [u64; 5], nprime: u64, mut t: [u64; 7]) -> [u64; 5] {
fn redc(n: [u64; 5], nprime: u64, t: &mut [u64; 7], out: &mut [u64; 5]) {
for i in 0..2 {
let mut c = 0u64;
let m = u64::wrapping_mul(t[i], nprime);
@@ -968,20 +981,22 @@ mod zp {
}
}
let mut t = [t[2], t[3], t[4], t[5], t[6]];
out[0] = t[2];
out[1] = t[3];
out[2] = t[4];
out[3] = t[5];
out[4] = t[6];
if t.into_iter().rev().ge(n.into_iter().rev()) {
if out.iter().rev().ge(n.iter().rev()) {
let mut o = false;
for i in 0..5 {
let (ti, o0) = u64::overflowing_sub(t[i], n[i]);
let (ti, o0) = u64::overflowing_sub(out[i], n[i]);
let (ti, o1) = u64::overflowing_sub(ti, o as u64);
o = o0 | o1;
t[i] = ti;
out[i] = ti;
}
}
assert!(t.into_iter().rev().lt(n.into_iter().rev()));
t
assert!(out.iter().rev().lt(n.iter().rev()));
}
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Versionize, Hash, Zeroize)]
@@ -1066,11 +1081,10 @@ mod zp {
let mut n = n;
// zero the 22 leading bits, so the result is <= MODULUS * 2^128
n[6] &= (1 << 42) - 1;
let mut res = [0; 5];
redc(MODULUS.0, MODULUS_MONTGOMERY, &mut n, &mut res);
Zp {
inner: Fp(
BigInt(redc(MODULUS.0, MODULUS_MONTGOMERY, n)),
core::marker::PhantomData,
),
inner: Fp(BigInt(res), core::marker::PhantomData),
}
}
@@ -1195,18 +1209,131 @@ mod zp {
iter.fold(Zp::ZERO, Add::add)
}
}
/// This is like [`Zp`] but will automatically be zeroized on drop, at the cost of not being
/// Copy
#[derive(Clone, PartialEq, Eq, Hash, ZeroizeOnDrop)]
pub struct ZeroizeZp {
inner: crate::curve_446::Fr,
}
#[cfg(test)]
impl From<ZeroizeZp> for crate::curve_446::Fr {
fn from(value: ZeroizeZp) -> Self {
value.inner
}
}
impl Mul<&ZeroizeZp> for &ZeroizeZp {
type Output = ZeroizeZp;
#[inline]
fn mul(self, rhs: &ZeroizeZp) -> Self::Output {
let mut result = self.clone();
MontBackend::<FrConfig, 5>::mul_assign(&mut result.inner, &rhs.inner);
result
}
}
impl Mul<&ZeroizeZp> for Zp {
type Output = Zp;
#[inline]
fn mul(mut self, rhs: &ZeroizeZp) -> Self::Output {
MontBackend::<FrConfig, 5>::mul_assign(&mut self.inner, &rhs.inner);
self
}
}
impl Add<&ZeroizeZp> for &ZeroizeZp {
type Output = ZeroizeZp;
#[inline]
fn add(self, rhs: &ZeroizeZp) -> Self::Output {
let mut result = self.clone();
MontBackend::<FrConfig, 5>::add_assign(&mut result.inner, &rhs.inner);
result
}
}
impl Add<&ZeroizeZp> for Zp {
type Output = Zp;
#[inline]
fn add(mut self, rhs: &ZeroizeZp) -> Self::Output {
MontBackend::<FrConfig, 5>::add_assign(&mut self.inner, &rhs.inner);
self
}
}
impl ZeroizeZp {
pub const ZERO: Self = Self {
inner: MontFp!("0"),
};
pub const ONE: Self = Self {
inner: MontFp!("1"),
};
fn reduce_from_raw_u64x7(n: &mut [u64; 7], out: &mut [u64; 5]) {
const MODULUS: BigInt<5> = BigInt!(
"645383785691237230677916041525710377746967055506026847120930304831624105190538527824412673"
);
const MODULUS_MONTGOMERY: u64 = 272467794636046335;
// zero the 22 leading bits, so the result is <= MODULUS * 2^128
n[6] &= (1 << 42) - 1;
redc(MODULUS.0, MODULUS_MONTGOMERY, n, out);
}
/// Replace the content of the provided element with a random but valid one
pub fn rand_in_place(&mut self, rng: &mut dyn rand::RngCore) {
use rand::Rng;
let mut values = [0; 7];
rng.fill(&mut values);
Self::reduce_from_raw_u64x7(&mut values, &mut self.inner.0 .0);
values.zeroize();
}
pub fn mul_point<T: Copy + Zero + Add<Output = T> + Group>(&self, x: T) -> T {
let zero = T::zero();
let mut n = self.clone().inner.into_bigint();
if n.0 == [0; 5] {
return zero;
}
let mut y = zero;
let mut x = x;
for word in &n.0 {
for idx in 0..64 {
let bit = (word >> idx) & 1;
if bit == 1 {
y += x;
}
x.double_in_place();
}
}
n.zeroize();
y
}
}
}
pub use g1::{G1Affine, G1};
pub use g2::{G2Affine, G2};
pub use gt::Gt;
pub use zp::Zp;
pub use zp::{ZeroizeZp, Zp};
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
use rand::{thread_rng, Rng, SeedableRng};
use std::collections::HashMap;
#[test]
@@ -1358,4 +1485,54 @@ mod tests {
hm.insert(a_affine, 2);
assert_eq!(hm.len(), 1);
}
/// Test that ZeroizeZp is equivalent to Zp
#[test]
fn test_zeroize_equivalency() {
let seed = thread_rng().gen();
println!("zeroize_equivalency seed: {seed:x}");
let rng = &mut StdRng::seed_from_u64(seed);
let mut zeroize1 = ZeroizeZp::ZERO;
ZeroizeZp::rand_in_place(&mut zeroize1, rng);
let mut zeroize2 = ZeroizeZp::ZERO;
ZeroizeZp::rand_in_place(&mut zeroize2, rng);
let rng = &mut StdRng::seed_from_u64(seed);
let zp1 = Zp::rand(rng);
let zp2 = Zp::rand(rng);
assert_eq!(zp1.inner, zeroize1.clone().into());
assert_eq!(zp2.inner, zeroize2.clone().into());
let sum_zeroize = &zeroize1 + &zeroize2;
let sum_zp = zp1 + zp2;
assert_eq!(sum_zp.inner, sum_zeroize.clone().into());
let sum_zeroize_zp = zp1 + &zeroize2;
assert_eq!(sum_zp.inner, sum_zeroize_zp.inner);
let prod_zeroize = &zeroize1 * &zeroize2;
let prod_zp = zp1 * zp2;
assert_eq!(prod_zp.inner, prod_zeroize.clone().into());
let prod_zeroize_zp = zp1 * &zeroize2;
assert_eq!(prod_zp.inner, prod_zeroize_zp.inner);
let g1 = G1::GENERATOR;
let g1_zeroize = g1.mul_scalar_zeroize(&zeroize1);
let g1_zp = g1.mul_scalar(zp1);
assert_eq!(g1_zp, g1_zeroize);
let g2 = G2::GENERATOR;
let g2_zeroize = g2.mul_scalar_zeroize(&zeroize1);
let g2_zp = g2.mul_scalar(zp1);
assert_eq!(g2_zp, g2_zeroize);
}
}