refactor(hl)!: use a trait for common ciphertext lists methods

BREAKING CHANGE:
- The `CiphertextList` trait needs to be in scope to use the common methods of
the `CompressedCiphertextList` and `CompactCiphertextListExpander`
- The `.get` of the `CompactCiphertextListExpander` now returns a
`Result<Option>` instead of an `Option<Result>`
This commit is contained in:
Nicolas Sarlin
2024-09-09 18:13:08 +02:00
committed by Nicolas Sarlin
parent e91d532a36
commit 0d49d19a13
12 changed files with 106 additions and 125 deletions

View File

@@ -12,7 +12,7 @@ Using this feature is straightforward: during encryption, the client generates t
```rust
use rand::prelude::*;
use tfhe::prelude::FheDecrypt;
use tfhe::prelude::*;
use tfhe::set_server_key;
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};
@@ -45,9 +45,8 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
// Verify the ciphertexts
let expander = proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?;
let a: tfhe::FheUint64 = expander.get(0).unwrap()?;
let b: tfhe::FheUint64 = expander.get(1).unwrap()?;
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();
a + b
};
@@ -80,7 +79,7 @@ This works essentially in the same way as before. Additionally, you need to indi
```rust
use rand::prelude::*;
use tfhe::prelude::FheDecrypt;
use tfhe::prelude::*;
use tfhe::set_server_key;
use tfhe::zk::{CompactPkeCrs, ZkComputeLoad};
@@ -119,9 +118,10 @@ pub fn main() -> Result<(), Box<dyn std::error::Error>> {
set_server_key(server_key);
// Verify the ciphertexts
let expander = proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?;
let a: tfhe::FheUint64 = expander.get(0).unwrap()?;
let b: tfhe::FheUint64 = expander.get(1).unwrap()?;
let expander =
proven_compact_list.verify_and_expand(public_zk_params, &public_key, &metadata)?;
let a: tfhe::FheUint64 = expander.get(0)?.unwrap();
let b: tfhe::FheUint64 = expander.get(1)?.unwrap();
a + b
};

View File

@@ -15,6 +15,7 @@ use crate::c_api::high_level_api::utils::{
#[cfg(feature = "zk-pok")]
use crate::c_api::high_level_api::zk::{CompactPkePublicParams, ZkComputeLoad};
use crate::c_api::utils::{catch_panic, get_mut_checked, get_ref_checked};
use crate::prelude::CiphertextList;
use std::ffi::c_int;
pub struct CompactCiphertextListBuilder(crate::high_level_api::CompactCiphertextListBuilder);

View File

@@ -8,6 +8,7 @@ use crate::c_api::high_level_api::utils::{
impl_destroy_on_type, impl_serialize_deserialize_on_type,
};
use crate::c_api::utils::{catch_panic, get_mut_checked, get_ref_checked};
use crate::prelude::CiphertextList;
use std::ffi::c_int;
pub struct CompressedCiphertextListBuilder(crate::high_level_api::CompressedCiphertextListBuilder);

View File

@@ -7,6 +7,7 @@ use crate::high_level_api::booleans::{
InnerBoolean, InnerBooleanVersionOwned, InnerCompressedFheBool,
};
use crate::integer::ciphertext::{CompactCiphertextList, DataKind};
use crate::prelude::CiphertextList;
use crate::{
CompactCiphertextList as HlCompactCiphertextList, CompressedFheBool, Error, FheBool, Tag,
};
@@ -111,7 +112,7 @@ impl CompactFheBool {
let block = list
.inner
.get::<crate::integer::BooleanBlock>(0)
.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??;
.map(|b| b.ok_or_else(|| Error::new("Failed to expand compact list".to_string())))??;
let mut ciphertext = FheBool::new(block, Tag::default());
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
@@ -148,7 +149,9 @@ impl CompactFheBoolList {
let block = list
.inner
.get::<crate::integer::BooleanBlock>(idx)
.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??;
.map(|list| {
list.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))
})??;
let mut ciphertext = FheBool::new(block, Tag::default());
ciphertext.ciphertext.move_to_device_of_server_key_if_set();

View File

@@ -16,6 +16,7 @@ use crate::integer::ciphertext::{
CompressedRadixCiphertext as IntegerCompressedRadixCiphertext,
CompressedSignedRadixCiphertext as IntegerCompressedSignedRadixCiphertext, DataKind,
};
use crate::prelude::CiphertextList;
use crate::shortint::ciphertext::CompressedModulusSwitchedCiphertext;
use crate::shortint::{Ciphertext, ServerKey};
use crate::{CompactCiphertextList as HlCompactCiphertextList, Error, Tag};
@@ -277,7 +278,9 @@ where
let ct = list
.inner
.get::<crate::integer::SignedRadixCiphertext>(0)
.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??;
.map(|list| {
list.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))
})??;
Ok(FheInt::new(ct, Tag::default()))
}
}
@@ -316,7 +319,9 @@ where
let ct = list
.inner
.get::<crate::integer::SignedRadixCiphertext>(idx)
.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??;
.map(|list| {
list.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))
})??;
Ok(FheInt::new(ct, Tag::default()))
})
.collect::<Result<Vec<_>, _>>()
@@ -353,7 +358,9 @@ where
let ct = list
.inner
.get::<crate::integer::RadixCiphertext>(0)
.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??;
.map(|ct| {
ct.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))
})??;
Ok(FheUint::new(ct, Tag::default()))
}
}
@@ -391,7 +398,9 @@ where
let ct = list
.inner
.get::<crate::integer::RadixCiphertext>(idx)
.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))??;
.map(|ct| {
ct.ok_or_else(|| Error::new("Failed to expand compact list".to_string()))
})??;
Ok(FheUint::new(ct, Tag::default()))
})
.collect::<Result<Vec<_>, _>>()

