feat(tfhe): add WASM and C API bindings and tests

This commit is contained in:
Arthur Meyre
2023-01-12 13:55:02 +01:00
parent 384850f7fa
commit 5945a52eba
13 changed files with 370 additions and 8 deletions

View File

@@ -11,6 +11,9 @@ void test_default_keygen_w_serde(void) {
BooleanCiphertext *ct = NULL;
Buffer ct_ser_buffer = {.pointer = NULL, .length = 0};
BooleanCiphertext *deser_ct = NULL;
BooleanCompressedCiphertext *cct = NULL;
BooleanCompressedCiphertext *deser_cct = NULL;
BooleanCiphertext *decompressed_ct = NULL;
int gen_keys_ok = boolean_gen_keys_with_default_parameters(&cks, &sks);
assert(gen_keys_ok == 0);
@@ -37,10 +40,34 @@ void test_default_keygen_w_serde(void) {
assert(result == true);
int c_encrypt_ok = boolean_client_key_encrypt_compressed(cks, true, &cct);
assert(c_encrypt_ok == 0);
int c_ser_ok = boolean_serialize_compressed_ciphertext(cct, &ct_ser_buffer);
assert(c_ser_ok == 0);
deser_view.pointer = ct_ser_buffer.pointer;
deser_view.length = ct_ser_buffer.length;
int c_deser_ok = boolean_deserialize_compressed_ciphertext(deser_view, &deser_cct);
assert(c_deser_ok == 0);
int decomp_ok = boolean_decompress_ciphertext(cct, &decompressed_ct);
assert(decomp_ok == 0);
bool c_result = false;
int c_decrypt_ok = boolean_client_key_decrypt(cks, decompressed_ct, &c_result);
assert(c_decrypt_ok == 0);
assert(c_result == true);
destroy_boolean_client_key(cks);
destroy_boolean_server_key(sks);
destroy_boolean_ciphertext(ct);
destroy_boolean_ciphertext(deser_ct);
destroy_boolean_compressed_ciphertext(cct);
destroy_boolean_compressed_ciphertext(deser_cct);
destroy_boolean_ciphertext(decompressed_ct);
destroy_buffer(&ct_ser_buffer);
}

View File

@@ -12,6 +12,9 @@ void test_predefined_keygen_w_serde(void) {
ShortintCiphertext *ct = NULL;
Buffer ct_ser_buffer = {.pointer = NULL, .length = 0};
ShortintCiphertext *deser_ct = NULL;
ShortintCompressedCiphertext *cct = NULL;
ShortintCompressedCiphertext *deser_cct = NULL;
ShortintCiphertext *decompressed_ct = NULL;
int get_params_ok = shortint_get_parameters(2, 2, &params);
assert(get_params_ok == 0);
@@ -41,11 +44,35 @@ void test_predefined_keygen_w_serde(void) {
assert(result == 3);
int c_encrypt_ok = shortint_client_key_encrypt_compressed(cks, 3, &cct);
assert(c_encrypt_ok == 0);
int c_ser_ok = shortint_serialize_compressed_ciphertext(cct, &ct_ser_buffer);
assert(c_ser_ok == 0);
deser_view.pointer = ct_ser_buffer.pointer;
deser_view.length = ct_ser_buffer.length;
int c_deser_ok = shortint_deserialize_compressed_ciphertext(deser_view, &deser_cct);
assert(c_deser_ok == 0);
int decomp_ok = shortint_decompress_ciphertext(cct, &decompressed_ct);
assert(decomp_ok == 0);
uint64_t c_result = -1;
int c_decrypt_ok = shortint_client_key_decrypt(cks, decompressed_ct, &c_result);
assert(c_decrypt_ok == 0);
assert(c_result == 3);
destroy_shortint_client_key(cks);
destroy_shortint_server_key(sks);
destroy_shortint_parameters(params);
destroy_shortint_ciphertext(ct);
destroy_shortint_ciphertext(deser_ct);
destroy_shortint_compressed_ciphertext(cct);
destroy_shortint_compressed_ciphertext(deser_cct);
destroy_shortint_ciphertext(decompressed_ct);
destroy_buffer(&ct_ser_buffer);
}

View File

@@ -30,6 +30,23 @@ test('boolean_encrypt_decrypt', (t) => {
// No equality tests here, as wasm stores pointers which will always differ
});
test('boolean_compressed_encrypt_decrypt', (t) => {
let params = Boolean.get_parameters(BooleanParameterSet.Default);
let cks = Boolean.new_client_key(params);
let ct = Boolean.encrypt_compressed(cks, true);
let serialized_cks = Boolean.serialize_client_key(cks);
let deserialized_cks = Boolean.deserialize_client_key(serialized_cks);
let serialized_ct = Boolean.serialize_compressed_ciphertext(ct);
let deserialized_ct = Boolean.deserialize_compressed_ciphertext(serialized_ct);
let decompressed_ct = Boolean.decompress_ciphertext(deserialized_ct);
let decrypted = Boolean.decrypt(deserialized_cks, decompressed_ct);
assert.deepStrictEqual(decrypted, true);
});
test('boolean_public_encrypt_decrypt', (t) => {
let params = Boolean.get_parameters(BooleanParameterSet.Default);
let cks = Boolean.new_client_key(params);
@@ -92,6 +109,23 @@ test('shortint_encrypt_decrypt', (t) => {
// No equality tests here, as wasm stores pointers which will always differ
});
test('shortint_compressed_encrypt_decrypt', (t) => {
let params = Shortint.get_parameters(2, 2);
let cks = Shortint.new_client_key(params);
let ct = Shortint.encrypt_compressed(cks, BigInt(3));
let serialized_cks = Shortint.serialize_client_key(cks);
let deserialized_cks = Shortint.deserialize_client_key(serialized_cks);
let serialized_ct = Shortint.serialize_compressed_ciphertext(ct);
let deserialized_ct = Shortint.deserialize_compressed_ciphertext(serialized_ct);
let decompressed_ct = Shortint.decompress_ciphertext(deserialized_ct);
let decrypted = Shortint.decrypt(deserialized_cks, decompressed_ct);
assert.deepStrictEqual(decrypted, BigInt(3));
});
test('shortint_public_encrypt_decrypt', (t) => {
let params = Shortint.get_parameters(2, 0);
let cks = Shortint.new_client_key(params);

View File

@@ -6,6 +6,10 @@ use crate::boolean;
pub struct BooleanCiphertext(pub(in crate::c_api) boolean::ciphertext::Ciphertext);
pub struct BooleanCompressedCiphertext(
pub(in crate::c_api) boolean::ciphertext::CompressedCiphertext,
);
#[no_mangle]
pub unsafe extern "C" fn boolean_serialize_ciphertext(
ciphertext: *const BooleanCiphertext,
@@ -42,3 +46,62 @@ pub unsafe extern "C" fn boolean_deserialize_ciphertext(
*result = Box::into_raw(heap_allocated_ciphertext);
})
}
#[no_mangle]
pub unsafe extern "C" fn boolean_decompress_ciphertext(
compressed_ciphertext: *mut BooleanCompressedCiphertext,
result: *mut *mut BooleanCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
// First fill the result with a null ptr so that if we fail and the return code is not
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result = std::ptr::null_mut();
let compressed_ciphertext = get_mut_checked(compressed_ciphertext).unwrap();
let ciphertext = compressed_ciphertext.0.clone().into();
let heap_allocated_ciphertext = Box::new(BooleanCiphertext(ciphertext));
*result = Box::into_raw(heap_allocated_ciphertext);
})
}
#[no_mangle]
pub unsafe extern "C" fn boolean_serialize_compressed_ciphertext(
ciphertext: *const BooleanCompressedCiphertext,
result: *mut Buffer,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
let ciphertext = get_ref_checked(ciphertext).unwrap();
let buffer: Buffer = bincode::serialize(&ciphertext.0).unwrap().into();
*result = buffer;
})
}
#[no_mangle]
pub unsafe extern "C" fn boolean_deserialize_compressed_ciphertext(
buffer_view: BufferView,
result: *mut *mut BooleanCompressedCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
// First fill the result with a null ptr so that if we fail and the return code is not
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result = std::ptr::null_mut();
let ciphertext: boolean::ciphertext::CompressedCiphertext =
bincode::deserialize(buffer_view.into()).unwrap();
let heap_allocated_ciphertext = Box::new(BooleanCompressedCiphertext(ciphertext));
*result = Box::into_raw(heap_allocated_ciphertext);
})
}

