feat(zk): impl CanonicalSerialize/Deserialize

This is to allow specifying whether data should be compressed
as compression and validation adds a very signigicant overhead
especially in wasm where deserialization goes from 6 min to 450ms
This commit is contained in:
tmontaigu
2024-06-12 19:15:16 +02:00
committed by David Testé
parent 2bd9f7aab4
commit 9cc97f9ab5
10 changed files with 274 additions and 59 deletions

View File

@@ -137,8 +137,8 @@ pub trait PairingGroupOps<Zp, G1, G2>:
pub trait Curve {
type Zp: FieldOps;
type G1: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
type G2: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
type G1: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
type G2: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
type Gt: PairingGroupOps<Self::Zp, Self::G1, Self::G2>;
}

View File

@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] {
mod g1 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G1 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
@@ -169,7 +179,17 @@ mod g1 {
mod g2 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G2 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]

View File

@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] {
mod g1 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G1 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
@@ -168,7 +178,17 @@ mod g1 {
mod g2 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G2 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]

View File

@@ -1,3 +1,5 @@
pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
pub mod curve_446;
pub mod curve_api;
pub mod proofs;

View File

@@ -1,5 +1,8 @@
use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps};
use ark_serialize::{
CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
};
use core::ops::{Index, IndexMut};
use rand::RngCore;
@@ -7,6 +10,36 @@ use rand::RngCore;
#[repr(transparent)]
struct OneBased<T: ?Sized>(T);
impl<T: Valid> Valid for OneBased<T> {
fn check(&self) -> Result<(), SerializationError> {
self.0.check()
}
}
impl<T: CanonicalDeserialize> CanonicalDeserialize for OneBased<T> {
fn deserialize_with_mode<R: ark_serialize::Read>(
reader: R,
compress: Compress,
validate: Validate,
) -> Result<Self, SerializationError> {
T::deserialize_with_mode(reader, compress, validate).map(Self)
}
}
impl<T: CanonicalSerialize> CanonicalSerialize for OneBased<T> {
fn serialize_with_mode<W: ark_serialize::Write>(
&self,
writer: W,
compress: Compress,
) -> Result<(), SerializationError> {
self.0.serialize_with_mode(writer, compress)
}
fn serialized_size(&self, compress: Compress) -> usize {
self.0.serialized_size(compress)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ComputeLoad {
Proof,
@@ -42,7 +75,13 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[derive(
Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize,
)]
#[serde(bound(
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
))]
struct GroupElements<G: Curve> {
g_list: OneBased<Vec<G::G1>>,
g_hat_list: OneBased<Vec<G::G2>>,

View File

@@ -6,7 +6,7 @@ fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)]
pub struct PublicParams<G: Curve> {
g_lists: GroupElements<G>,
big_d: usize,
@@ -52,6 +52,10 @@ impl<G: Curve> PublicParams<G> {
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(bound(
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
))]
pub struct Proof<G: Curve> {
c_hat: G::G2,
c_y: G::G1,
@@ -992,49 +996,75 @@ mod tests {
m_roundtrip[i] = result;
}
let public_param = crs_gen::<crate::curve_api::Bls12_446>(d, k, b_i, q, t, rng);
type Curve = crate::curve_api::Bls12_446;
for use_fake_e1 in [false, true] {
for use_fake_e2 in [false, true] {
for use_fake_m in [false, true] {
for use_fake_r in [false, true] {
let (public_commit, private_commit) = commit(
a.clone(),
b.clone(),
c1.clone(),
c2.clone(),
if use_fake_r {
fake_r.clone()
} else {
r.clone()
},
if use_fake_e1 {
fake_e1.clone()
} else {
e1.clone()
},
if use_fake_m {
fake_m.clone()
} else {
m.clone()
},
if use_fake_e2 {
fake_e2.clone()
} else {
e2.clone()
},
&public_param,
rng,
);
let serialize_then_deserialize =
|public_param: &PublicParams<Curve>,
compress: Compress|
-> Result<PublicParams<Curve>, SerializationError> {
let mut data = Vec::new();
public_param.serialize_with_mode(&mut data, compress)?;
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof =
prove((&public_param, &public_commit), &private_commit, load, rng);
PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No)
};
assert_eq!(
verify(&proof, (&public_param, &public_commit)).is_err(),
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
let original_public_param = crs_gen::<Curve>(d, k, b_i, q, t, rng);
let public_param_that_was_compressed =
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
let public_param_that_was_not_compressed =
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
for public_param in [
original_public_param,
public_param_that_was_compressed,
public_param_that_was_not_compressed,
] {
for use_fake_e1 in [false, true] {
for use_fake_e2 in [false, true] {
for use_fake_m in [false, true] {
for use_fake_r in [false, true] {
let (public_commit, private_commit) = commit(
a.clone(),
b.clone(),
c1.clone(),
c2.clone(),
if use_fake_r {
fake_r.clone()
} else {
r.clone()
},
if use_fake_e1 {
fake_e1.clone()
} else {
e1.clone()
},
if use_fake_m {
fake_m.clone()
} else {
m.clone()
},
if use_fake_e2 {
fake_e2.clone()
} else {
e2.clone()
},
&public_param,
rng,
);
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof = prove(
(&public_param, &public_commit),
&private_commit,
load,
rng,
);
assert_eq!(
verify(&proof, (&public_param, &public_commit)).is_err(),
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
);
}
}
}
}