feat(zk): impl CanonicalSerialize/Deserialize

This is to allow specifying whether data should be compressed
as compression and validation adds a very signigicant overhead
especially in wasm where deserialization goes from 6 min to 450ms
This commit is contained in:
tmontaigu
2024-06-12 19:15:16 +02:00
committed by David Testé
parent 2bd9f7aab4
commit 9cc97f9ab5
10 changed files with 274 additions and 59 deletions

View File

@@ -137,8 +137,8 @@ pub trait PairingGroupOps<Zp, G1, G2>:
pub trait Curve {
type Zp: FieldOps;
type G1: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
type G2: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
type G1: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
type G2: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
type Gt: PairingGroupOps<Self::Zp, Self::G1, Self::G2>;
}

View File

@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] {
mod g1 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G1 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
@@ -169,7 +179,17 @@ mod g1 {
mod g2 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G2 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]

View File

@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] {
mod g1 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G1 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
@@ -168,7 +178,17 @@ mod g1 {
mod g2 {
use super::*;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Hash,
CanonicalSerialize,
CanonicalDeserialize,
)]
#[repr(transparent)]
pub struct G2 {
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]

View File

@@ -1,3 +1,5 @@
pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
pub mod curve_446;
pub mod curve_api;
pub mod proofs;

View File

@@ -1,5 +1,8 @@
use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps};
use ark_serialize::{
CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
};
use core::ops::{Index, IndexMut};
use rand::RngCore;
@@ -7,6 +10,36 @@ use rand::RngCore;
#[repr(transparent)]
struct OneBased<T: ?Sized>(T);
impl<T: Valid> Valid for OneBased<T> {
fn check(&self) -> Result<(), SerializationError> {
self.0.check()
}
}
impl<T: CanonicalDeserialize> CanonicalDeserialize for OneBased<T> {
fn deserialize_with_mode<R: ark_serialize::Read>(
reader: R,
compress: Compress,
validate: Validate,
) -> Result<Self, SerializationError> {
T::deserialize_with_mode(reader, compress, validate).map(Self)
}
}
impl<T: CanonicalSerialize> CanonicalSerialize for OneBased<T> {
fn serialize_with_mode<W: ark_serialize::Write>(
&self,
writer: W,
compress: Compress,
) -> Result<(), SerializationError> {
self.0.serialize_with_mode(writer, compress)
}
fn serialized_size(&self, compress: Compress) -> usize {
self.0.serialized_size(compress)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ComputeLoad {
Proof,
@@ -42,7 +75,13 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[derive(
Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize,
)]
#[serde(bound(
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
))]
struct GroupElements<G: Curve> {
g_list: OneBased<Vec<G::G1>>,
g_hat_list: OneBased<Vec<G::G2>>,

View File

@@ -6,7 +6,7 @@ fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)]
pub struct PublicParams<G: Curve> {
g_lists: GroupElements<G>,
big_d: usize,
@@ -52,6 +52,10 @@ impl<G: Curve> PublicParams<G> {
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(bound(
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
))]
pub struct Proof<G: Curve> {
c_hat: G::G2,
c_y: G::G1,
@@ -992,49 +996,75 @@ mod tests {
m_roundtrip[i] = result;
}
let public_param = crs_gen::<crate::curve_api::Bls12_446>(d, k, b_i, q, t, rng);
type Curve = crate::curve_api::Bls12_446;
for use_fake_e1 in [false, true] {
for use_fake_e2 in [false, true] {
for use_fake_m in [false, true] {
for use_fake_r in [false, true] {
let (public_commit, private_commit) = commit(
a.clone(),
b.clone(),
c1.clone(),
c2.clone(),
if use_fake_r {
fake_r.clone()
} else {
r.clone()
},
if use_fake_e1 {
fake_e1.clone()
} else {
e1.clone()
},
if use_fake_m {
fake_m.clone()
} else {
m.clone()
},
if use_fake_e2 {
fake_e2.clone()
} else {
e2.clone()
},
&public_param,
rng,
);
let serialize_then_deserialize =
|public_param: &PublicParams<Curve>,
compress: Compress|
-> Result<PublicParams<Curve>, SerializationError> {
let mut data = Vec::new();
public_param.serialize_with_mode(&mut data, compress)?;
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof =
prove((&public_param, &public_commit), &private_commit, load, rng);
PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No)
};
assert_eq!(
verify(&proof, (&public_param, &public_commit)).is_err(),
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
let original_public_param = crs_gen::<Curve>(d, k, b_i, q, t, rng);
let public_param_that_was_compressed =
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
let public_param_that_was_not_compressed =
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
for public_param in [
original_public_param,
public_param_that_was_compressed,
public_param_that_was_not_compressed,
] {
for use_fake_e1 in [false, true] {
for use_fake_e2 in [false, true] {
for use_fake_m in [false, true] {
for use_fake_r in [false, true] {
let (public_commit, private_commit) = commit(
a.clone(),
b.clone(),
c1.clone(),
c2.clone(),
if use_fake_r {
fake_r.clone()
} else {
r.clone()
},
if use_fake_e1 {
fake_e1.clone()
} else {
e1.clone()
},
if use_fake_m {
fake_m.clone()
} else {
m.clone()
},
if use_fake_e2 {
fake_e2.clone()
} else {
e2.clone()
},
&public_param,
rng,
);
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
let proof = prove(
(&public_param, &public_commit),
&private_commit,
load,
rng,
);
assert_eq!(
verify(&proof, (&public_param, &public_commit)).is_err(),
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
);
}
}
}
}

View File

@@ -19,8 +19,10 @@ const {
FheInt256,
CompactCiphertextList,
ProvenCompactCiphertextList,
CompactPkePublicParams,
CompactPkeCrs,
ZkComputeLoad,
} = require("../pkg/tfhe.js");
const {CompactPkeCrs, ZkComputeLoad} = require("../pkg");
const {
randomBytes,
} = require('node:crypto');
@@ -467,6 +469,11 @@ test('hlapi_compact_ciphertext_list_with_proof', (t) => {
let crs = CompactPkeCrs.from_parameters(block_params, 2 + 32 + 1 + 256);
let public_params = crs.public_params();
const compress = false; // We don't compress as it's too slow on wasm
let serialized_pke_params = public_params.serialize(compress);
let validate = false; // Also too slow on wasm
public_params = CompactPkePublicParams.deserialize(serialized_pke_params, compress, validate);
let clear_u2 = 3;
let clear_i32 = -3284;
let clear_bool = true;

View File

@@ -1,6 +1,7 @@
use super::utils::*;
use crate::c_api::high_level_api::config::Config;
use crate::c_api::utils::get_ref_checked;
use crate::zk::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use std::ffi::c_int;
#[repr(C)]
@@ -21,12 +22,79 @@ impl From<ZkComputeLoad> for crate::zk::ZkComputeLoad {
pub struct CompactPkePublicParams(pub(crate) crate::core_crypto::entities::CompactPkePublicParams);
impl_destroy_on_type!(CompactPkePublicParams);
impl_serialize_deserialize_on_type!(CompactPkePublicParams);
/// Serializes the public params
///
/// If compress is true, the data will be compressed (less serialized bytes), however, this makes
/// the serialization process slower.
///
/// Also, the value to `compress` should match the value given to `is_compressed`
/// when deserializing.
#[no_mangle]
pub unsafe extern "C" fn compact_pke_public_params_serialize(
sself: *const CompactPkePublicParams,
compress: bool,
result: *mut crate::c_api::buffer::DynamicBuffer,
) -> ::std::os::raw::c_int {
crate::c_api::utils::catch_panic(|| {
crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap();
let wrapper = crate::c_api::utils::get_ref_checked(sself).unwrap();
let compress = if compress {
Compress::Yes
} else {
Compress::No
};
let mut buffer = vec![];
wrapper
.0
.serialize_with_mode(&mut buffer, compress)
.unwrap();
*result = buffer.into();
})
}
/// Deserializes the public params
///
/// If the data comes from compressed public params, then `is_compressed` must be true.
#[no_mangle]
pub unsafe extern "C" fn compact_pke_public_params_deserialize(
buffer_view: crate::c_api::buffer::DynamicBufferView,
is_compressed: bool,
validate: bool,
result: *mut *mut CompactPkePublicParams,
) -> ::std::os::raw::c_int {
crate::c_api::utils::catch_panic(|| {
crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap();
*result = std::ptr::null_mut();
let deserialized = crate::zk::CompactPkePublicParams::deserialize_with_mode(
buffer_view.as_slice(),
if is_compressed {
Compress::Yes
} else {
Compress::No
},
if validate {
Validate::Yes
} else {
Validate::No
},
)
.unwrap();
let heap_allocated_object = Box::new(CompactPkePublicParams(deserialized));
*result = Box::into_raw(heap_allocated_object);
})
}
pub struct CompactPkeCrs(pub(crate) crate::core_crypto::entities::CompactPkeCrs);
impl_destroy_on_type!(CompactPkeCrs);
impl_serialize_deserialize_on_type!(CompactPkeCrs);
#[no_mangle]
pub unsafe extern "C" fn compact_pke_crs_from_config(

View File

@@ -3,7 +3,7 @@ use wasm_bindgen::prelude::*;
use crate::js_on_wasm_api::js_high_level_api::config::TfheConfig;
use crate::js_on_wasm_api::js_high_level_api::{catch_panic_result, into_js_error};
use crate::js_on_wasm_api::shortint::ShortintParameters;
use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
#[derive(Copy, Clone, Eq, PartialEq)]
#[wasm_bindgen]
pub enum ZkComputeLoad {
@@ -29,16 +29,45 @@ pub struct CompactPkePublicParams(pub(crate) crate::zk::CompactPkePublicParams);
#[wasm_bindgen]
impl CompactPkePublicParams {
#[wasm_bindgen]
pub fn serialize(&self) -> Result<Vec<u8>, JsError> {
catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error))
pub fn serialize(&self, compress: bool) -> Result<Vec<u8>, JsError> {
catch_panic_result(|| {
let mut data = vec![];
self.0
.serialize_with_mode(
&mut data,
if compress {
Compress::Yes
} else {
Compress::No
},
)
.map_err(into_js_error)?;
Ok(data)
})
}
#[wasm_bindgen]
pub fn deserialize(buffer: &[u8]) -> Result<CompactPkePublicParams, JsError> {
pub fn deserialize(
buffer: &[u8],
is_compressed: bool,
validate: bool,
) -> Result<CompactPkePublicParams, JsError> {
catch_panic_result(|| {
bincode::deserialize(buffer)
.map(CompactPkePublicParams)
.map_err(into_js_error)
crate::zk::CompactPkePublicParams::deserialize_with_mode(
buffer,
if is_compressed {
Compress::Yes
} else {
Compress::No
},
if validate {
Validate::Yes
} else {
Validate::No
},
)
.map(CompactPkePublicParams)
.map_err(into_js_error)
})
}
}

View File

@@ -1,4 +1,4 @@
use crate::core_crypto::commons::math::random::{BoundedDistribution, Deserialize, Serialize};
use crate::core_crypto::commons::math::random::BoundedDistribution;
use crate::core_crypto::prelude::*;
use rand_core::RngCore;
use std::cmp::Ordering;
@@ -7,6 +7,7 @@ use std::fmt::Debug;
use tfhe_zk_pok::proofs::pke::crs_gen;
pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad;
pub use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
type Curve = tfhe_zk_pok::curve_api::Bls12_446;
pub type CompactPkeProof = tfhe_zk_pok::proofs::pke::Proof<Curve>;
pub type CompactPkePublicParams = tfhe_zk_pok::proofs::pke::PublicParams<Curve>;
@@ -29,7 +30,6 @@ impl ZkVerificationOutCome {
}
}
#[derive(Serialize, Deserialize)]
pub struct CompactPkeCrs {
public_params: CompactPkePublicParams,
}