diff --git a/tfhe/docs/guides/zk-pok.md b/tfhe/docs/guides/zk-pok.md index 09d30d670..1256425ed 100644 --- a/tfhe/docs/guides/zk-pok.md +++ b/tfhe/docs/guides/zk-pok.md @@ -12,7 +12,7 @@ Using this feature is straightforward: during encryption, the client generates t ```rust use rand::prelude::*; -use tfhe::prelude::FheDecrypt; +use tfhe::prelude::*; use tfhe::set_server_key; use tfhe::zk::{CompactPkeCrs, ZkComputeLoad}; @@ -45,9 +45,8 @@ pub fn main() -> Result<(), Box> { // Verify the ciphertexts let expander = proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?; - - let a: tfhe::FheUint64 = expander.get(0).unwrap()?; - let b: tfhe::FheUint64 = expander.get(1).unwrap()?; + let a: tfhe::FheUint64 = expander.get(0)?.unwrap(); + let b: tfhe::FheUint64 = expander.get(1)?.unwrap(); a + b }; @@ -80,7 +79,7 @@ This works essentially in the same way as before. Additionally, you need to indi ```rust use rand::prelude::*; -use tfhe::prelude::FheDecrypt; +use tfhe::prelude::*; use tfhe::set_server_key; use tfhe::zk::{CompactPkeCrs, ZkComputeLoad}; @@ -119,9 +118,10 @@ pub fn main() -> Result<(), Box> { set_server_key(server_key); // Verify the ciphertexts - let expander = proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?; - let a: tfhe::FheUint64 = expander.get(0).unwrap()?; - let b: tfhe::FheUint64 = expander.get(1).unwrap()?; + let expander = + proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?; + let a: tfhe::FheUint64 = expander.get(0)?.unwrap(); + let b: tfhe::FheUint64 = expander.get(1)?.unwrap(); a + b }; diff --git a/tfhe/src/c_api/high_level_api/compact_list.rs b/tfhe/src/c_api/high_level_api/compact_list.rs index 4473e4e99..7e0f0ffeb 100644 --- a/tfhe/src/c_api/high_level_api/compact_list.rs +++ b/tfhe/src/c_api/high_level_api/compact_list.rs @@ -15,6 +15,7 @@ use crate::c_api::high_level_api::utils::{ #[cfg(feature = "zk-pok")] use crate::c_api::high_level_api::zk::{CompactPkePublicParams, ZkComputeLoad}; use crate::c_api::utils::{catch_panic, get_mut_checked, get_ref_checked}; +use crate::prelude::CiphertextList; use std::ffi::c_int; pub struct CompactCiphertextListBuilder(crate::high_level_api::CompactCiphertextListBuilder); diff --git a/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs index 1b04e08dd..03224feee 100644 --- a/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/c_api/high_level_api/compressed_ciphertext_list.rs @@ -8,6 +8,7 @@ use crate::c_api::high_level_api::utils::{ impl_destroy_on_type, impl_serialize_deserialize_on_type, }; use crate::c_api::utils::{catch_panic, get_mut_checked, get_ref_checked}; +use crate::prelude::CiphertextList; use std::ffi::c_int; pub struct CompressedCiphertextListBuilder(crate::high_level_api::CompressedCiphertextListBuilder); diff --git a/tfhe/src/high_level_api/backward_compatibility/booleans.rs b/tfhe/src/high_level_api/backward_compatibility/booleans.rs index 2605b8965..8ef15a662 100644 --- a/tfhe/src/high_level_api/backward_compatibility/booleans.rs +++ b/tfhe/src/high_level_api/backward_compatibility/booleans.rs @@ -7,6 +7,7 @@ use crate::high_level_api::booleans::{ InnerBoolean, InnerBooleanVersionOwned, InnerCompressedFheBool, }; use crate::integer::ciphertext::{CompactCiphertextList, DataKind}; +use crate::prelude::CiphertextList; use crate::{ CompactCiphertextList as HlCompactCiphertextList, CompressedFheBool, Error, FheBool, Tag, }; @@ -111,7 +112,7 @@ impl CompactFheBool { let block = list .inner .get::(0) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|b| b.ok_or_else(|| Error::new("Failed to expand compact list".to_string())))??; let mut ciphertext = FheBool::new(block, Tag::default()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); @@ -148,7 +149,9 @@ impl CompactFheBoolList { let block = list .inner .get::(idx) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|list| { + list.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; let mut ciphertext = FheBool::new(block, Tag::default()); ciphertext.ciphertext.move_to_device_of_server_key_if_set(); diff --git a/tfhe/src/high_level_api/backward_compatibility/integers.rs b/tfhe/src/high_level_api/backward_compatibility/integers.rs index de05f2b0a..fca57b4ab 100644 --- a/tfhe/src/high_level_api/backward_compatibility/integers.rs +++ b/tfhe/src/high_level_api/backward_compatibility/integers.rs @@ -16,6 +16,7 @@ use crate::integer::ciphertext::{ CompressedRadixCiphertext as IntegerCompressedRadixCiphertext, CompressedSignedRadixCiphertext as IntegerCompressedSignedRadixCiphertext, DataKind, }; +use crate::prelude::CiphertextList; use crate::shortint::ciphertext::CompressedModulusSwitchedCiphertext; use crate::shortint::{Ciphertext, ServerKey}; use crate::{CompactCiphertextList as HlCompactCiphertextList, Error, Tag}; @@ -277,7 +278,9 @@ where let ct = list .inner .get::(0) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|list| { + list.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheInt::new(ct, Tag::default())) } } @@ -316,7 +319,9 @@ where let ct = list .inner .get::(idx) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|list| { + list.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheInt::new(ct, Tag::default())) }) .collect::, _>>() @@ -353,7 +358,9 @@ where let ct = list .inner .get::(0) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|ct| { + ct.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheUint::new(ct, Tag::default())) } } @@ -391,7 +398,9 @@ where let ct = list .inner .get::(idx) - .ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??; + .map(|ct| { + ct.ok_or_else(|| Error::new("Failed to expand compact list".to_string())) + })??; Ok(FheUint::new(ct, Tag::default())) }) .collect::, _>>() diff --git a/tfhe/src/high_level_api/compact_list.rs b/tfhe/src/high_level_api/compact_list.rs index 0cc86b279..723cd7c29 100644 --- a/tfhe/src/high_level_api/compact_list.rs +++ b/tfhe/src/high_level_api/compact_list.rs @@ -16,6 +16,7 @@ use crate::integer::parameters::{ IntegerCompactCiphertextListUnpackingMode, }; use crate::named::Named; +use crate::prelude::CiphertextList; use crate::shortint::MessageModulus; #[cfg(feature = "zk-pok")] pub use zk::ProvenCompactCiphertextList; @@ -369,27 +370,27 @@ pub struct CompactCiphertextListExpander { tag: Tag, } -impl CompactCiphertextListExpander { - pub fn len(&self) -> usize { +impl CiphertextList for CompactCiphertextListExpander { + fn len(&self) -> usize { self.inner.len() } - pub fn is_empty(&self) -> bool { + fn is_empty(&self) -> bool { self.len() == 0 } - pub fn get_kind_of(&self, index: usize) -> Option { + fn get_kind_of(&self, index: usize) -> Option { self.inner.get_kind_of(index).and_then(|data_kind| { crate::FheTypes::from_data_kind(data_kind, self.inner.message_modulus()) }) } - pub fn get(&self, index: usize) -> Option> + fn get(&self, index: usize) -> crate::Result> where T: Expandable + Tagged, { let mut expanded = self.inner.get::(index); - if let Some(Ok(inner)) = &mut expanded { + if let Ok(Some(inner)) = &mut expanded { inner.tag_mut().set_data(self.tag.data()); } expanded @@ -543,15 +544,15 @@ mod tests { let e: u8 = e.decrypt(&ck); assert_eq!(e, 3); - assert!(expander.get::(5).is_none()); + assert!(expander.get::(5).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } } @@ -605,15 +606,15 @@ mod tests { let e: u8 = e.decrypt(&ck); assert_eq!(e, 3); - assert!(expander.get::(5).is_none()); + assert!(expander.get::(5).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } } @@ -668,15 +669,15 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(expander.get::(4).is_none()); + assert!(expander.get::(4).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } let unverified_expander = compact_list.expand_without_verification().unwrap(); @@ -696,7 +697,7 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(unverified_expander.get::(4).is_none()); + assert!(unverified_expander.get::(4).unwrap().is_none()); } } @@ -757,15 +758,15 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(expander.get::(4).is_none()); + assert!(expander.get::(4).unwrap().is_none()); } { // Incorrect type - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); // Correct type but wrong number of bits - assert!(expander.get::(0).unwrap().is_err()); + assert!(expander.get::(0).is_err()); } let unverified_expander = compact_list.expand_without_verification().unwrap(); @@ -785,7 +786,7 @@ mod tests { let d: u8 = d.decrypt(&ck); assert_eq!(d, 3); - assert!(unverified_expander.get::(4).is_none()); + assert!(unverified_expander.get::(4).unwrap().is_none()); } } } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 83e215344..2bb06ff95 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -12,7 +12,7 @@ use crate::integer::gpu::ciphertext::compressed_ciphertext_list::{ CudaCompressible, CudaExpandable, }; use crate::named::Named; -use crate::prelude::Tagged; +use crate::prelude::{CiphertextList, Tagged}; use crate::shortint::Ciphertext; use crate::{FheBool, FheInt, FheUint, Tag}; @@ -233,8 +233,8 @@ impl Tagged for CompressedCiphertextList { } } -impl CompressedCiphertextList { - pub fn len(&self) -> usize { +impl CiphertextList for CompressedCiphertextList { + fn len(&self) -> usize { match &self.inner { InnerCompressedCiphertextList::Cpu(inner) => inner.len(), #[cfg(feature = "gpu")] @@ -242,7 +242,7 @@ impl CompressedCiphertextList { } } - pub fn is_empty(&self) -> bool { + fn is_empty(&self) -> bool { match &self.inner { InnerCompressedCiphertextList::Cpu(inner) => inner.len() == 0, #[cfg(feature = "gpu")] @@ -250,7 +250,7 @@ impl CompressedCiphertextList { } } - pub fn get_kind_of(&self, index: usize) -> Option { + fn get_kind_of(&self, index: usize) -> Option { match &self.inner { InnerCompressedCiphertextList::Cpu(inner) => Some(match inner.get_kind_of(index)? { DataKind::Unsigned(n) => { @@ -342,7 +342,7 @@ impl CompressedCiphertextList { } } - pub fn get(&self, index: usize) -> crate::Result> + fn get(&self, index: usize) -> crate::Result> where T: HlExpandable + Tagged, { @@ -394,7 +394,9 @@ impl CompressedCiphertextList { } } } +} +impl CompressedCiphertextList { pub fn into_raw_parts(self) -> (crate::integer::ciphertext::CompressedCiphertextList, Tag) { let Self { inner, tag } = self; match inner { diff --git a/tfhe/src/high_level_api/prelude.rs b/tfhe/src/high_level_api/prelude.rs index 5baf64ad9..128cea0a9 100644 --- a/tfhe/src/high_level_api/prelude.rs +++ b/tfhe/src/high_level_api/prelude.rs @@ -7,10 +7,10 @@ //! use tfhe::prelude::*; //! ``` pub use crate::high_level_api::traits::{ - BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, - FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse, - OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight, - RotateRightAssign, Tagged, + BitSlice, CiphertextList, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, + FheMax, FheMin, FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, + FheTryTrivialEncrypt, IfThenElse, OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, + RotateLeftAssign, RotateRight, RotateRightAssign, Tagged, }; pub use crate::conformance::ParameterSetConformant; diff --git a/tfhe/src/high_level_api/traits.rs b/tfhe/src/high_level_api/traits.rs index 850ec3c95..d5b7f2611 100644 --- a/tfhe/src/high_level_api/traits.rs +++ b/tfhe/src/high_level_api/traits.rs @@ -4,6 +4,8 @@ use crate::error::InvalidRangeError; use crate::high_level_api::ClientKey; use crate::{FheBool, Tag}; +use super::compressed_ciphertext_list::HlExpandable; + /// Trait used to have a generic way of creating a value of a FHE type /// from a native value. /// @@ -199,3 +201,12 @@ pub trait Tagged { fn tag_mut(&mut self) -> &mut Tag; } + +pub trait CiphertextList { + fn len(&self) -> usize; + fn is_empty(&self) -> bool; + fn get_kind_of(&self, index: usize) -> Option; + fn get(&self, index: usize) -> crate::Result> + where + T: HlExpandable + Tagged; +} diff --git a/tfhe/src/integer/ciphertext/compact_list.rs b/tfhe/src/integer/ciphertext/compact_list.rs index b77fb963e..33d730d54 100644 --- a/tfhe/src/integer/ciphertext/compact_list.rs +++ b/tfhe/src/integer/ciphertext/compact_list.rs @@ -287,12 +287,13 @@ impl CompactCiphertextListExpander { .map(|block| (block, current_info)) } - pub fn get(&self, index: usize) -> Option> + pub fn get(&self, index: usize) -> crate::Result> where T: Expandable, { self.blocks_of(index) .map(|(blocks, kind)| T::from_expanded_blocks(blocks.to_owned(), kind)) + .transpose() } pub(crate) fn message_modulus(&self) -> MessageModulus { diff --git a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs index bfec4879a..6168e9e73 100644 --- a/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs +++ b/tfhe/src/js_on_wasm_api/js_high_level_api/integers.rs @@ -1051,12 +1051,13 @@ macro_rules! define_expander_get_method { #[wasm_bindgen] pub fn [] (&mut self, index: usize) -> Result<[], JsError> { catch_panic_result(|| { - self.0.get::]>(index) - .map_or_else( - || Err(JsError::new(&format!("Index {index} is out of bounds"))), - |a| a.map_err(into_js_error), - ) - .map([]) + self.0.get::]>(index) + .map_err(into_js_error) + .map(|val| + val.map_or_else( + || Err(JsError::new(&format!("Index {index} is out of bounds"))), + |val| Ok([](val)) + ))? }) } )* @@ -1077,11 +1078,12 @@ macro_rules! define_expander_get_method { pub fn [] (&mut self, index: usize) -> Result<[], JsError> { catch_panic_result(|| { self.0.get::]>(index) - .map_or_else( - || Err(JsError::new(&format!("Index {index} is out of bounds"))), - |a| a.map_err(into_js_error), - ) - .map([]) + .map_err(into_js_error) + .map(|val| + val.map_or_else( + || Err(JsError::new(&format!("Index {index} is out of bounds"))), + |val| Ok([](val)) + ))? }) } )* @@ -1103,11 +1105,13 @@ impl CompactCiphertextListExpander { catch_panic_result(|| { self.0 .get::(index) - .map_or_else( - || Err(JsError::new(&format!("Index {index} is out of bounds"))), - |a| a.map_err(into_js_error), - ) - .map(FheBool) + .map_err(into_js_error) + .map(|val| { + val.map_or_else( + || Err(JsError::new(&format!("Index {index} is out of bounds"))), + |val| Ok(FheBool(val)), + ) + })? }) } diff --git a/tfhe/tests/backward_compatibility/high_level_api.rs b/tfhe/tests/backward_compatibility/high_level_api.rs index 86385d32e..5422426e7 100644 --- a/tfhe/tests/backward_compatibility/high_level_api.rs +++ b/tfhe/tests/backward_compatibility/high_level_api.rs @@ -6,7 +6,7 @@ use tfhe::backward_compatibility::integers::{ CompactFheInt8, CompactFheInt8List, CompactFheUint8, CompactFheUint8List, }; -use tfhe::prelude::{FheDecrypt, FheEncrypt}; +use tfhe::prelude::{CiphertextList, FheDecrypt, FheEncrypt}; use tfhe::shortint::PBSParameters; use tfhe::{ set_server_key, ClientKey, CompactCiphertextList, CompressedCiphertextList, @@ -261,6 +261,7 @@ pub fn test_hl_bool_ciphertext_list( /// Test HL ciphertext list: loads the ciphertext list and compare the decrypted values to the ones /// in the metadata. + pub fn test_hl_heterogeneous_ciphertext_list( dir: &Path, test: &HlHeterogeneousCiphertextListTest, @@ -276,33 +277,25 @@ pub fn test_hl_heterogeneous_ciphertext_list( set_server_key(server_key); if test.compressed { - test_hl_heterogeneous_ciphertext_list_compressed( - load_and_unversionize(dir, test, format)?, - &key, - test, - ) + let list: CompressedCiphertextList = load_and_unversionize(dir, test, format)?; + test_hl_heterogeneous_ciphertext_list_elements(list, &key, test) } else { - test_hl_heterogeneous_ciphertext_list_compact( - load_and_unversionize(dir, test, format)?, - &key, - test, - ) + let list: CompactCiphertextList = load_and_unversionize(dir, test, format)?; + test_hl_heterogeneous_ciphertext_list_elements(list.expand().unwrap(), &key, test) } .map(|_| test.success(format)) .map_err(|msg| test.failure(msg, format)) } -pub fn test_hl_heterogeneous_ciphertext_list_compact( - list: CompactCiphertextList, +pub fn test_hl_heterogeneous_ciphertext_list_elements( + list: CtList, key: &ClientKey, test: &HlHeterogeneousCiphertextListTest, ) -> Result<(), String> { - let ct_list = list.expand().unwrap(); - - for idx in 0..(ct_list.len()) { + for idx in 0..(list.len()) { match test.data_kinds[idx] { DataKind::Bool => { - let ct: FheBool = ct_list.get(idx).unwrap().unwrap(); + let ct: FheBool = list.get(idx).unwrap().unwrap(); let clear = ct.decrypt(key); if clear != (test.clear_values[idx] != 0) { return Err(format!( @@ -312,7 +305,7 @@ pub fn test_hl_heterogeneous_ciphertext_list_compact( } } DataKind::Signed => { - let ct: FheInt8 = ct_list.get(idx).unwrap().unwrap(); + let ct: FheInt8 = list.get(idx).unwrap().unwrap(); let clear: i8 = ct.decrypt(key); if clear != test.clear_values[idx] as i8 { return Err(format!( @@ -323,52 +316,7 @@ pub fn test_hl_heterogeneous_ciphertext_list_compact( } } DataKind::Unsigned => { - let ct: FheUint8 = ct_list.get(idx).unwrap().unwrap(); - let clear: u8 = ct.decrypt(key); - if clear != test.clear_values[idx] as u8 { - return Err(format!( - "Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}", - clear, test.clear_values[idx] - )); - } - } - }; - } - Ok(()) -} - -pub fn test_hl_heterogeneous_ciphertext_list_compressed( - list: CompressedCiphertextList, - key: &ClientKey, - test: &HlHeterogeneousCiphertextListTest, -) -> Result<(), String> { - let ct_list = list; - - for idx in 0..(ct_list.len()) { - match test.data_kinds[idx] { - DataKind::Bool => { - let ct: FheBool = ct_list.get(idx).unwrap().unwrap(); - let clear = ct.decrypt(key); - if clear != (test.clear_values[idx] != 0) { - return Err(format!( - "Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}", - clear, test.clear_values[idx] - )); - } - } - DataKind::Signed => { - let ct: FheInt8 = ct_list.get(idx).unwrap().unwrap(); - let clear: i8 = ct.decrypt(key); - if clear != test.clear_values[idx] as i8 { - return Err(format!( - "Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}", - clear, - (test.clear_values[idx] as i8) - )); - } - } - DataKind::Unsigned => { - let ct: FheUint8 = ct_list.get(idx).unwrap().unwrap(); + let ct: FheUint8 = list.get(idx).unwrap().unwrap(); let clear: u8 = ct.decrypt(key); if clear != test.clear_values[idx] as u8 { return Err(format!(