View File

@@ -16,6 +16,7 @@ use crate::integer::parameters::{
IntegerCompactCiphertextListUnpackingMode,
};
use crate::named::Named;
use crate::prelude::CiphertextList;
use crate::shortint::MessageModulus;
#[cfg(feature = "zk-pok")]
pub use zk::ProvenCompactCiphertextList;
@@ -369,27 +370,27 @@ pub struct CompactCiphertextListExpander {
tag: Tag,
}
impl CompactCiphertextListExpander {
pub fn len(&self) -> usize {
impl CiphertextList for CompactCiphertextListExpander {
fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
self.inner.get_kind_of(index).and_then(|data_kind| {
crate::FheTypes::from_data_kind(data_kind, self.inner.message_modulus())
})
}
pub fn get<T>(&self, index: usize) -> Option<crate::Result<T>>
fn get<T>(&self, index: usize) -> crate::Result<Option<T>>
where
T: Expandable + Tagged,
{
let mut expanded = self.inner.get::<T>(index);
if let Some(Ok(inner)) = &mut expanded {
if let Ok(Some(inner)) = &mut expanded {
inner.tag_mut().set_data(self.tag.data());
}
expanded
@@ -543,15 +544,15 @@ mod tests {
let e: u8 = e.decrypt(&ck);
assert_eq!(e, 3);
assert!(expander.get::<FheBool>(5).is_none());
assert!(expander.get::<FheBool>(5).unwrap().is_none());
}
{
// Incorrect type
assert!(expander.get::<FheInt64>(0).unwrap().is_err());
assert!(expander.get::<FheInt64>(0).is_err());
// Correct type but wrong number of bits
assert!(expander.get::<FheUint16>(0).unwrap().is_err());
assert!(expander.get::<FheUint16>(0).is_err());
}
}
@@ -605,15 +606,15 @@ mod tests {
let e: u8 = e.decrypt(&ck);
assert_eq!(e, 3);
assert!(expander.get::<FheBool>(5).is_none());
assert!(expander.get::<FheBool>(5).unwrap().is_none());
}
{
// Incorrect type
assert!(expander.get::<FheInt64>(0).unwrap().is_err());
assert!(expander.get::<FheInt64>(0).is_err());
// Correct type but wrong number of bits
assert!(expander.get::<FheUint16>(0).unwrap().is_err());
assert!(expander.get::<FheUint16>(0).is_err());
}
}
@@ -668,15 +669,15 @@ mod tests {
let d: u8 = d.decrypt(&ck);
assert_eq!(d, 3);
assert!(expander.get::<FheBool>(4).is_none());
assert!(expander.get::<FheBool>(4).unwrap().is_none());
}
{
// Incorrect type
assert!(expander.get::<FheInt64>(0).unwrap().is_err());
assert!(expander.get::<FheInt64>(0).is_err());
// Correct type but wrong number of bits
assert!(expander.get::<FheUint16>(0).unwrap().is_err());
assert!(expander.get::<FheUint16>(0).is_err());
}
let unverified_expander = compact_list.expand_without_verification().unwrap();
@@ -696,7 +697,7 @@ mod tests {
let d: u8 = d.decrypt(&ck);
assert_eq!(d, 3);
assert!(unverified_expander.get::<FheBool>(4).is_none());
assert!(unverified_expander.get::<FheBool>(4).unwrap().is_none());
}
}
@@ -757,15 +758,15 @@ mod tests {
let d: u8 = d.decrypt(&ck);
assert_eq!(d, 3);
assert!(expander.get::<FheBool>(4).is_none());
assert!(expander.get::<FheBool>(4).unwrap().is_none());
}
{
// Incorrect type
assert!(expander.get::<FheInt64>(0).unwrap().is_err());
assert!(expander.get::<FheInt64>(0).is_err());
// Correct type but wrong number of bits
assert!(expander.get::<FheUint16>(0).unwrap().is_err());
assert!(expander.get::<FheUint16>(0).is_err());
}
let unverified_expander = compact_list.expand_without_verification().unwrap();
@@ -785,7 +786,7 @@ mod tests {
let d: u8 = d.decrypt(&ck);
assert_eq!(d, 3);
assert!(unverified_expander.get::<FheBool>(4).is_none());
assert!(unverified_expander.get::<FheBool>(4).unwrap().is_none());
}
}
}

