mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(zk): impl CanonicalSerialize/Deserialize
This is to allow specifying whether data should be compressed as compression and validation adds a very signigicant overhead especially in wasm where deserialization goes from 6 min to 450ms
This commit is contained in:
@@ -137,8 +137,8 @@ pub trait PairingGroupOps<Zp, G1, G2>:
|
||||
|
||||
pub trait Curve {
|
||||
type Zp: FieldOps;
|
||||
type G1: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
|
||||
type G2: CurveGroupOps<Self::Zp> + serde::Serialize + for<'de> serde::Deserialize<'de>;
|
||||
type G1: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
|
||||
type G2: CurveGroupOps<Self::Zp> + CanonicalSerialize + CanonicalDeserialize;
|
||||
type Gt: PairingGroupOps<Self::Zp, Self::G1, Self::G2>;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] {
|
||||
mod g1 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G1 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
@@ -169,7 +179,17 @@ mod g1 {
|
||||
mod g2 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G2 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
|
||||
@@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] {
|
||||
mod g1 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G1 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
@@ -168,7 +178,17 @@ mod g1 {
|
||||
mod g2 {
|
||||
use super::*;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
Hash,
|
||||
CanonicalSerialize,
|
||||
CanonicalDeserialize,
|
||||
)]
|
||||
#[repr(transparent)]
|
||||
pub struct G2 {
|
||||
#[serde(serialize_with = "ark_se", deserialize_with = "ark_de")]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
|
||||
|
||||
pub mod curve_446;
|
||||
pub mod curve_api;
|
||||
pub mod proofs;
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps};
|
||||
|
||||
use ark_serialize::{
|
||||
CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate,
|
||||
};
|
||||
use core::ops::{Index, IndexMut};
|
||||
use rand::RngCore;
|
||||
|
||||
@@ -7,6 +10,36 @@ use rand::RngCore;
|
||||
#[repr(transparent)]
|
||||
struct OneBased<T: ?Sized>(T);
|
||||
|
||||
impl<T: Valid> Valid for OneBased<T> {
|
||||
fn check(&self) -> Result<(), SerializationError> {
|
||||
self.0.check()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CanonicalDeserialize> CanonicalDeserialize for OneBased<T> {
|
||||
fn deserialize_with_mode<R: ark_serialize::Read>(
|
||||
reader: R,
|
||||
compress: Compress,
|
||||
validate: Validate,
|
||||
) -> Result<Self, SerializationError> {
|
||||
T::deserialize_with_mode(reader, compress, validate).map(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CanonicalSerialize> CanonicalSerialize for OneBased<T> {
|
||||
fn serialize_with_mode<W: ark_serialize::Write>(
|
||||
&self,
|
||||
writer: W,
|
||||
compress: Compress,
|
||||
) -> Result<(), SerializationError> {
|
||||
self.0.serialize_with_mode(writer, compress)
|
||||
}
|
||||
|
||||
fn serialized_size(&self, compress: Compress) -> usize {
|
||||
self.0.serialized_size(compress)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ComputeLoad {
|
||||
Proof,
|
||||
@@ -42,7 +75,13 @@ impl<T: ?Sized + IndexMut<usize>> IndexMut<usize> for OneBased<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(
|
||||
Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize,
|
||||
)]
|
||||
#[serde(bound(
|
||||
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
|
||||
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
|
||||
))]
|
||||
struct GroupElements<G: Curve> {
|
||||
g_list: OneBased<Vec<G::G1>>,
|
||||
g_hat_list: OneBased<Vec<G::G2>>,
|
||||
|
||||
@@ -6,7 +6,7 @@ fn bit_iter(x: u64, nbits: u32) -> impl Iterator<Item = bool> {
|
||||
(0..nbits).map(move |idx| ((x >> idx) & 1) != 0)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)]
|
||||
pub struct PublicParams<G: Curve> {
|
||||
g_lists: GroupElements<G>,
|
||||
big_d: usize,
|
||||
@@ -52,6 +52,10 @@ impl<G: Curve> PublicParams<G> {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(bound(
|
||||
deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>",
|
||||
serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize"
|
||||
))]
|
||||
pub struct Proof<G: Curve> {
|
||||
c_hat: G::G2,
|
||||
c_y: G::G1,
|
||||
@@ -992,49 +996,75 @@ mod tests {
|
||||
m_roundtrip[i] = result;
|
||||
}
|
||||
|
||||
let public_param = crs_gen::<crate::curve_api::Bls12_446>(d, k, b_i, q, t, rng);
|
||||
type Curve = crate::curve_api::Bls12_446;
|
||||
|
||||
for use_fake_e1 in [false, true] {
|
||||
for use_fake_e2 in [false, true] {
|
||||
for use_fake_m in [false, true] {
|
||||
for use_fake_r in [false, true] {
|
||||
let (public_commit, private_commit) = commit(
|
||||
a.clone(),
|
||||
b.clone(),
|
||||
c1.clone(),
|
||||
c2.clone(),
|
||||
if use_fake_r {
|
||||
fake_r.clone()
|
||||
} else {
|
||||
r.clone()
|
||||
},
|
||||
if use_fake_e1 {
|
||||
fake_e1.clone()
|
||||
} else {
|
||||
e1.clone()
|
||||
},
|
||||
if use_fake_m {
|
||||
fake_m.clone()
|
||||
} else {
|
||||
m.clone()
|
||||
},
|
||||
if use_fake_e2 {
|
||||
fake_e2.clone()
|
||||
} else {
|
||||
e2.clone()
|
||||
},
|
||||
&public_param,
|
||||
rng,
|
||||
);
|
||||
let serialize_then_deserialize =
|
||||
|public_param: &PublicParams<Curve>,
|
||||
compress: Compress|
|
||||
-> Result<PublicParams<Curve>, SerializationError> {
|
||||
let mut data = Vec::new();
|
||||
public_param.serialize_with_mode(&mut data, compress)?;
|
||||
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
let proof =
|
||||
prove((&public_param, &public_commit), &private_commit, load, rng);
|
||||
PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No)
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
verify(&proof, (&public_param, &public_commit)).is_err(),
|
||||
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
|
||||
let original_public_param = crs_gen::<Curve>(d, k, b_i, q, t, rng);
|
||||
let public_param_that_was_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::No).unwrap();
|
||||
let public_param_that_was_not_compressed =
|
||||
serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap();
|
||||
|
||||
for public_param in [
|
||||
original_public_param,
|
||||
public_param_that_was_compressed,
|
||||
public_param_that_was_not_compressed,
|
||||
] {
|
||||
for use_fake_e1 in [false, true] {
|
||||
for use_fake_e2 in [false, true] {
|
||||
for use_fake_m in [false, true] {
|
||||
for use_fake_r in [false, true] {
|
||||
let (public_commit, private_commit) = commit(
|
||||
a.clone(),
|
||||
b.clone(),
|
||||
c1.clone(),
|
||||
c2.clone(),
|
||||
if use_fake_r {
|
||||
fake_r.clone()
|
||||
} else {
|
||||
r.clone()
|
||||
},
|
||||
if use_fake_e1 {
|
||||
fake_e1.clone()
|
||||
} else {
|
||||
e1.clone()
|
||||
},
|
||||
if use_fake_m {
|
||||
fake_m.clone()
|
||||
} else {
|
||||
m.clone()
|
||||
},
|
||||
if use_fake_e2 {
|
||||
fake_e2.clone()
|
||||
} else {
|
||||
e2.clone()
|
||||
},
|
||||
&public_param,
|
||||
rng,
|
||||
);
|
||||
|
||||
for load in [ComputeLoad::Proof, ComputeLoad::Verify] {
|
||||
let proof = prove(
|
||||
(&public_param, &public_commit),
|
||||
&private_commit,
|
||||
load,
|
||||
rng,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
verify(&proof, (&public_param, &public_commit)).is_err(),
|
||||
use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,8 +19,10 @@ const {
|
||||
FheInt256,
|
||||
CompactCiphertextList,
|
||||
ProvenCompactCiphertextList,
|
||||
CompactPkePublicParams,
|
||||
CompactPkeCrs,
|
||||
ZkComputeLoad,
|
||||
} = require("../pkg/tfhe.js");
|
||||
const {CompactPkeCrs, ZkComputeLoad} = require("../pkg");
|
||||
const {
|
||||
randomBytes,
|
||||
} = require('node:crypto');
|
||||
@@ -467,6 +469,11 @@ test('hlapi_compact_ciphertext_list_with_proof', (t) => {
|
||||
let crs = CompactPkeCrs.from_parameters(block_params, 2 + 32 + 1 + 256);
|
||||
let public_params = crs.public_params();
|
||||
|
||||
const compress = false; // We don't compress as it's too slow on wasm
|
||||
let serialized_pke_params = public_params.serialize(compress);
|
||||
let validate = false; // Also too slow on wasm
|
||||
public_params = CompactPkePublicParams.deserialize(serialized_pke_params, compress, validate);
|
||||
|
||||
let clear_u2 = 3;
|
||||
let clear_i32 = -3284;
|
||||
let clear_bool = true;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use super::utils::*;
|
||||
use crate::c_api::high_level_api::config::Config;
|
||||
use crate::c_api::utils::get_ref_checked;
|
||||
use crate::zk::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
|
||||
use std::ffi::c_int;
|
||||
|
||||
#[repr(C)]
|
||||
@@ -21,12 +22,79 @@ impl From<ZkComputeLoad> for crate::zk::ZkComputeLoad {
|
||||
|
||||
pub struct CompactPkePublicParams(pub(crate) crate::core_crypto::entities::CompactPkePublicParams);
|
||||
impl_destroy_on_type!(CompactPkePublicParams);
|
||||
impl_serialize_deserialize_on_type!(CompactPkePublicParams);
|
||||
|
||||
/// Serializes the public params
|
||||
///
|
||||
/// If compress is true, the data will be compressed (less serialized bytes), however, this makes
|
||||
/// the serialization process slower.
|
||||
///
|
||||
/// Also, the value to `compress` should match the value given to `is_compressed`
|
||||
/// when deserializing.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compact_pke_public_params_serialize(
|
||||
sself: *const CompactPkePublicParams,
|
||||
compress: bool,
|
||||
result: *mut crate::c_api::buffer::DynamicBuffer,
|
||||
) -> ::std::os::raw::c_int {
|
||||
crate::c_api::utils::catch_panic(|| {
|
||||
crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap();
|
||||
|
||||
let wrapper = crate::c_api::utils::get_ref_checked(sself).unwrap();
|
||||
|
||||
let compress = if compress {
|
||||
Compress::Yes
|
||||
} else {
|
||||
Compress::No
|
||||
};
|
||||
let mut buffer = vec![];
|
||||
wrapper
|
||||
.0
|
||||
.serialize_with_mode(&mut buffer, compress)
|
||||
.unwrap();
|
||||
|
||||
*result = buffer.into();
|
||||
})
|
||||
}
|
||||
|
||||
/// Deserializes the public params
|
||||
///
|
||||
/// If the data comes from compressed public params, then `is_compressed` must be true.
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compact_pke_public_params_deserialize(
|
||||
buffer_view: crate::c_api::buffer::DynamicBufferView,
|
||||
is_compressed: bool,
|
||||
validate: bool,
|
||||
result: *mut *mut CompactPkePublicParams,
|
||||
) -> ::std::os::raw::c_int {
|
||||
crate::c_api::utils::catch_panic(|| {
|
||||
crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap();
|
||||
|
||||
*result = std::ptr::null_mut();
|
||||
|
||||
let deserialized = crate::zk::CompactPkePublicParams::deserialize_with_mode(
|
||||
buffer_view.as_slice(),
|
||||
if is_compressed {
|
||||
Compress::Yes
|
||||
} else {
|
||||
Compress::No
|
||||
},
|
||||
if validate {
|
||||
Validate::Yes
|
||||
} else {
|
||||
Validate::No
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let heap_allocated_object = Box::new(CompactPkePublicParams(deserialized));
|
||||
|
||||
*result = Box::into_raw(heap_allocated_object);
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CompactPkeCrs(pub(crate) crate::core_crypto::entities::CompactPkeCrs);
|
||||
|
||||
impl_destroy_on_type!(CompactPkeCrs);
|
||||
impl_serialize_deserialize_on_type!(CompactPkeCrs);
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compact_pke_crs_from_config(
|
||||
|
||||
@@ -3,7 +3,7 @@ use wasm_bindgen::prelude::*;
|
||||
use crate::js_on_wasm_api::js_high_level_api::config::TfheConfig;
|
||||
use crate::js_on_wasm_api::js_high_level_api::{catch_panic_result, into_js_error};
|
||||
use crate::js_on_wasm_api::shortint::ShortintParameters;
|
||||
|
||||
use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
#[wasm_bindgen]
|
||||
pub enum ZkComputeLoad {
|
||||
@@ -29,16 +29,45 @@ pub struct CompactPkePublicParams(pub(crate) crate::zk::CompactPkePublicParams);
|
||||
#[wasm_bindgen]
|
||||
impl CompactPkePublicParams {
|
||||
#[wasm_bindgen]
|
||||
pub fn serialize(&self) -> Result<Vec<u8>, JsError> {
|
||||
catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error))
|
||||
pub fn serialize(&self, compress: bool) -> Result<Vec<u8>, JsError> {
|
||||
catch_panic_result(|| {
|
||||
let mut data = vec![];
|
||||
self.0
|
||||
.serialize_with_mode(
|
||||
&mut data,
|
||||
if compress {
|
||||
Compress::Yes
|
||||
} else {
|
||||
Compress::No
|
||||
},
|
||||
)
|
||||
.map_err(into_js_error)?;
|
||||
Ok(data)
|
||||
})
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn deserialize(buffer: &[u8]) -> Result<CompactPkePublicParams, JsError> {
|
||||
pub fn deserialize(
|
||||
buffer: &[u8],
|
||||
is_compressed: bool,
|
||||
validate: bool,
|
||||
) -> Result<CompactPkePublicParams, JsError> {
|
||||
catch_panic_result(|| {
|
||||
bincode::deserialize(buffer)
|
||||
.map(CompactPkePublicParams)
|
||||
.map_err(into_js_error)
|
||||
crate::zk::CompactPkePublicParams::deserialize_with_mode(
|
||||
buffer,
|
||||
if is_compressed {
|
||||
Compress::Yes
|
||||
} else {
|
||||
Compress::No
|
||||
},
|
||||
if validate {
|
||||
Validate::Yes
|
||||
} else {
|
||||
Validate::No
|
||||
},
|
||||
)
|
||||
.map(CompactPkePublicParams)
|
||||
.map_err(into_js_error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::core_crypto::commons::math::random::{BoundedDistribution, Deserialize, Serialize};
|
||||
use crate::core_crypto::commons::math::random::BoundedDistribution;
|
||||
use crate::core_crypto::prelude::*;
|
||||
use rand_core::RngCore;
|
||||
use std::cmp::Ordering;
|
||||
@@ -7,6 +7,7 @@ use std::fmt::Debug;
|
||||
use tfhe_zk_pok::proofs::pke::crs_gen;
|
||||
|
||||
pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad;
|
||||
pub use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
|
||||
type Curve = tfhe_zk_pok::curve_api::Bls12_446;
|
||||
pub type CompactPkeProof = tfhe_zk_pok::proofs::pke::Proof<Curve>;
|
||||
pub type CompactPkePublicParams = tfhe_zk_pok::proofs::pke::PublicParams<Curve>;
|
||||
@@ -29,7 +30,6 @@ impl ZkVerificationOutCome {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CompactPkeCrs {
|
||||
public_params: CompactPkePublicParams,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user