View File

@@ -5,7 +5,7 @@ use std::os::raw::c_int;
use crate::boolean;
use super::BooleanCiphertext;
use super::{BooleanCiphertext, BooleanCompressedCiphertext};
pub struct BooleanClientKey(pub(in crate::c_api) boolean::client_key::ClientKey);
#[no_mangle]
@@ -52,6 +52,29 @@ pub unsafe extern "C" fn boolean_client_key_encrypt(
})
}
#[no_mangle]
pub unsafe extern "C" fn boolean_client_key_encrypt_compressed(
client_key: *const BooleanClientKey,
value_to_encrypt: bool,
result: *mut *mut BooleanCompressedCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
// First fill the result with a null ptr so that if we fail and the return code is not
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result = std::ptr::null_mut();
let client_key = get_ref_checked(client_key).unwrap();
let heap_allocated_ciphertext = Box::new(BooleanCompressedCiphertext(
client_key.0.encrypt_compressed(value_to_encrypt),
));
*result = Box::into_raw(heap_allocated_ciphertext);
})
}
#[no_mangle]
pub unsafe extern "C" fn boolean_client_key_decrypt(
client_key: *const BooleanClientKey,

View File

@@ -3,8 +3,8 @@ use std::os::raw::c_int;
use super::parameters::BooleanParameters;
use super::{
BooleanCiphertext, BooleanClientKey, BooleanCompressedServerKey, BooleanPublicKey,
BooleanServerKey,
BooleanCiphertext, BooleanClientKey, BooleanCompressedCiphertext, BooleanCompressedServerKey,
BooleanPublicKey, BooleanServerKey,
};
#[no_mangle]
@@ -66,3 +66,14 @@ pub unsafe extern "C" fn destroy_boolean_ciphertext(
drop(Box::from_raw(boolean_ciphertext));
})
}
#[no_mangle]
pub unsafe extern "C" fn destroy_boolean_compressed_ciphertext(
boolean_ciphertext: *mut BooleanCompressedCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(boolean_ciphertext).unwrap();
drop(Box::from_raw(boolean_ciphertext));
})
}