View File

@@ -12,7 +12,7 @@ use crate::integer::gpu::ciphertext::compressed_ciphertext_list::{
CudaCompressible, CudaExpandable,
};
use crate::named::Named;
use crate::prelude::Tagged;
use crate::prelude::{CiphertextList, Tagged};
use crate::shortint::Ciphertext;
use crate::{FheBool, FheInt, FheUint, Tag};
@@ -233,8 +233,8 @@ impl Tagged for CompressedCiphertextList {
}
}
impl CompressedCiphertextList {
pub fn len(&self) -> usize {
impl CiphertextList for CompressedCiphertextList {
fn len(&self) -> usize {
match &self.inner {
InnerCompressedCiphertextList::Cpu(inner) => inner.len(),
#[cfg(feature = "gpu")]
@@ -242,7 +242,7 @@ impl CompressedCiphertextList {
}
}
pub fn is_empty(&self) -> bool {
fn is_empty(&self) -> bool {
match &self.inner {
InnerCompressedCiphertextList::Cpu(inner) => inner.len() == 0,
#[cfg(feature = "gpu")]
@@ -250,7 +250,7 @@ impl CompressedCiphertextList {
}
}
pub fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes> {
match &self.inner {
InnerCompressedCiphertextList::Cpu(inner) => Some(match inner.get_kind_of(index)? {
DataKind::Unsigned(n) => {
@@ -342,7 +342,7 @@ impl CompressedCiphertextList {
}
}
pub fn get<T>(&self, index: usize) -> crate::Result<Option<T>>
fn get<T>(&self, index: usize) -> crate::Result<Option<T>>
where
T: HlExpandable + Tagged,
{
@@ -394,7 +394,9 @@ impl CompressedCiphertextList {
}
}
}
}
impl CompressedCiphertextList {
pub fn into_raw_parts(self) -> (crate::integer::ciphertext::CompressedCiphertextList, Tag) {
let Self { inner, tag } = self;
match inner {

View File

@@ -7,10 +7,10 @@
//! use tfhe::prelude::*;
//! ```
pub use crate::high_level_api::traits::{
BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin,
FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse,
OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign, Tagged,
BitSlice, CiphertextList, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch,
FheMax, FheMin, FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt,
FheTryTrivialEncrypt, IfThenElse, OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft,
RotateLeftAssign, RotateRight, RotateRightAssign, Tagged,
};
pub use crate::conformance::ParameterSetConformant;

View File

@@ -4,6 +4,8 @@ use crate::error::InvalidRangeError;
use crate::high_level_api::ClientKey;
use crate::{FheBool, Tag};
use super::compressed_ciphertext_list::HlExpandable;
/// Trait used to have a generic way of creating a value of a FHE type
/// from a native value.
///
@@ -199,3 +201,12 @@ pub trait Tagged {
fn tag_mut(&mut self) -> &mut Tag;
}
pub trait CiphertextList {
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
fn get_kind_of(&self, index: usize) -> Option<crate::FheTypes>;
fn get<T>(&self, index: usize) -> crate::Result<Option<T>>
where
T: HlExpandable + Tagged;
}

View File

@@ -287,12 +287,13 @@ impl CompactCiphertextListExpander {
.map(|block| (block, current_info))
}
pub fn get<T>(&self, index: usize) -> Option<crate::Result<T>>
pub fn get<T>(&self, index: usize) -> crate::Result<Option<T>>
where
T: Expandable,
{
self.blocks_of(index)
.map(|(blocks, kind)| T::from_expanded_blocks(blocks.to_owned(), kind))
.transpose()
}
pub(crate) fn message_modulus(&self) -> MessageModulus {

View File

@@ -1051,12 +1051,13 @@ macro_rules! define_expander_get_method {
#[wasm_bindgen]
pub fn [<get_uint $num_bits>] (&mut self, index: usize) -> Result<[<FheUint $num_bits>], JsError> {
catch_panic_result(|| {
self.0.get::<crate::[<FheUint $num_bits>]>(index)
.map_or_else(
|| Err(JsError::new(&format!("Index {index} is out of bounds"))),
|a| a.map_err(into_js_error),
)
.map([<FheUint $num_bits>])
self.0.get::<crate::[<FheUint $num_bits>]>(index)
.map_err(into_js_error)
.map(|val|
val.map_or_else(
|| Err(JsError::new(&format!("Index {index} is out of bounds"))),
|val| Ok([<FheUint $num_bits>](val))
))?
})
}
)*
@@ -1077,11 +1078,12 @@ macro_rules! define_expander_get_method {
pub fn [<get_int $num_bits>] (&mut self, index: usize) -> Result<[<FheInt $num_bits>], JsError> {
catch_panic_result(|| {
self.0.get::<crate::[<FheInt $num_bits>]>(index)
.map_or_else(
|| Err(JsError::new(&format!("Index {index} is out of bounds"))),
|a| a.map_err(into_js_error),
)
.map([<FheInt $num_bits>])
.map_err(into_js_error)
.map(|val|
val.map_or_else(
|| Err(JsError::new(&format!("Index {index} is out of bounds"))),
|val| Ok([<FheInt $num_bits>](val))
))?
})
}
)*
@@ -1103,11 +1105,13 @@ impl CompactCiphertextListExpander {
catch_panic_result(|| {
self.0
.get::<crate::FheBool>(index)
.map_or_else(
|| Err(JsError::new(&format!("Index {index} is out of bounds"))),
|a| a.map_err(into_js_error),
)
.map(FheBool)
.map_err(into_js_error)
.map(|val| {
val.map_or_else(
|| Err(JsError::new(&format!("Index {index} is out of bounds"))),
|val| Ok(FheBool(val)),
)
})?
})
}

View File

@@ -6,7 +6,7 @@ use tfhe::backward_compatibility::integers::{
CompactFheInt8, CompactFheInt8List, CompactFheUint8, CompactFheUint8List,
};
use tfhe::prelude::{FheDecrypt, FheEncrypt};
use tfhe::prelude::{CiphertextList, FheDecrypt, FheEncrypt};
use tfhe::shortint::PBSParameters;
use tfhe::{
set_server_key, ClientKey, CompactCiphertextList, CompressedCiphertextList,
@@ -261,6 +261,7 @@ pub fn test_hl_bool_ciphertext_list(
/// Test HL ciphertext list: loads the ciphertext list and compare the decrypted values to the ones
/// in the metadata.
pub fn test_hl_heterogeneous_ciphertext_list(
dir: &Path,
test: &HlHeterogeneousCiphertextListTest,
@@ -276,33 +277,25 @@ pub fn test_hl_heterogeneous_ciphertext_list(
set_server_key(server_key);
if test.compressed {
test_hl_heterogeneous_ciphertext_list_compressed(
load_and_unversionize(dir, test, format)?,
&key,
test,
)
let list: CompressedCiphertextList = load_and_unversionize(dir, test, format)?;
test_hl_heterogeneous_ciphertext_list_elements(list, &key, test)
} else {
test_hl_heterogeneous_ciphertext_list_compact(
load_and_unversionize(dir, test, format)?,
&key,
test,
)
let list: CompactCiphertextList = load_and_unversionize(dir, test, format)?;
test_hl_heterogeneous_ciphertext_list_elements(list.expand().unwrap(), &key, test)
}
.map(|_| test.success(format))
.map_err(|msg| test.failure(msg, format))
}
pub fn test_hl_heterogeneous_ciphertext_list_compact(
list: CompactCiphertextList,
pub fn test_hl_heterogeneous_ciphertext_list_elements<CtList: CiphertextList>(
list: CtList,
key: &ClientKey,
test: &HlHeterogeneousCiphertextListTest,
) -> Result<(), String> {
let ct_list = list.expand().unwrap();
for idx in 0..(ct_list.len()) {
for idx in 0..(list.len()) {
match test.data_kinds[idx] {
DataKind::Bool => {
let ct: FheBool = ct_list.get(idx).unwrap().unwrap();
let ct: FheBool = list.get(idx).unwrap().unwrap();
let clear = ct.decrypt(key);
if clear != (test.clear_values[idx] != 0) {
return Err(format!(
@@ -312,7 +305,7 @@ pub fn test_hl_heterogeneous_ciphertext_list_compact(
}
}
DataKind::Signed => {
let ct: FheInt8 = ct_list.get(idx).unwrap().unwrap();
let ct: FheInt8 = list.get(idx).unwrap().unwrap();
let clear: i8 = ct.decrypt(key);
if clear != test.clear_values[idx] as i8 {
return Err(format!(
@@ -323,52 +316,7 @@ pub fn test_hl_heterogeneous_ciphertext_list_compact(
}
}
DataKind::Unsigned => {
let ct: FheUint8 = ct_list.get(idx).unwrap().unwrap();
let clear: u8 = ct.decrypt(key);
if clear != test.clear_values[idx] as u8 {
return Err(format!(
"Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}",
clear, test.clear_values[idx]
));
}
}
};
}
Ok(())
}
pub fn test_hl_heterogeneous_ciphertext_list_compressed(
list: CompressedCiphertextList,
key: &ClientKey,
test: &HlHeterogeneousCiphertextListTest,
) -> Result<(), String> {
let ct_list = list;
for idx in 0..(ct_list.len()) {
match test.data_kinds[idx] {
DataKind::Bool => {
let ct: FheBool = ct_list.get(idx).unwrap().unwrap();
let clear = ct.decrypt(key);
if clear != (test.clear_values[idx] != 0) {
return Err(format!(
"Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}",
clear, test.clear_values[idx]
));
}
}
DataKind::Signed => {
let ct: FheInt8 = ct_list.get(idx).unwrap().unwrap();
let clear: i8 = ct.decrypt(key);
if clear != test.clear_values[idx] as i8 {
return Err(format!(
"Invalid decrypted cleartext:\n Expected :\n{:?}\nGot:\n{:?}",
clear,
(test.clear_values[idx] as i8)
));
}
}
DataKind::Unsigned => {
let ct: FheUint8 = ct_list.get(idx).unwrap().unwrap();
let ct: FheUint8 = list.get(idx).unwrap().unwrap();
let clear: u8 = ct.decrypt(key);
if clear != test.clear_values[idx] as u8 {
return Err(format!(