From 9cc97f9ab5cd3d0ee5a1feb92e585cc043b037ec Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Wed, 12 Jun 2024 19:15:16 +0200 Subject: [PATCH] feat(zk): impl CanonicalSerialize/Deserialize This is to allow specifying whether data should be compressed as compression and validation adds a very signigicant overhead especially in wasm where deserialization goes from 6 min to 450ms --- tfhe-zk-pok/src/curve_api.rs | 4 +- tfhe-zk-pok/src/curve_api/bls12_381.rs | 24 +++- tfhe-zk-pok/src/curve_api/bls12_446.rs | 24 +++- tfhe-zk-pok/src/lib.rs | 2 + tfhe-zk-pok/src/proofs/mod.rs | 41 ++++++- tfhe-zk-pok/src/proofs/pke.rs | 110 +++++++++++------- tfhe/js_on_wasm_tests/test-hlapi-signed.js | 9 +- tfhe/src/c_api/high_level_api/zk.rs | 72 +++++++++++- .../js_on_wasm_api/js_high_level_api/zk.rs | 43 +++++-- tfhe/src/zk.rs | 4 +- 10 files changed, 274 insertions(+), 59 deletions(-) diff --git a/tfhe-zk-pok/src/curve_api.rs b/tfhe-zk-pok/src/curve_api.rs index 88a9ac47e..48dc43ccc 100644 --- a/tfhe-zk-pok/src/curve_api.rs +++ b/tfhe-zk-pok/src/curve_api.rs @@ -137,8 +137,8 @@ pub trait PairingGroupOps: pub trait Curve { type Zp: FieldOps; - type G1: CurveGroupOps + serde::Serialize + for<'de> serde::Deserialize<'de>; - type G2: CurveGroupOps + serde::Serialize + for<'de> serde::Deserialize<'de>; + type G1: CurveGroupOps + CanonicalSerialize + CanonicalDeserialize; + type G2: CurveGroupOps + CanonicalSerialize + CanonicalDeserialize; type Gt: PairingGroupOps; } diff --git a/tfhe-zk-pok/src/curve_api/bls12_381.rs b/tfhe-zk-pok/src/curve_api/bls12_381.rs index a4bb50dc0..8c5bb2c01 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_381.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_381.rs @@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 6]) -> [u8; 6 * 8] { mod g1 { use super::*; - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[derive( + Copy, + Clone, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] #[repr(transparent)] pub struct G1 { #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] @@ -169,7 +179,17 @@ mod g1 { mod g2 { use super::*; - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[derive( + Copy, + Clone, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] #[repr(transparent)] pub struct G2 { #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] diff --git a/tfhe-zk-pok/src/curve_api/bls12_446.rs b/tfhe-zk-pok/src/curve_api/bls12_446.rs index 4769acb67..d628ed708 100644 --- a/tfhe-zk-pok/src/curve_api/bls12_446.rs +++ b/tfhe-zk-pok/src/curve_api/bls12_446.rs @@ -36,7 +36,17 @@ fn bigint_to_bytes(x: [u64; 7]) -> [u8; 7 * 8] { mod g1 { use super::*; - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[derive( + Copy, + Clone, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] #[repr(transparent)] pub struct G1 { #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] @@ -168,7 +178,17 @@ mod g1 { mod g2 { use super::*; - #[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] + #[derive( + Copy, + Clone, + PartialEq, + Eq, + Serialize, + Deserialize, + Hash, + CanonicalSerialize, + CanonicalDeserialize, + )] #[repr(transparent)] pub struct G2 { #[serde(serialize_with = "ark_se", deserialize_with = "ark_de")] diff --git a/tfhe-zk-pok/src/lib.rs b/tfhe-zk-pok/src/lib.rs index b75cb224d..934b15957 100644 --- a/tfhe-zk-pok/src/lib.rs +++ b/tfhe-zk-pok/src/lib.rs @@ -1,3 +1,5 @@ +pub use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; + pub mod curve_446; pub mod curve_api; pub mod proofs; diff --git a/tfhe-zk-pok/src/proofs/mod.rs b/tfhe-zk-pok/src/proofs/mod.rs index bc6f0cf1a..3c902e44b 100644 --- a/tfhe-zk-pok/src/proofs/mod.rs +++ b/tfhe-zk-pok/src/proofs/mod.rs @@ -1,5 +1,8 @@ use crate::curve_api::{Curve, CurveGroupOps, FieldOps, PairingGroupOps}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate, +}; use core::ops::{Index, IndexMut}; use rand::RngCore; @@ -7,6 +10,36 @@ use rand::RngCore; #[repr(transparent)] struct OneBased(T); +impl Valid for OneBased { + fn check(&self) -> Result<(), SerializationError> { + self.0.check() + } +} + +impl CanonicalDeserialize for OneBased { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + T::deserialize_with_mode(reader, compress, validate).map(Self) + } +} + +impl CanonicalSerialize for OneBased { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.serialized_size(compress) + } +} + #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ComputeLoad { Proof, @@ -42,7 +75,13 @@ impl> IndexMut for OneBased { } } -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive( + Clone, Debug, serde::Serialize, serde::Deserialize, CanonicalSerialize, CanonicalDeserialize, +)] +#[serde(bound( + deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", + serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" +))] struct GroupElements { g_list: OneBased>, g_hat_list: OneBased>, diff --git a/tfhe-zk-pok/src/proofs/pke.rs b/tfhe-zk-pok/src/proofs/pke.rs index d32aa0fb3..1682cf3bd 100644 --- a/tfhe-zk-pok/src/proofs/pke.rs +++ b/tfhe-zk-pok/src/proofs/pke.rs @@ -6,7 +6,7 @@ fn bit_iter(x: u64, nbits: u32) -> impl Iterator { (0..nbits).map(move |idx| ((x >> idx) & 1) != 0) } -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)] pub struct PublicParams { g_lists: GroupElements, big_d: usize, @@ -52,6 +52,10 @@ impl PublicParams { } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(bound( + deserialize = "G: Curve, G::G1: serde::Deserialize<'de>, G::G2: serde::Deserialize<'de>", + serialize = "G: Curve, G::G1: serde::Serialize, G::G2: serde::Serialize" +))] pub struct Proof { c_hat: G::G2, c_y: G::G1, @@ -992,49 +996,75 @@ mod tests { m_roundtrip[i] = result; } - let public_param = crs_gen::(d, k, b_i, q, t, rng); + type Curve = crate::curve_api::Bls12_446; - for use_fake_e1 in [false, true] { - for use_fake_e2 in [false, true] { - for use_fake_m in [false, true] { - for use_fake_r in [false, true] { - let (public_commit, private_commit) = commit( - a.clone(), - b.clone(), - c1.clone(), - c2.clone(), - if use_fake_r { - fake_r.clone() - } else { - r.clone() - }, - if use_fake_e1 { - fake_e1.clone() - } else { - e1.clone() - }, - if use_fake_m { - fake_m.clone() - } else { - m.clone() - }, - if use_fake_e2 { - fake_e2.clone() - } else { - e2.clone() - }, - &public_param, - rng, - ); + let serialize_then_deserialize = + |public_param: &PublicParams, + compress: Compress| + -> Result, SerializationError> { + let mut data = Vec::new(); + public_param.serialize_with_mode(&mut data, compress)?; - for load in [ComputeLoad::Proof, ComputeLoad::Verify] { - let proof = - prove((&public_param, &public_commit), &private_commit, load, rng); + PublicParams::deserialize_with_mode(data.as_slice(), compress, Validate::No) + }; - assert_eq!( - verify(&proof, (&public_param, &public_commit)).is_err(), - use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m + let original_public_param = crs_gen::(d, k, b_i, q, t, rng); + let public_param_that_was_compressed = + serialize_then_deserialize(&original_public_param, Compress::No).unwrap(); + let public_param_that_was_not_compressed = + serialize_then_deserialize(&original_public_param, Compress::Yes).unwrap(); + + for public_param in [ + original_public_param, + public_param_that_was_compressed, + public_param_that_was_not_compressed, + ] { + for use_fake_e1 in [false, true] { + for use_fake_e2 in [false, true] { + for use_fake_m in [false, true] { + for use_fake_r in [false, true] { + let (public_commit, private_commit) = commit( + a.clone(), + b.clone(), + c1.clone(), + c2.clone(), + if use_fake_r { + fake_r.clone() + } else { + r.clone() + }, + if use_fake_e1 { + fake_e1.clone() + } else { + e1.clone() + }, + if use_fake_m { + fake_m.clone() + } else { + m.clone() + }, + if use_fake_e2 { + fake_e2.clone() + } else { + e2.clone() + }, + &public_param, + rng, ); + + for load in [ComputeLoad::Proof, ComputeLoad::Verify] { + let proof = prove( + (&public_param, &public_commit), + &private_commit, + load, + rng, + ); + + assert_eq!( + verify(&proof, (&public_param, &public_commit)).is_err(), + use_fake_e1 || use_fake_e2 || use_fake_r || use_fake_m + ); + } } } } diff --git a/tfhe/js_on_wasm_tests/test-hlapi-signed.js b/tfhe/js_on_wasm_tests/test-hlapi-signed.js index 214af7745..6e22e9f75 100644 --- a/tfhe/js_on_wasm_tests/test-hlapi-signed.js +++ b/tfhe/js_on_wasm_tests/test-hlapi-signed.js @@ -19,8 +19,10 @@ const { FheInt256, CompactCiphertextList, ProvenCompactCiphertextList, + CompactPkePublicParams, + CompactPkeCrs, + ZkComputeLoad, } = require("../pkg/tfhe.js"); -const {CompactPkeCrs, ZkComputeLoad} = require("../pkg"); const { randomBytes, } = require('node:crypto'); @@ -467,6 +469,11 @@ test('hlapi_compact_ciphertext_list_with_proof', (t) => { let crs = CompactPkeCrs.from_parameters(block_params, 2 + 32 + 1 + 256); let public_params = crs.public_params(); + const compress = false; // We don't compress as it's too slow on wasm + let serialized_pke_params = public_params.serialize(compress); + let validate = false; // Also too slow on wasm + public_params = CompactPkePublicParams.deserialize(serialized_pke_params, compress, validate); + let clear_u2 = 3; let clear_i32 = -3284; let clear_bool = true; diff --git a/tfhe/src/c_api/high_level_api/zk.rs b/tfhe/src/c_api/high_level_api/zk.rs index 37d02d06b..dfa9a2044 100644 --- a/tfhe/src/c_api/high_level_api/zk.rs +++ b/tfhe/src/c_api/high_level_api/zk.rs @@ -1,6 +1,7 @@ use super::utils::*; use crate::c_api::high_level_api::config::Config; use crate::c_api::utils::get_ref_checked; +use crate::zk::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; use std::ffi::c_int; #[repr(C)] @@ -21,12 +22,79 @@ impl From for crate::zk::ZkComputeLoad { pub struct CompactPkePublicParams(pub(crate) crate::core_crypto::entities::CompactPkePublicParams); impl_destroy_on_type!(CompactPkePublicParams); -impl_serialize_deserialize_on_type!(CompactPkePublicParams); + +/// Serializes the public params +/// +/// If compress is true, the data will be compressed (less serialized bytes), however, this makes +/// the serialization process slower. +/// +/// Also, the value to `compress` should match the value given to `is_compressed` +/// when deserializing. +#[no_mangle] +pub unsafe extern "C" fn compact_pke_public_params_serialize( + sself: *const CompactPkePublicParams, + compress: bool, + result: *mut crate::c_api::buffer::DynamicBuffer, +) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); + + let wrapper = crate::c_api::utils::get_ref_checked(sself).unwrap(); + + let compress = if compress { + Compress::Yes + } else { + Compress::No + }; + let mut buffer = vec![]; + wrapper + .0 + .serialize_with_mode(&mut buffer, compress) + .unwrap(); + + *result = buffer.into(); + }) +} + +/// Deserializes the public params +/// +/// If the data comes from compressed public params, then `is_compressed` must be true. +#[no_mangle] +pub unsafe extern "C" fn compact_pke_public_params_deserialize( + buffer_view: crate::c_api::buffer::DynamicBufferView, + is_compressed: bool, + validate: bool, + result: *mut *mut CompactPkePublicParams, +) -> ::std::os::raw::c_int { + crate::c_api::utils::catch_panic(|| { + crate::c_api::utils::check_ptr_is_non_null_and_aligned(result).unwrap(); + + *result = std::ptr::null_mut(); + + let deserialized = crate::zk::CompactPkePublicParams::deserialize_with_mode( + buffer_view.as_slice(), + if is_compressed { + Compress::Yes + } else { + Compress::No + }, + if validate { + Validate::Yes + } else { + Validate::No + }, + ) + .unwrap(); + + let heap_allocated_object = Box::new(CompactPkePublicParams(deserialized)); + + *result = Box::into_raw(heap_allocated_object); + }) +} pub struct CompactPkeCrs(pub(crate) crate::core_crypto::entities::CompactPkeCrs); impl_destroy_on_type!(CompactPkeCrs); -impl_serialize_deserialize_on_type!(CompactPkeCrs); #[no_mangle] pub unsafe extern "C" fn compact_pke_crs_from_config( diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs index 0fcf327f5..be95bdeb4 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/zk.rs @@ -3,7 +3,7 @@ use wasm_bindgen::prelude::*; use crate::js_on_wasm_api::js_high_level_api::config::TfheConfig; use crate::js_on_wasm_api::js_high_level_api::{catch_panic_result, into_js_error}; use crate::js_on_wasm_api::shortint::ShortintParameters; - +use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; #[derive(Copy, Clone, Eq, PartialEq)] #[wasm_bindgen] pub enum ZkComputeLoad { @@ -29,16 +29,45 @@ pub struct CompactPkePublicParams(pub(crate) crate::zk::CompactPkePublicParams); #[wasm_bindgen] impl CompactPkePublicParams { #[wasm_bindgen] - pub fn serialize(&self) -> Result, JsError> { - catch_panic_result(|| bincode::serialize(&self.0).map_err(into_js_error)) + pub fn serialize(&self, compress: bool) -> Result, JsError> { + catch_panic_result(|| { + let mut data = vec![]; + self.0 + .serialize_with_mode( + &mut data, + if compress { + Compress::Yes + } else { + Compress::No + }, + ) + .map_err(into_js_error)?; + Ok(data) + }) } #[wasm_bindgen] - pub fn deserialize(buffer: &[u8]) -> Result { + pub fn deserialize( + buffer: &[u8], + is_compressed: bool, + validate: bool, + ) -> Result { catch_panic_result(|| { - bincode::deserialize(buffer) - .map(CompactPkePublicParams) - .map_err(into_js_error) + crate::zk::CompactPkePublicParams::deserialize_with_mode( + buffer, + if is_compressed { + Compress::Yes + } else { + Compress::No + }, + if validate { + Validate::Yes + } else { + Validate::No + }, + ) + .map(CompactPkePublicParams) + .map_err(into_js_error) }) } } diff --git a/tfhe/src/zk.rs b/tfhe/src/zk.rs index 3c7b376e2..6e56bdeba 100644 --- a/tfhe/src/zk.rs +++ b/tfhe/src/zk.rs @@ -1,4 +1,4 @@ -use crate::core_crypto::commons::math::random::{BoundedDistribution, Deserialize, Serialize}; +use crate::core_crypto::commons::math::random::BoundedDistribution; use crate::core_crypto::prelude::*; use rand_core::RngCore; use std::cmp::Ordering; @@ -7,6 +7,7 @@ use std::fmt::Debug; use tfhe_zk_pok::proofs::pke::crs_gen; pub use tfhe_zk_pok::proofs::ComputeLoad as ZkComputeLoad; +pub use tfhe_zk_pok::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate}; type Curve = tfhe_zk_pok::curve_api::Bls12_446; pub type CompactPkeProof = tfhe_zk_pok::proofs::pke::Proof; pub type CompactPkePublicParams = tfhe_zk_pok::proofs::pke::PublicParams; @@ -29,7 +30,6 @@ impl ZkVerificationOutCome { } } -#[derive(Serialize, Deserialize)] pub struct CompactPkeCrs { public_params: CompactPkePublicParams, }