mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(zk): impl CanonicalSerialize/Deserialize
This is to allow specifying whether data should be compressed as compression and validation adds a very signigicant overhead especially in wasm where deserialization goes from 6 min to 450ms
This commit is contained in:
@@ -137,8 +137,8 @@ pub trait PairingGroupOps<Zp, G1, G2>:
|
||||
|
||||
pub trait Curve {
|
||||
type Zp: FieldOps;
|
||||
type G1: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
|
||||
type G2: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
|
||||
type G1: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
|
||||
type G2: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
|
||||
type Gt: PairingGroupOps<Self::Zp, Self::G1, Self::G2>;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] {
|
||||
mod g1 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G1 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
@@ -169,7 +179,17 @@ mod g1 {
|
||||
mod g2 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G2 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
|
||||
@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] {
|
||||
mod g1 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G1 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
@@ -168,7 +178,17 @@ mod g1 {
|
||||
mod g2 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G2 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
|
||||
|
||||
pub mod curve_446;
|
||||
pub mod curve_api;
|
||||
pub mod proofs;
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps};
|
||||
|
||||
use ark_serialize::{
|
||||
CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
|
||||
};
|
||||
use core::ops::{Index, IndexMut};
|
||||
use rand::RngCore;
|
||||
|
||||
@@ -7,6 +10,36 @@ use rand::RngCore;
|
||||
#[repr(transparent)]
|
||||
struct OneBased<T: ?Sized>(T);
|
||||
|
||||
impl<T: Valid> Valid for OneBased<T> {
|
||||
fn check(&self) -> Result<(), SerializationError> {
|
||||
self.0.check()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CanonicalDeserialize> CanonicalDeserialize for OneBased<T> {
|
||||
fn deserialize_with_mode<R: ark_serialize::Read>(
|
||||
reader: R,
|
||||
compress: Compress,
|
||||
validate: Validate,
|
||||
) -> Result<Self, SerializationError> {
|
||||
T::deserialize_with_mode(reader, compress, validate).map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CanonicalSerialize> CanonicalSerialize for OneBased<T> {
|
||||
fn serialize_with_mode<W: ark_serialize::Write>(
|
||||
&self,
|
||||
writer: W,
|
||||
compress: Compress,
|
||||
) -> Result<(), SerializationError> {
|
||||
self.0.serialize_with_mode(writer, compress)
|
||||
}
|
||||
|
||||
fn serialized_size(&self, compress: Compress) -> usize {
|
||||
self.0.serialized_size(compress)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ComputeLoad {
|
||||
Proof,
|
||||
@@ -42,7 +75,13 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(
|
||||
Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize,
|
||||
)]
|
||||
#[serde(bound(
|
||||
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
|
||||
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
|
||||
))]
|
||||
struct GroupElements<G: Curve> {
|
||||
g_list: OneBased<Vec<G::G1>>,
|
||||
g_hat_list: OneBased<Vec<G::G2>>,
|
||||
|
||||
@@ -6,7 +6,7 @@ fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
|
||||
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)]
|
||||
pub struct PublicParams<G: Curve> {
|
||||
g_lists: GroupElements<G>,
|
||||
big_d: usize,
|
||||
@@ -52,6 +52,10 @@ impl<G: Curve> PublicParams<G> {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(bound(
|
||||
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
|
||||
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
|
||||
))]
|
||||
pub struct Proof<G: Curve> {
|
||||
c_hat: G::G2,
|
||||
c_y: G::G1,
|
||||
@@ -992,49 +996,75 @@ mod tests {
|
||||
m_roundtrip[i] = result;
|
||||
}
|
||||
|
||||
let public_param = crs_gen::<crate::curve_api::Bls12_446>(d, k, b_i, q, t, rng);
|
||||
type Curve = crate::curve_api::Bls12_446;
|
||||
|
||||
for use_fake_e1 in [false, true] {
|
||||
for use_fake_e2 in [false, true] {
|
||||
for use_fake_m in [false, true] {
|
||||
for use_fake_r in [false, true] {
|
||||
let (public_commit, private_commit) = commit(
|
||||
a.clone(),
|
||||
b.clone(),
|
||||
c1.clone(),
|
||||
c2.clone(),
|
||||
if use_fake_r {
|
||||
fake_r.clone()
|
||||
} else {
|
||||
r.clone()
|
||||
},
|
||||
if use_fake_e1 {
|
||||
fake_e1.clone()
|
||||
} else {
|
||||
e1.clone()
|
||||
},
|
||||
if use_fake_m {
|
||||
fake_m.clone()
|
||||
} else {
|
||||
m.clone()
|
||||
},
|
||||
if use_fake_e2 {
|
||||
fake_e2.clone()
|
||||
} else {
|
||||
e2.clone()
|
||||
},
|
||||
&public_param,
|
||||
rng,
|
||||
);
|
||||
let serialize_then_deserialize =
|
||||
|public_param: &PublicParams<Curve>,
|
||||
compress: Compress|
|
||||
-> Result<PublicParams<Curve>, SerializationError> {
|
||||
let mut data = Vec::new();
|
||||
public_param.serialize_with_mode(&mut data, compress)?;
|
||||
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
let proof =
|
||||
prove((&public_param, &public_commit), &private_commit, load, rng);
|
||||
PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No)
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
verify(&proof, (&public_param, &public_commit)).is_err(),
|
||||
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
|
||||
let original_public_param = crs_gen::<Curve>(d, k, b_i, q, t, rng);
|
||||
let public_param_that_was_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
|
||||
let public_param_that_was_not_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
|
||||
|
||||
for public_param in [
|
||||
original_public_param,
|
||||
public_param_that_was_compressed,
|
||||
public_param_that_was_not_compressed,
|
||||
] {
|
||||
for use_fake_e1 in [false, true] {
|
||||
for use_fake_e2 in [false, true] {
|
||||
for use_fake_m in [false, true] {
|
||||
for use_fake_r in [false, true] {
|
||||
let (public_commit, private_commit) = commit(
|
||||
a.clone(),
|
||||
b.clone(),
|
||||
c1.clone(),
|
||||
c2.clone(),
|
||||
if use_fake_r {
|
||||
fake_r.clone()
|
||||
} else {
|
||||
r.clone()
|
||||
},
|
||||
if use_fake_e1 {
|
||||
fake_e1.clone()
|
||||
} else {
|
||||
e1.clone()
|
||||
},
|
||||
if use_fake_m {
|
||||
fake_m.clone()
|
||||
} else {
|
||||
m.clone()
|
||||
},
|
||||
if use_fake_e2 {
|
||||
fake_e2.clone()
|
||||
} else {
|
||||
e2.clone()
|
||||
},
|
||||
&public_param,
|
||||
rng,
|
||||
);
|
||||
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
let proof = prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
load,
|
||||
rng,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
verify(&proof, (&public_param, &public_commit)).is_err(),
|
||||
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user