View File

@@ -10,7 +10,7 @@ use std::os::raw::c_int;
use crate::boolean;
pub use ciphertext::BooleanCiphertext;
pub use ciphertext::{BooleanCiphertext, BooleanCompressedCiphertext};
pub use client_key::BooleanClientKey;
pub use public_key::BooleanPublicKey;
pub use server_key::{BooleanCompressedServerKey, BooleanServerKey};

View File

@@ -5,6 +5,9 @@ use std::os::raw::c_int;
use crate::shortint;
pub struct ShortintCiphertext(pub(in crate::c_api) shortint::ciphertext::Ciphertext);
pub struct ShortintCompressedCiphertext(
pub(in crate::c_api) shortint::ciphertext::CompressedCiphertext,
);
#[no_mangle]
pub unsafe extern "C" fn shortint_ciphertext_set_degree(
@@ -68,3 +71,62 @@ pub unsafe extern "C" fn shortint_deserialize_ciphertext(
*result = Box::into_raw(heap_allocated_ciphertext);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_decompress_ciphertext(
compressed_ciphertext: *mut ShortintCompressedCiphertext,
result: *mut *mut ShortintCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
// First fill the result with a null ptr so that if we fail and the return code is not
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result = std::ptr::null_mut();
let compressed_ciphertext = get_mut_checked(compressed_ciphertext).unwrap();
let ciphertext = compressed_ciphertext.0.clone().into();
let heap_allocated_ciphertext = Box::new(ShortintCiphertext(ciphertext));
*result = Box::into_raw(heap_allocated_ciphertext);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_serialize_compressed_ciphertext(
ciphertext: *const ShortintCompressedCiphertext,
result: *mut Buffer,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
let ciphertext = get_ref_checked(ciphertext).unwrap();
let buffer: Buffer = bincode::serialize(&ciphertext.0).unwrap().into();
*result = buffer;
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_deserialize_compressed_ciphertext(
buffer_view: BufferView,
result: *mut *mut ShortintCompressedCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
// First fill the result with a null ptr so that if we fail and the return code is not
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result = std::ptr::null_mut();
let ciphertext: shortint::ciphertext::CompressedCiphertext =
bincode::deserialize(buffer_view.into()).unwrap();
let heap_allocated_ciphertext = Box::new(ShortintCompressedCiphertext(ciphertext));
*result = Box::into_raw(heap_allocated_ciphertext);
})
}

View File

@@ -5,7 +5,7 @@ use std::os::raw::c_int;
use crate::shortint;
use super::ShortintCiphertext;
use super::{ShortintCiphertext, ShortintCompressedCiphertext};
pub struct ShortintClientKey(pub(in crate::c_api) shortint::client_key::ClientKey);
#[no_mangle]
@@ -52,6 +52,29 @@ pub unsafe extern "C" fn shortint_client_key_encrypt(
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_client_key_encrypt_compressed(
client_key: *const ShortintClientKey,
value_to_encrypt: u64,
result: *mut *mut ShortintCompressedCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
// First fill the result with a null ptr so that if we fail and the return code is not
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result = std::ptr::null_mut();
let client_key = get_ref_checked(client_key).unwrap();
let heap_allocated_ciphertext = Box::new(ShortintCompressedCiphertext(
client_key.0.encrypt_compressed(value_to_encrypt),
));
*result = Box::into_raw(heap_allocated_ciphertext);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_client_key_decrypt(
client_key: *const ShortintClientKey,

View File

@@ -4,8 +4,8 @@ use std::os::raw::c_int;
use super::parameters::ShortintParameters;
use super::{
ShortintBivariatePBSAccumulator, ShortintCiphertext, ShortintClientKey,
ShortintCompressedPublicKey, ShortintCompressedServerKey, ShortintPBSAccumulator,
ShortintPublicKey, ShortintServerKey,
ShortintCompressedCiphertext, ShortintCompressedPublicKey, ShortintCompressedServerKey,
ShortintPBSAccumulator, ShortintPublicKey, ShortintServerKey,
};
#[no_mangle]
@@ -79,6 +79,17 @@ pub unsafe extern "C" fn destroy_shortint_ciphertext(
})
}
#[no_mangle]
pub unsafe extern "C" fn destroy_shortint_compressed_ciphertext(
shortint_ciphertext: *mut ShortintCompressedCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(shortint_ciphertext).unwrap();
drop(Box::from_raw(shortint_ciphertext));
})
}
#[no_mangle]
pub unsafe extern "C" fn destroy_shortint_pbs_accumulator(
pbs_accumulator: *mut ShortintPBSAccumulator,

View File

@@ -10,7 +10,7 @@ use std::os::raw::c_int;
use crate::shortint;
pub use ciphertext::ShortintCiphertext;
pub use ciphertext::{ShortintCiphertext, ShortintCompressedCiphertext};
pub use client_key::ShortintClientKey;
pub use public_key::{ShortintCompressedPublicKey, ShortintPublicKey};
pub use server_key::pbs::{ShortintBivariatePBSAccumulator, ShortintPBSAccumulator};

View File

@@ -8,6 +8,9 @@ use std::panic::set_hook;
#[wasm_bindgen]
pub struct BooleanCiphertext(pub(crate) crate::boolean::ciphertext::Ciphertext);
#[wasm_bindgen]
pub struct BooleanCompressedCiphertext(pub(crate) crate::boolean::ciphertext::CompressedCiphertext);
#[wasm_bindgen]
pub struct BooleanClientKey(pub(crate) crate::boolean::client_key::ClientKey);
@@ -136,6 +139,23 @@ impl Boolean {
BooleanCiphertext(client_key.0.encrypt(message))
}
#[wasm_bindgen]
pub fn encrypt_compressed(
client_key: &BooleanClientKey,
message: bool,
) -> BooleanCompressedCiphertext {
set_hook(Box::new(console_error_panic_hook::hook));
BooleanCompressedCiphertext(client_key.0.encrypt_compressed(message))
}
#[wasm_bindgen]
pub fn decompress_ciphertext(
compressed_ciphertext: &BooleanCompressedCiphertext,
) -> BooleanCiphertext {
set_hook(Box::new(console_error_panic_hook::hook));
BooleanCiphertext(compressed_ciphertext.0.clone().into())
}
#[wasm_bindgen]
pub fn encrypt_with_public_key(
public_key: &BooleanPublicKey,
@@ -173,6 +193,25 @@ impl Boolean {
.map(BooleanCiphertext)
}
#[wasm_bindgen]
pub fn serialize_compressed_ciphertext(
ciphertext: &BooleanCompressedCiphertext,
) -> Result<Vec<u8>, JsError> {
set_hook(Box::new(console_error_panic_hook::hook));
bincode::serialize(&ciphertext.0)
.map_err(|e| wasm_bindgen::JsError::new(format!("{e:?}").as_str()))
}
#[wasm_bindgen]
pub fn deserialize_compressed_ciphertext(
buffer: &[u8],
) -> Result<BooleanCompressedCiphertext, JsError> {
set_hook(Box::new(console_error_panic_hook::hook));
bincode::deserialize(buffer)
.map_err(|e| wasm_bindgen::JsError::new(format!("{e:?}").as_str()))
.map(BooleanCompressedCiphertext)
}
#[wasm_bindgen]
pub fn serialize_client_key(client_key: &BooleanClientKey) -> Result<Vec<u8>, JsError> {
set_hook(Box::new(console_error_panic_hook::hook));

View File

@@ -8,6 +8,11 @@ use std::panic::set_hook;
#[wasm_bindgen]
pub struct ShortintCiphertext(pub(crate) crate::shortint::ciphertext::Ciphertext);
#[wasm_bindgen]
pub struct ShortintCompressedCiphertext(
pub(crate) crate::shortint::ciphertext::CompressedCiphertext,
);
#[wasm_bindgen]
pub struct ShortintClientKey(pub(crate) crate::shortint::ClientKey);
@@ -192,6 +197,24 @@ impl Shortint {
ShortintCiphertext(client_key.0.encrypt(message))
}
#[wasm_bindgen]
pub fn encrypt_compressed(
client_key: &ShortintClientKey,
message: u64,
) -> ShortintCompressedCiphertext {
set_hook(Box::new(console_error_panic_hook::hook));
ShortintCompressedCiphertext(client_key.0.encrypt_compressed(message))
}
#[wasm_bindgen]
pub fn decompress_ciphertext(
compressed_ciphertext: &ShortintCompressedCiphertext,
) -> ShortintCiphertext {
set_hook(Box::new(console_error_panic_hook::hook));
ShortintCiphertext(compressed_ciphertext.0.clone().into())
}
#[wasm_bindgen]
pub fn encrypt_with_public_key(
public_key: &ShortintPublicKey,
@@ -233,6 +256,25 @@ impl Shortint {
.map(ShortintCiphertext)
}
#[wasm_bindgen]
pub fn serialize_compressed_ciphertext(
ciphertext: &ShortintCompressedCiphertext,
) -> Result<Vec<u8>, JsError> {
set_hook(Box::new(console_error_panic_hook::hook));
bincode::serialize(&ciphertext.0)
.map_err(|e| wasm_bindgen::JsError::new(format!("{e:?}").as_str()))
}
#[wasm_bindgen]
pub fn deserialize_compressed_ciphertext(
buffer: &[u8],
) -> Result<ShortintCompressedCiphertext, JsError> {
set_hook(Box::new(console_error_panic_hook::hook));
bincode::deserialize(buffer)
.map_err(|e| wasm_bindgen::JsError::new(format!("{e:?}").as_str()))
.map(ShortintCompressedCiphertext)
}
#[wasm_bindgen]
pub fn serialize_client_key(client_key: &ShortintClientKey) -> Result<Vec<u8>, JsError> {
set_hook(Box::new(console_error_panic_hook::hook));