From 6948d91cc6e82cccdaeaa89b2544bbcba588f9fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Mon, 28 Apr 2025 14:28:02 +0200 Subject: [PATCH] feat(frontend-rust): support interop with tfhers ciphertexts --- .../include/concretelang/Common/Values.h | 2 +- .../Python/concrete/compiler/tfhers_int.py | 1 + .../compiler/lib/Bindings/Rust/CMakeLists.txt | 12 +- .../compiler/lib/Common/Values.cpp | 2 +- frontends/concrete-rust/Makefile | 24 +- .../concrete-keygen/src/keyasm.rs | 7 +- .../concrete-rust/concrete-macro/Cargo.toml | 1 + .../concrete-macro/src/fast_path_hasher.rs | 4 +- .../concrete-macro/src/generation.rs | 503 ++++++++++++++---- .../concrete-rust/concrete-macro/src/lib.rs | 96 ++-- .../concrete-rust/concrete-macro/src/unzip.rs | 3 +- frontends/concrete-rust/concrete/Cargo.toml | 3 + frontends/concrete-rust/concrete/build.rs | 6 +- frontends/concrete-rust/concrete/src/ffi.h | 89 +++- frontends/concrete-rust/concrete/src/ffi.rs | 17 + frontends/concrete-rust/concrete/src/lib.rs | 7 +- .../concrete-rust/concrete/src/protocol.rs | 120 +++-- .../concrete/src/tfhe/from_value.rs | 65 +++ .../concrete/src/tfhe/into_key.rs | 30 ++ .../concrete/src/tfhe/into_value.rs | 74 +++ .../concrete-rust/concrete/src/tfhe/mod.rs | 11 + .../concrete-rust/concrete/src/tfhe/spec.rs | 189 +++++++ .../concrete-rust/concrete/src/tfhe/types.rs | 133 +++++ .../src/utils}/configuration.rs | 93 +--- .../concrete/src/utils/from_value.rs | 19 + .../concrete/src/utils/into_value.rs | 12 + .../concrete-rust/concrete/src/utils/mod.rs | 4 + .../concrete/src/utils/python.rs | 94 ++++ frontends/concrete-rust/test/Cargo.toml | 1 + .../__pycache__/test_tfhers.cpython-310.pyc | Bin 0 -> 1132 bytes frontends/concrete-rust/test/python/test.py | 16 + .../concrete-rust/test/python/test_tfhers.py | 35 ++ frontends/concrete-rust/test/src/default.rs | 41 ++ frontends/concrete-rust/test/src/lib.rs | 38 +- frontends/concrete-rust/test/src/main.rs | 29 +- frontends/concrete-rust/test/src/test.zip | Bin 1696 -> 2456 bytes .../concrete-rust/test/src/test_tfhers.zip | Bin 0 -> 3633 bytes frontends/concrete-rust/test/src/tfhers.rs | 51 ++ .../src/concrete-protocol.capnp | 216 ++++---- 39 files changed, 1583 insertions(+), 465 deletions(-) create mode 100644 frontends/concrete-rust/concrete/src/tfhe/from_value.rs create mode 100644 frontends/concrete-rust/concrete/src/tfhe/into_key.rs create mode 100644 frontends/concrete-rust/concrete/src/tfhe/into_value.rs create mode 100644 frontends/concrete-rust/concrete/src/tfhe/mod.rs create mode 100644 frontends/concrete-rust/concrete/src/tfhe/spec.rs create mode 100644 frontends/concrete-rust/concrete/src/tfhe/types.rs rename frontends/concrete-rust/{concrete-macro/src => concrete/src/utils}/configuration.rs (61%) create mode 100644 frontends/concrete-rust/concrete/src/utils/from_value.rs create mode 100644 frontends/concrete-rust/concrete/src/utils/into_value.rs create mode 100644 frontends/concrete-rust/concrete/src/utils/mod.rs create mode 100644 frontends/concrete-rust/concrete/src/utils/python.rs create mode 100644 frontends/concrete-rust/test/python/__pycache__/test_tfhers.cpython-310.pyc create mode 100644 frontends/concrete-rust/test/python/test.py create mode 100644 frontends/concrete-rust/test/python/test_tfhers.py create mode 100644 frontends/concrete-rust/test/src/default.rs create mode 100644 frontends/concrete-rust/test/src/test_tfhers.zip create mode 100644 frontends/concrete-rust/test/src/tfhers.rs diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h index aff5b92e7..173c60ea1 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Values.h @@ -159,7 +159,7 @@ struct Value { /// Turns a server value to a client value, without interpreting the kind of /// value. - static Value fromRawTransportValue(TransportValue transportVal); + static Value fromRawTransportValue(const TransportValue &transportVal); /// Turns a client value to a raw (without kind info attached) server value. TransportValue intoRawTransportValue() const; diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py index 17b9e6513..f3c2cb25a 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/tfhers_int.py @@ -22,6 +22,7 @@ class TfhersExporter: """Convert Concrete value to TFHErs and serialize it. Args: + value (Value): value to export info (TfhersFheIntDescription): description of the TFHErs integer to export to diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/CMakeLists.txt index ba544d44f..aa28ea325 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/CMakeLists.txt @@ -12,8 +12,16 @@ if(LINUX) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,-Bsymbolic") endif() -target_link_libraries(ConcreteRust PRIVATE ConcretelangSupport ConcretelangClientLib ConcretelangServerLib - ConcretelangRuntimeStatic) +target_link_libraries( + ConcreteRust + PRIVATE ConcretelangSupport + ConcretelangClientLib + ConcretelangServerLib + ConcretelangRuntimeStatic + LLVMSupport + capnp + capnp-json + kj) if(APPLE) find_library(SECURITY_FRAMEWORK Security) diff --git a/compilers/concrete-compiler/compiler/lib/Common/Values.cpp b/compilers/concrete-compiler/compiler/lib/Common/Values.cpp index 9dbd5e2db..f64bbc87f 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Values.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Values.cpp @@ -25,7 +25,7 @@ using concretelang::protocol::vectorToProtoPayload; namespace concretelang { namespace values { -Value Value::fromRawTransportValue(TransportValue transportVal) { +Value Value::fromRawTransportValue(const TransportValue &transportVal) { Value output; auto integerPrecision = transportVal.asReader().getRawInfo().getIntegerPrecision(); diff --git a/frontends/concrete-rust/Makefile b/frontends/concrete-rust/Makefile index 844eb4cfa..ce4661212 100644 --- a/frontends/concrete-rust/Makefile +++ b/frontends/concrete-rust/Makefile @@ -1,8 +1,28 @@ CARGO_SUBDIRS := concrete concrete-macro test concrete-keygen -.PHONY: test +.PHONY: test clean test: @for dir in $(CARGO_SUBDIRS); do \ echo "Running cargo test in $$dir..."; \ - COMPILER_BUILD_DIRECTORY=../../../compilers/concrete-compiler/compiler/build cargo test --manifest-path $$dir/Cargo.toml || exit 1; \ + COMPILER_BUILD_DIRECTORY=../../../compilers/concrete-compiler/compiler/build cargo test --all-features --manifest-path $$dir/Cargo.toml -- --nocapture || exit 1; \ done + +clean: + @for dir in $(CARGO_SUBDIRS); do \ + echo "Cleaning target folder in $$dir..."; \ + rm -rf $$dir/target || exit 1; \ + done + +run: + cd test && \ + COMPILER_BUILD_DIRECTORY=../../../compilers/concrete-compiler/compiler/build cargo run + +format: + @for dir in $(CARGO_SUBDIRS); do \ + echo "Running format in $$dir..."; \ + COMPILER_BUILD_DIRECTORY=$(COMPILER_BUILD_DIRECTORY) cargo fmt --manifest-path $$dir/Cargo.toml || exit 1; \ + done + +regen_test_zips: + python test/python/test.py + python test/python/test_tfhers.py diff --git a/frontends/concrete-rust/concrete-keygen/src/keyasm.rs b/frontends/concrete-rust/concrete-keygen/src/keyasm.rs index f6bc071db..950f4456a 100644 --- a/frontends/concrete-rust/concrete-keygen/src/keyasm.rs +++ b/frontends/concrete-rust/concrete-keygen/src/keyasm.rs @@ -30,8 +30,11 @@ fn assemble_keyset_from_zip( // Read keyset info let mut keyset_info_file = archive.by_name(KEYSET_INFO_FILENAME).unwrap(); let mut keyset_info_buffer = Vec::new(); - keyset_info_file.read_to_end(&mut keyset_info_buffer).unwrap(); - let keyset_info_message = concrete_protocol_capnp::read_capnp_from_buffer(&keyset_info_buffer).unwrap(); + keyset_info_file + .read_to_end(&mut keyset_info_buffer) + .unwrap(); + let keyset_info_message = + concrete_protocol_capnp::read_capnp_from_buffer(&keyset_info_buffer).unwrap(); let keyset_info_proto: concrete_protocol_capnp::keyset_info::Reader<'_> = concrete_protocol_capnp::get_reader_from_message(&keyset_info_message).unwrap(); diff --git a/frontends/concrete-rust/concrete-macro/Cargo.toml b/frontends/concrete-rust/concrete-macro/Cargo.toml index 47b3a2bac..bddecd27e 100644 --- a/frontends/concrete-rust/concrete-macro/Cargo.toml +++ b/frontends/concrete-rust/concrete-macro/Cargo.toml @@ -21,4 +21,5 @@ proc-macro2 = "1.0" zip = "2.2.2" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +itertools = "0.14.0" concrete = { version = "2.10.1", path = "../concrete", features = ["compiler"]} diff --git a/frontends/concrete-rust/concrete-macro/src/fast_path_hasher.rs b/frontends/concrete-rust/concrete-macro/src/fast_path_hasher.rs index d8b08f122..fb9dd1adf 100644 --- a/frontends/concrete-rust/concrete-macro/src/fast_path_hasher.rs +++ b/frontends/concrete-rust/concrete-macro/src/fast_path_hasher.rs @@ -1,4 +1,6 @@ -use std::{hash::{Hash, Hasher}, os::unix::fs::MetadataExt, path::PathBuf}; +use std::hash::{Hash, Hasher}; +use std::os::unix::fs::MetadataExt; +use std::path::PathBuf; pub struct FastPathHasher { path: PathBuf, diff --git a/frontends/concrete-rust/concrete-macro/src/generation.rs b/frontends/concrete-rust/concrete-macro/src/generation.rs index 51aad87c5..057215cfd 100644 --- a/frontends/concrete-rust/concrete-macro/src/generation.rs +++ b/frontends/concrete-rust/concrete-macro/src/generation.rs @@ -1,7 +1,49 @@ -use concrete::protocol::{CircuitInfo, ProgramInfo}; +use concrete::protocol::{CircuitInfo, ProgramInfo, TypeInfo}; +use concrete::tfhe::{FunctionSpec, IntegerType}; +use itertools::multizip; use quote::{format_ident, quote}; -pub fn generate_unsafe_binding(pi: &ProgramInfo) -> proc_macro2::TokenStream { +pub fn generate(pi: &ProgramInfo, hash: u64) -> proc_macro2::TokenStream { + let lib_name = format!("concrete-artifact-{hash}"); + let unsafe_binding = generate_unsafe_binding(&pi); + let infos = generate_infos(&pi); + let keyset = generate_keyset(&pi); + let client = generate_client(&pi); + let server = generate_server(&pi); + + let links = if cfg!(target_os = "macos") { + quote! { + #[link(name = "ConcretelangRuntime")] + #[link(name = "omp")] + } + } else if cfg!(target_os = "linux") { + quote! { + #[link(name = "ConcretelangRuntime")] + #[link(name = "hpx_iostreams")] + #[link(name = "hpx_core")] + #[link(name = "hpx")] + #[link(name = "omp")] + } + } else { + panic!("Unsupported platform"); + }; + + quote! { + #infos + #keyset + #client + #server + + #[doc(hidden)] + pub mod _binding { + #[link(name = #lib_name, kind="static")] + #links + #unsafe_binding + } + } +} + +fn generate_unsafe_binding(pi: &ProgramInfo) -> proc_macro2::TokenStream { let func_defs = pi .circuits .iter() @@ -19,7 +61,7 @@ pub fn generate_unsafe_binding(pi: &ProgramInfo) -> proc_macro2::TokenStream { } } -pub fn generate_infos(pi: &ProgramInfo) -> proc_macro2::TokenStream { +fn generate_infos(pi: &ProgramInfo) -> proc_macro2::TokenStream { quote! { pub static PROGRAM_INFO: std::sync::LazyLock<::concrete::protocol::ProgramInfo> = std::sync::LazyLock::new(|| { #pi @@ -27,26 +69,115 @@ pub fn generate_infos(pi: &ProgramInfo) -> proc_macro2::TokenStream { } } -pub fn generate_keyset() -> proc_macro2::TokenStream { +fn generate_keyset(pi: &ProgramInfo) -> proc_macro2::TokenStream { + let mut need_providing = pi + .circuits + .iter() + .flat_map(|ci| { + multizip(( + std::iter::repeat(&ci.name), + std::iter::successors(Some(0), |a| Some(a + 1)), + ci.inputs.iter(), + pi.tfhers_specs + .as_ref() + .unwrap() + .get_func(&ci.name) + .unwrap() + .input_types + .iter(), + )) + }) + .filter(|(_, _, _, spec)| spec.is_some()) + .collect::>(); + + need_providing.dedup_by_key(|a| { + let TypeInfo::lweCiphertext(ref ct) = a.2.typeInfo else { + unreachable!() + }; + ct.encryption.keyId + }); + + let fields_idents = need_providing + .iter() + .map(|(func_name, ith, ..)| format_ident!("{func_name}_{ith}")) + .collect::>(); + + let fields_types = need_providing + .iter() + .map(|_| { + quote! { Option>} + }) + .collect::>(); + + let methods = need_providing.iter().map(|(func_name, ith, gi, ..)| { + let method_name = format_ident!("with_key_for_{func_name}_{ith}_arg"); + let ident = format_ident!("{func_name}_{ith}"); + let TypeInfo::lweCiphertext(ref ct) = gi.typeInfo else { + unreachable!() + }; + let kid = ct.encryption.keyId; + quote! { + pub fn #method_name(mut self, key: &::tfhe::ClientKey) -> Self { + let mut key = Some(<::tfhe::ClientKey as concrete::tfhe::IntoLweSecretKey>::into_lwe_secret_key(key, Some(#kid))); + if self.#ident.is_some() { + assert_eq!(self.#ident.as_mut().unwrap().pin_mut().get_buffer(), key.as_mut().unwrap().pin_mut().get_buffer(), "Tried to set the same underlying key twice, with a different key. Something must be wrong..."); + return self; + } + self.#ident = key; + self + } + } + }); + quote! { - pub fn new_keyset( - secret_csprng: std::pin::Pin<&mut ::concrete::common::SecretCsprng>, - encryption_csprng: std::pin::Pin<&mut ::concrete::common::EncryptionCsprng> - ) -> ::concrete::UniquePtr<::concrete::common::Keyset> { - ::concrete::common::Keyset::new( - &PROGRAM_INFO.keyset, - secret_csprng, - encryption_csprng - ) + #[derive(Default)] + pub struct KeysetBuilder{ + #(#fields_idents: #fields_types),* + } + + impl KeysetBuilder { + pub fn new() -> Self { + return Self::default(); + } + + #(#methods),* + + pub fn generate( + self, + secret_csprng: std::pin::Pin<&mut ::concrete::common::SecretCsprng>, + encryption_csprng: std::pin::Pin<&mut ::concrete::common::EncryptionCsprng> + ) -> ::concrete::UniquePtr<::concrete::common::Keyset> { + ::concrete::common::Keyset::new( + &PROGRAM_INFO.keyset, + secret_csprng, + encryption_csprng, + vec![ + #( + self.#fields_idents.expect(concat!("Missing tfhers key ", stringify!(#fields_idents))) + ),* + ] + ) + } } } } -pub(crate) fn generate_client(program_info: &ProgramInfo) -> proc_macro2::TokenStream { +fn generate_client(program_info: &ProgramInfo) -> proc_macro2::TokenStream { let client_functions = program_info .circuits .iter() - .map(|ci| generate_client_function(ci)); + .map(|ci| { + ( + ci, + program_info + .tfhers_specs + .as_ref() + .unwrap() + .get_func(&ci.name) + .unwrap(), + ) + }) + .map(|(ci, ts)| generate_client_function(ci, Some(ts))); quote! { pub mod client { #(#client_functions)* @@ -54,87 +185,120 @@ pub(crate) fn generate_client(program_info: &ProgramInfo) -> proc_macro2::TokenS } } -fn generate_client_function_prepare_inputs(circuit_info: &CircuitInfo) -> proc_macro2::TokenStream { +fn generate_client_function_prepare_inputs( + circuit_info: &CircuitInfo, + tfhers_spec: Option, +) -> proc_macro2::TokenStream { let ith = (0..circuit_info.inputs.len()).collect::>(); - let input_idents = circuit_info.inputs.iter().enumerate().map(|(ith, _)| format_ident!("arg_{ith}")).collect::>(); - let input_types = circuit_info.inputs.iter().map(|gi| { - match (gi.rawInfo.integerPrecision, gi.rawInfo.isSigned) { - (8, true) => quote! {::concrete::common::Tensor}, - (8, false) => quote! {::concrete::common::Tensor}, - (16, true) => quote! {::concrete::common::Tensor}, - (16, false) => quote! {::concrete::common::Tensor}, - (32, true) => quote! {::concrete::common::Tensor}, - (32, false) => quote! {::concrete::common::Tensor}, - (64, true) => quote! {::concrete::common::Tensor}, - (64, false) => quote! {::concrete::common::Tensor}, - _ => unreachable!(), - } - }).collect::>(); - let output_types = circuit_info.inputs.iter().map(|_| { - quote! {::concrete::UniquePtr<::concrete::common::TransportValue>} - }).collect::>(); - - quote!{ + let input_specs = tfhers_spec.map_or(vec![None; circuit_info.inputs.len()], |v| { + v.input_types.to_owned() + }); + let input_idents = circuit_info + .inputs + .iter() + .enumerate() + .map(|(ith, _)| format_ident!("arg_{ith}")) + .collect::>(); + let input_types = multizip((input_specs.iter(), circuit_info.inputs.iter())) + .map(|(spec, gate_info)| { + match ( + spec, + gate_info.rawInfo.integerPrecision, + gate_info.rawInfo.isSigned, + ) { + (Some(_), _, _) => quote! {()}, + (_, 8, true) => quote! {::concrete::common::Tensor}, + (_, 8, false) => quote! {::concrete::common::Tensor}, + (_, 16, true) => quote! {::concrete::common::Tensor}, + (_, 16, false) => quote! {::concrete::common::Tensor}, + (_, 32, true) => quote! {::concrete::common::Tensor}, + (_, 32, false) => quote! {::concrete::common::Tensor}, + (_, 64, true) => quote! {::concrete::common::Tensor}, + (_, 64, false) => quote! {::concrete::common::Tensor}, + _ => unreachable!(), + } + }) + .collect::>(); + let output_types = input_specs + .iter() + .map(|spec| match spec { + Some(_) => quote! {()}, + None => quote! {::concrete::UniquePtr<::concrete::common::TransportValue>}, + }) + .collect::>(); + let preparations = multizip((input_specs.iter(), input_types.iter(), input_idents.iter(), ith.iter())) + .map(|(spec, typ, ident, ith)|{ + match spec { + Some(..) => quote!{()}, + None => quote!{self.0.pin_mut().prepare_input(<#typ as ::concrete::utils::into_value::IntoValue>::into_value(#ident), #ith)} + } + }); + quote! { pub fn prepare_inputs(&mut self, #(#input_idents: #input_types),*) -> (#(#output_types),*) { - ( - #( - self.0.pin_mut().prepare_input(::concrete::common::Value::from_tensor(#input_idents), #ith) - ),* - ) + (#(#preparations),*) } } } -fn generate_client_function_process_outputs(circuit_info: &CircuitInfo) -> proc_macro2::TokenStream { +fn generate_client_function_process_outputs( + circuit_info: &CircuitInfo, + tfhers_spec: Option, +) -> proc_macro2::TokenStream { let ith = (0..circuit_info.outputs.len()).collect::>(); - let input_idents = circuit_info.outputs.iter().enumerate().map(|(ith, _)| format_ident!("res_{ith}")).collect::>(); - let input_types = circuit_info.outputs.iter().map(|_| { - quote! {::concrete::UniquePtr<::concrete::common::TransportValue>} - }).collect::>(); - let output_types = circuit_info.outputs.iter().map(|gi| { - match (gi.rawInfo.integerPrecision, gi.rawInfo.isSigned) { - (8, true) => quote! {::concrete::common::Tensor}, - (8, false) => quote! {::concrete::common::Tensor}, - (16, true) => quote! {::concrete::common::Tensor}, - (16, false) => quote! {::concrete::common::Tensor}, - (32, true) => quote! {::concrete::common::Tensor}, - (32, false) => quote! {::concrete::common::Tensor}, - (64, true) => quote! {::concrete::common::Tensor}, - (64, false) => quote! {::concrete::common::Tensor}, - _ => unreachable!(), - } - }).collect::>(); - let output_unwrap = circuit_info.outputs.iter().map(|gi| { - match (gi.rawInfo.integerPrecision, gi.rawInfo.isSigned) { - (8, true) => quote! {get_tensor::().unwrap()}, - (8, false) => quote! {get_tensor::().unwrap()}, - (16, true) => quote! {get_tensor::().unwrap()}, - (16, false) => quote! {get_tensor::().unwrap()}, - (32, true) => quote! {get_tensor::().unwrap()}, - (32, false) => quote! {get_tensor::().unwrap()}, - (64, true) => quote! {get_tensor::().unwrap()}, - (64, false) => quote! {get_tensor::().unwrap()}, - _ => unreachable!(), - } - }).collect::>(); - - - quote!{ + let output_specs = tfhers_spec.map_or(vec![None; circuit_info.outputs.len()], |v| { + v.output_types.to_owned() + }); + let input_idents = circuit_info + .outputs + .iter() + .enumerate() + .map(|(ith, _)| format_ident!("res_{ith}")) + .collect::>(); + let input_types = output_specs + .iter() + .map(|spec| match spec { + Some(_) => quote! {()}, + None => quote! {::concrete::UniquePtr<::concrete::common::TransportValue>}, + }) + .collect::>(); + let output_types = multizip((circuit_info.outputs.iter(), output_specs.iter())) + .map( + |(gi, ts)| match (ts, gi.rawInfo.integerPrecision, gi.rawInfo.isSigned) { + (Some(_), _, _) => quote! {()}, + (_, 8, true) => quote! {::concrete::common::Tensor}, + (_, 8, false) => quote! {::concrete::common::Tensor}, + (_, 16, true) => quote! {::concrete::common::Tensor}, + (_, 16, false) => quote! {::concrete::common::Tensor}, + (_, 32, true) => quote! {::concrete::common::Tensor}, + (_, 32, false) => quote! {::concrete::common::Tensor}, + (_, 64, true) => quote! {::concrete::common::Tensor}, + (_, 64, false) => quote! {::concrete::common::Tensor}, + _ => unreachable!(), + }, + ) + .collect::>(); + let unwrappers = multizip((output_specs.iter(), output_types.iter(), input_idents.iter(), ith.iter())) + .map(|(spec, typ, ident, ith)|{ + match spec { + Some(_) => quote!{()}, + None => quote!{<#typ as ::concrete::utils::from_value::FromValue>::from_value((), self.0.pin_mut().process_output(#ident, #ith))}, + } + }); + quote! { pub fn process_outputs(&mut self, #(#input_idents: #input_types),*) -> (#(#output_types),*) { - ( - #( - self.0.pin_mut().process_output(#input_idents, #ith).#output_unwrap - ),* - ) + (#(#unwrappers),*) } } } -fn generate_client_function(circuit_info: &CircuitInfo) -> proc_macro2::TokenStream { +fn generate_client_function( + circuit_info: &CircuitInfo, + tfhers_spec: Option, +) -> proc_macro2::TokenStream { let function_identifier = format_ident!("{}", circuit_info.name); - - let prepare_inputs = generate_client_function_prepare_inputs(circuit_info); - let process_outputs = generate_client_function_process_outputs(circuit_info); + let prepare_inputs = generate_client_function_prepare_inputs(circuit_info, tfhers_spec.clone()); + let process_outputs = + generate_client_function_process_outputs(circuit_info, tfhers_spec.clone()); quote! { pub mod #function_identifier{ @@ -166,11 +330,22 @@ fn generate_client_function(circuit_info: &CircuitInfo) -> proc_macro2::TokenStr } } -pub(crate) fn generate_server(program_info: &ProgramInfo) -> proc_macro2::TokenStream { +fn generate_server(program_info: &ProgramInfo) -> proc_macro2::TokenStream { let server_functions = program_info .circuits .iter() - .map(|ci| generate_server_function(ci)); + .map(|ci| { + ( + ci, + program_info + .tfhers_specs + .as_ref() + .unwrap() + .get_func(&ci.name) + .unwrap(), + ) + }) + .map(|(ci, spec)| generate_server_function(ci, Some(spec))); quote! { pub mod server { #(#server_functions)* @@ -178,11 +353,13 @@ pub(crate) fn generate_server(program_info: &ProgramInfo) -> proc_macro2::TokenS } } -fn generate_server_function(circuit_info: &CircuitInfo) -> proc_macro2::TokenStream { - +fn generate_server_function( + circuit_info: &CircuitInfo, + tfhers_spec: Option, +) -> proc_macro2::TokenStream { let function_identifier = format_ident!("{}", circuit_info.name); let binding_identifier = format_ident!("_mlir_concrete_{}", circuit_info.name); - let invoke = generate_server_function_invoke(circuit_info); + let invoke = generate_server_function_invoke(circuit_info, tfhers_spec); quote! { pub mod #function_identifier{ @@ -207,26 +384,150 @@ fn generate_server_function(circuit_info: &CircuitInfo) -> proc_macro2::TokenStr } } } - } -fn generate_server_function_invoke(circuit_info: &CircuitInfo) -> proc_macro2::TokenStream { - let args_idents = (0..circuit_info.inputs.len()).map(|a| format_ident!("arg_{a}")).collect::>(); - let args_types = (0..circuit_info.inputs.len()).map(|_| quote!{::concrete::UniquePtr<::concrete::common::TransportValue>}).collect::>(); - let results_idents = (0..circuit_info.outputs.len()).map(|a| format_ident!("res_{a}")).collect::>(); - let results_types = (0..circuit_info.outputs.len()).map(|_| quote!{::concrete::UniquePtr<::concrete::common::TransportValue>}).collect::>(); +fn generate_types(ts: &Option) -> proc_macro2::TokenStream { + match ts { + Some(IntegerType { + bit_width: 2, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt2}, + Some(IntegerType { + bit_width: 2, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint2}, + Some(IntegerType { + bit_width: 4, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt4}, + Some(IntegerType { + bit_width: 4, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint4}, + Some(IntegerType { + bit_width: 6, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt6}, + Some(IntegerType { + bit_width: 6, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint6}, + Some(IntegerType { + bit_width: 8, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt8}, + Some(IntegerType { + bit_width: 8, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint8}, + Some(IntegerType { + bit_width: 10, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt10}, + Some(IntegerType { + bit_width: 10, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint10}, + Some(IntegerType { + bit_width: 12, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt12}, + Some(IntegerType { + bit_width: 12, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint12}, + Some(IntegerType { + bit_width: 14, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt14}, + Some(IntegerType { + bit_width: 14, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint14}, + Some(IntegerType { + bit_width: 16, + is_signed: true, + .. + }) => quote! {::tfhe::FheInt16}, + Some(IntegerType { + bit_width: 16, + is_signed: false, + .. + }) => quote! {::tfhe::FheUint16}, + None => quote! {::concrete::UniquePtr<::concrete::common::TransportValue>}, + _ => unreachable!(), + } +} + +fn generate_server_function_invoke( + circuit_info: &CircuitInfo, + tfhers_spec: Option, +) -> proc_macro2::TokenStream { + let input_specs = tfhers_spec + .clone() + .map_or(vec![None; circuit_info.inputs.len()], |v| { + v.input_types.to_owned() + }); + let output_specs = tfhers_spec + .clone() + .map_or(vec![None; circuit_info.outputs.len()], |v| { + v.output_types.to_owned() + }); + let args_idents = (0..circuit_info.inputs.len()) + .map(|a| format_ident!("arg_{a}")) + .collect::>(); + let args_types = input_specs + .iter() + .map(|s| generate_types(s)) + .collect::>(); + let results_idents = (0..circuit_info.outputs.len()) + .map(|a| format_ident!("res_{a}")) + .collect::>(); + let results_types = output_specs + .iter() + .map(|s| generate_types(s)) + .collect::>(); let output_len = circuit_info.outputs.len(); - quote!{ + let preludes = multizip((input_specs.iter(), args_idents.iter(), args_types.iter(), circuit_info.inputs.iter())) + .map(|(spec, ident, typ, gi)| { + match spec{ + Some(_) => { + let type_info_json_string = serde_json::to_string(&gi.typeInfo).unwrap(); + quote!{ <#typ as ::concrete::utils::into_value::IntoValue>::into_value(#ident).into_transport_value(#type_info_json_string)} + } + None => quote!{#ident} + } + }); + + let postludes = multizip((output_specs.iter(), results_idents.iter(), results_types.iter())) + .map(|(spec, ident, typ)|{ + match spec { + Some(s) => quote!{ <#typ as ::concrete::utils::from_value::FromValue>::from_value(#s, #ident.to_value())}, + None => quote!{#ident} + } + }); + + quote! { pub fn invoke(&mut self, server_keyset: &::concrete::common::ServerKeyset, #(#args_idents: #args_types),*) -> (#(#results_types),*) { - let inputs = vec![ - #(#args_idents),* - ]; + let inputs = vec![#(#preludes),*]; let output = self.0.pin_mut().call(server_keyset, inputs); let [#(#results_idents),*] = <[::concrete::UniquePtr<::concrete::common::TransportValue>; #output_len]>::try_from(output).unwrap(); - ( - #(#results_idents),* - ) + (#(#postludes),*) } } } diff --git a/frontends/concrete-rust/concrete-macro/src/lib.rs b/frontends/concrete-rust/concrete-macro/src/lib.rs index a476fc054..f93a018e6 100644 --- a/frontends/concrete-rust/concrete-macro/src/lib.rs +++ b/frontends/concrete-rust/concrete-macro/src/lib.rs @@ -1,14 +1,13 @@ #![allow(stable_features)] #![feature(file_lock)] -#[allow(unused)] -use quote::quote; -use concrete::{compiler, protocol::ProgramInfo}; -use configuration::Configuration; +use concrete::compiler; +use concrete::protocol::ProgramInfo; use proc_macro::{ TokenStream, {self}, }; -use std::{fs::read_to_string, path::PathBuf}; +use std::fs::read_to_string; use std::hash::{DefaultHasher, Hash, Hasher}; +use std::path::PathBuf; use syn::LitStr; const CONCRETE_BUILD_DIR: &'static str = env!("CONCRETE_BUILD_DIR"); @@ -17,14 +16,14 @@ const PATH_PROGRAM_INFO: &'static str = "program_info.concrete.params.json"; const PATH_CIRCUIT: &'static str = "circuit.mlir"; const PATH_COMPOSITION_RULES: &'static str = "composition_rules.json"; const PATH_SIMULATED: &'static str = "is_simulated"; +const CLIENT_SPECS: &'static str = "client.specs.json"; const PATH_CONFIGURATION: &'static str = "configuration.json"; const DEFAULT_GLOBAL_P_ERROR: Option = Some(0.00001); const DEFAULT_P_ERROR: Option = None; -mod configuration; mod fast_path_hasher; -mod unzip; mod generation; +mod unzip; #[proc_macro] pub fn from_concrete_python_export_zip(input: TokenStream) -> TokenStream { @@ -61,7 +60,9 @@ pub fn from_concrete_python_export_zip(input: TokenStream) -> TokenStream { .open(concrete_build_dir.join(format!("{hash_val}.lock"))) .expect("Failed to open lock file."); - lock_file.lock().expect("Failed to acquire lock on the lock file"); + lock_file + .lock() + .expect("Failed to acquire lock on the lock file"); let concrete_hash_dir = concrete_build_dir.join(format!("{hash_val}")); if !concrete_hash_dir.exists() { @@ -91,15 +92,25 @@ pub fn from_concrete_python_export_zip(input: TokenStream) -> TokenStream { if !config_path.exists() { panic!("Missing `configuration.json` file in the export. Did you save your server with the `via_mlir` option ?"); } - let configuration_string = read_to_string(config_path).expect("Failed to read configuration to string"); - let conf: Configuration = serde_json::from_str(configuration_string.as_str()).expect("Failed to deserialize configuration"); + let configuration_string = + read_to_string(config_path).expect("Failed to read configuration to string"); + let conf: concrete::utils::configuration::Configuration = + serde_json::from_str(configuration_string.as_str()) + .expect("Failed to deserialize configuration"); + + let client_specs_path = concrete_hash_dir.join(CLIENT_SPECS); + if !client_specs_path.exists() { + panic!("Missing `client.specs.json` file in the export. Did you save your server with the `via_mlir` option ?"); + } if !composition_rules_path.exists() { panic!("Missing `composition_rules.json` file in the export. Did you save your server with the `via_mlir` option ?"); } - let composition_rules_string = read_to_string(composition_rules_path).expect("Failed to read composition rules to string"); + let composition_rules_string = read_to_string(composition_rules_path) + .expect("Failed to read composition rules to string"); let composition_rules: Vec = - serde_json::from_str(composition_rules_string.as_str()).expect("Failed to deserialize composition rules"); + serde_json::from_str(composition_rules_string.as_str()) + .expect("Failed to deserialize composition rules"); let mut opts = compiler::CompilationOptions::new(); opts.pin_mut() @@ -133,22 +144,22 @@ pub fn from_concrete_python_export_zip(input: TokenStream) -> TokenStream { } match conf.parameter_selection_strategy { - configuration::ParameterSelectionStrategy::V0 => { + concrete::utils::configuration::ParameterSelectionStrategy::V0 => { opts.pin_mut().set_optimizer_strategy(0) } - configuration::ParameterSelectionStrategy::Mono => { + concrete::utils::configuration::ParameterSelectionStrategy::Mono => { opts.pin_mut().set_optimizer_strategy(1) } - configuration::ParameterSelectionStrategy::Multi => { + concrete::utils::configuration::ParameterSelectionStrategy::Multi => { opts.pin_mut().set_optimizer_strategy(2) } } match conf.multi_parameter_strategy { - configuration::MultiParameterStrategy::Precision => { + concrete::utils::configuration::MultiParameterStrategy::Precision => { opts.pin_mut().set_optimizer_multi_parameter_strategy(0) } - configuration::MultiParameterStrategy::PrecisionAndNorm2 => { + concrete::utils::configuration::MultiParameterStrategy::PrecisionAndNorm2 => { opts.pin_mut().set_optimizer_multi_parameter_strategy(1) } } @@ -163,8 +174,12 @@ pub fn from_concrete_python_export_zip(input: TokenStream) -> TokenStream { opts.pin_mut() .set_keyset_restriction(&conf.keyset_restriction.map(|a| a.0).unwrap_or("".into())); match conf.security_level { - configuration::SecurityLevel::Security128Bits => opts.pin_mut().set_security_level(128), - configuration::SecurityLevel::Security132Bits => opts.pin_mut().set_security_level(132), + concrete::utils::configuration::SecurityLevel::Security128Bits => { + opts.pin_mut().set_security_level(128) + } + concrete::utils::configuration::SecurityLevel::Security132Bits => { + opts.pin_mut().set_security_level(132) + } } for rule in composition_rules { @@ -208,40 +223,29 @@ pub fn from_concrete_python_export_zip(input: TokenStream) -> TokenStream { let output_path = concrete_build_dir.join(format!("libconcrete-artifact-{hash_val}.a")); if !output_path.exists() { - std::fs::copy(concrete_hash_dir.join(PATH_STATIC_LIB), output_path) - .unwrap(); + std::fs::copy(concrete_hash_dir.join(PATH_STATIC_LIB), output_path).unwrap(); } + let client_specs_path = concrete_hash_dir.join(CLIENT_SPECS); + if !client_specs_path.exists() { + panic!("Missing `client.specs.json` file in the export. Did you save your server with the `via_mlir` option ?"); + } + let mut client_specs: ProgramInfo = + serde_json::from_reader(std::fs::File::open(client_specs_path).unwrap()).unwrap(); + client_specs.eventually_patch_tfhers_specs(); + let concrete_program_info_path = concrete_hash_dir.join(PATH_PROGRAM_INFO); if !concrete_program_info_path.exists() { panic!("Missing `program_info.concrete.params.json` file after compilation. Something is wrong. Delete target folder and re-compile."); } - let program_info: ProgramInfo = serde_json::from_reader( - std::fs::File::open(concrete_program_info_path).unwrap(), - ) - .unwrap(); + let mut program_info: ProgramInfo = + serde_json::from_reader(std::fs::File::open(concrete_program_info_path).unwrap()).unwrap(); + + program_info.tfhers_specs = client_specs.tfhers_specs.clone(); + + // assert_eq!(client_specs, program_info, "Export client specs, and compiled program info do not match. Something is wrong. Get in touch with developers."); lock_file.unlock().unwrap(); - let lib_name = format!("concrete-artifact-{hash_val}"); - let unsafe_binding = generation::generate_unsafe_binding(&program_info); - let infos = generation::generate_infos(&program_info); - let keyset = generation::generate_keyset(); - let client = generation::generate_client(&program_info); - let server = generation::generate_server(&program_info); - - quote! { - #infos - #keyset - #client - #server - - #[doc(hidden)] - pub mod _binding { - #[link(name = "ConcretelangRuntime", kind="dylib")] - #[link(name = #lib_name, kind="static")] - #unsafe_binding - } - } - .into() + generation::generate(&program_info, hash_val).into() } diff --git a/frontends/concrete-rust/concrete-macro/src/unzip.rs b/frontends/concrete-rust/concrete-macro/src/unzip.rs index 6f690de32..a8bfc18ab 100644 --- a/frontends/concrete-rust/concrete-macro/src/unzip.rs +++ b/frontends/concrete-rust/concrete-macro/src/unzip.rs @@ -28,8 +28,7 @@ pub fn unzip(zip_path: &Path, to: &Path) { use std::os::unix::fs::PermissionsExt; if let Some(mode) = file.unix_mode() { - std::fs::set_permissions(&outpath, std::fs::Permissions::from_mode(mode)) - .unwrap(); + std::fs::set_permissions(&outpath, std::fs::Permissions::from_mode(mode)).unwrap(); } } } diff --git a/frontends/concrete-rust/concrete/Cargo.toml b/frontends/concrete-rust/concrete/Cargo.toml index c89f59599..02bcabf8c 100644 --- a/frontends/concrete-rust/concrete/Cargo.toml +++ b/frontends/concrete-rust/concrete/Cargo.toml @@ -18,6 +18,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" quote = {version = "1.0", optional =true } proc-macro2 = {version = "1.0", optional = true} +tfhe = {version = "1.1", features = ["integer"], optional = true} [build-dependencies] zip = "2.6" @@ -28,4 +29,6 @@ cxx-build = "1.0" ar = "0.9" [features] +default = ["tfhe-rs", "compiler"] compiler = ["dep:quote", "dep:proc-macro2"] +tfhe-rs = ["dep:tfhe"] diff --git a/frontends/concrete-rust/concrete/build.rs b/frontends/concrete-rust/concrete/build.rs index 64da7a21a..c9c0ec6df 100644 --- a/frontends/concrete-rust/concrete/build.rs +++ b/frontends/concrete-rust/concrete/build.rs @@ -9,7 +9,7 @@ const ARTIFACTS: &[&str] = if cfg!(target_os = "macos") { &[ "libConcreteRust.dylib", "libConcretelangRuntime.dylib", - "libomp.dylib" + "libomp.dylib", ] } else if cfg!(target_os = "linux") { &[ @@ -76,14 +76,14 @@ fn copy_local_artifacts(out_dir: &Path, mut lib_dir: PathBuf) { lib_dir.push("lib"); lib_dir = lib_dir.canonicalize().unwrap(); do_with_lock(&out_dir.join(INSTALL_LOCK), || { - for art in ARTIFACTS { + for art in ARTIFACTS { + println!("cargo::rerun-if-changed={}", &lib_dir.join(art).display()); std::fs::copy(&lib_dir.join(art), &out_dir.join(art)).unwrap(); } }); } fn fetch_artifacts(out_dir: &Path) { - struct Archive(Vec); impl Handler for Archive { fn write(&mut self, data: &[u8]) -> Result { diff --git a/frontends/concrete-rust/concrete/src/ffi.h b/frontends/concrete-rust/concrete/src/ffi.h index c11045bfa..10d4e9d43 100644 --- a/frontends/concrete-rust/concrete/src/ffi.h +++ b/frontends/concrete-rust/concrete/src/ffi.h @@ -19,6 +19,7 @@ #include "concretelang/Support/V0Parameters.h" #include "cxx.h" #include +#include #include #include #include @@ -176,6 +177,16 @@ std::unique_ptr compile(rust::Str sources, } template struct Key : T { + + static std::unique_ptr> _from_buffer_and_info(rust::Slice buffer_slice, rust::Str info_json) { + auto info = typename T::InfoType(); + auto info_string = std::string(info_json); + assert(info.readJsonFromString(info_string).has_value()); + auto buffer = std::make_shared>(buffer_slice.begin(), buffer_slice.end()); + auto output =std::make_unique(buffer, info); + return std::unique_ptr>(reinterpret_cast *>(output.release())); + } + rust::Slice get_buffer() { auto buffer = this->getBuffer(); return {buffer.data(), buffer.size()}; @@ -186,6 +197,12 @@ template struct Key : T { }; typedef Key LweSecretKey; + +inline std::unique_ptr +_lwe_secret_key_from_buffer_and_info(rust::Slice buffer_slice, rust::Str info_json) { + return LweSecretKey::_from_buffer_and_info(buffer_slice, info_json); +} + typedef Key LweBootstrapKey; typedef Key LweKeyswitchKey; typedef Key PackingKeyswitchKey; @@ -301,11 +318,17 @@ struct Keyset : concretelang::keysets::Keyset { std::unique_ptr _keyset_new(rust::Str keyset_info, SecretCsprng &secret_csprng, - EncryptionCsprng &encryption_csprng) { + EncryptionCsprng &encryption_csprng, + rust::Slice> initial_keys) { auto info = Message(); info.readJsonFromString(std::string(keyset_info)).value(); + auto map = std::map(); + for (auto &key : initial_keys) { + auto info = key->getInfo(); + map.insert(std::make_pair(info.asReader().getId(), std::move(*key.release()))); + } auto output = std::make_unique( - info, secret_csprng, encryption_csprng); + info, secret_csprng, encryption_csprng, map); return std::unique_ptr(reinterpret_cast(output.release())); } @@ -345,6 +368,24 @@ const auto _tensor_i32_new = _tensor_new; const auto _tensor_u64_new = _tensor_new; const auto _tensor_i64_new = _tensor_new; +struct TransportValue : concretelang::values::TransportValue { + std::unique_ptr to_owned() const { + return std::make_unique(*this); + } + + rust::Vec serialize() const { + auto output = rust::Vec(); + auto vec_ostream = VecOStream(output); + auto ostream = std::ostream(&vec_ostream); + this->writeBinaryToOstream( + ostream + ).value(); + ostream.flush(); + return output; + } + +}; + struct Value : concretelang::values::Value { bool _has_element_type_u8() const { return hasElementType(); } bool _has_element_type_i8() const { return hasElementType(); } @@ -403,9 +444,19 @@ struct Value : concretelang::values::Value { } rust::Slice get_dimensions() const { - auto vecref = getDimensions(); + const auto& vecref = getDimensions(); return {vecref.data(), vecref.size()}; } + + std::unique_ptr into_transport_value(rust::Str type_info_json) const { + auto first = intoRawTransportValue(); + auto info = Message(); + info.readJsonFromString(std::string(type_info_json)).value(); + first.asBuilder().setTypeInfo(info.asReader()); + auto output = + std::make_unique<::concretelang::values::TransportValue>(first); + return std::unique_ptr(reinterpret_cast(output.release())); + } }; template @@ -423,23 +474,6 @@ const auto _value_from_tensor_i32 = _value_from_tensor; const auto _value_from_tensor_u64 = _value_from_tensor; const auto _value_from_tensor_i64 = _value_from_tensor; -struct TransportValue : concretelang::values::TransportValue { - std::unique_ptr to_owned() const { - return std::make_unique(*this); - } - - rust::Vec serialize() const { - auto output = rust::Vec(); - auto vec_ostream = VecOStream(output); - auto ostream = std::ostream(&vec_ostream); - this->writeBinaryToOstream( - ostream - ).value(); - ostream.flush(); - return output; - } -}; - std::unique_ptr _deserialize_transport_value(rust::Slice slice) { auto output = TransportValue(); auto slice_istream = SliceIStream(slice); @@ -448,6 +482,12 @@ std::unique_ptr _deserialize_transport_value(rust::Slice(output); } +std::unique_ptr _transport_value_to_value(TransportValue const &tv) { + auto output = + std::make_unique<::concretelang::values::Value>(::concretelang::values::Value::fromRawTransportValue(tv)); + return std::unique_ptr(reinterpret_cast(output.release())); +} + struct ClientFunction : concretelang::clientlib::ClientCircuit { std::unique_ptr prepare_input(std::unique_ptr arg, size_t pos) { @@ -543,7 +583,14 @@ struct ServerFunction : concretelang::serverlib::ServerCircuit { for (size_t i = 0; i < args.length(); i++) { oargs.push_back(*args[i].release()); } - auto res = std::make_unique>(call(keys, oargs).value()); + auto maybe_res = call(keys, oargs); + if (maybe_res.has_error()){ + std::cout << "Failed to perform call:\n"; + std::cout << maybe_res.error().mesg; + std::cout.flush(); + assert(false); + } + auto res = std::make_unique>(maybe_res.value()); return std::unique_ptr>( reinterpret_cast *>(res.release())); } diff --git a/frontends/concrete-rust/concrete/src/ffi.rs b/frontends/concrete-rust/concrete/src/ffi.rs index 4f4061d27..3850281b4 100644 --- a/frontends/concrete-rust/concrete/src/ffi.rs +++ b/frontends/concrete-rust/concrete/src/ffi.rs @@ -8,6 +8,7 @@ use crate::protocol::{ CircuitInfo, KeysetInfo, LweBootstrapKeyInfo, LweKeyswitchKeyInfo, LweSecretKeyInfo, PackingKeyswitchKeyInfo, ProgramInfo, }; +use crate::utils::into_value::IntoValue; use cxx::{CxxVector, SharedPtr, UniquePtr}; #[cxx::bridge(namespace = "concrete_rust")] @@ -144,6 +145,11 @@ mod ffi { fn get_buffer(self: Pin<&mut LweSecretKey>) -> &[u64]; #[doc(hidden)] fn _get_info_json(self: &LweSecretKey) -> String; + #[doc(hidden)] + fn _lwe_secret_key_from_buffer_and_info( + buffer: &[u64], + info: &str, + ) -> UniquePtr; /// A Keyset object holding both the [`ClientKeyset`] and the [`ServerKeyset`]. type Keyset; @@ -152,6 +158,7 @@ mod ffi { keyset_info_json: &str, secret_csprng: Pin<&mut SecretCsprng>, encryption_csprng: Pin<&mut EncryptionCsprng>, + initial_keys: &mut [UniquePtr], ) -> UniquePtr; /// Return the associated server keyset. fn get_server(self: &Keyset) -> UniquePtr; @@ -319,6 +326,7 @@ mod ffi { #[doc(hidden)] fn _get_tensor_i64(self: &Value) -> UniquePtr; fn get_dimensions(self: &Value) -> &[usize]; + fn into_transport_value(self: &Value, type_info_json: &str) -> UniquePtr; /// A serialized value which can be transported between the server and the client. type TransportValue; @@ -328,6 +336,8 @@ mod ffi { fn serialize(self: &TransportValue) -> Vec; #[doc(hidden)] fn _deserialize_transport_value(bytes: &[u8]) -> UniquePtr; + #[doc(hidden)] + fn _transport_value_to_value(tv: &TransportValue) -> UniquePtr; // ------------------------------------------------------------------------------------------- Client @@ -429,6 +439,7 @@ mod ffi { } pub use ffi::*; +use serde_json::map::IntoValues; impl ServerKeyset { /// Deserialize a server keyset from bytes. @@ -449,6 +460,10 @@ impl TransportValue { pub fn deserialize(bytes: &[u8]) -> UniquePtr { _deserialize_transport_value(bytes) } + + pub fn to_value(&self) -> UniquePtr { + _transport_value_to_value(self) + } } impl ServerFunction { @@ -1062,11 +1077,13 @@ impl Keyset { keyset_info: &KeysetInfo, secret_csprng: Pin<&mut SecretCsprng>, encryption_csprng: Pin<&mut EncryptionCsprng>, + mut initial_keys: Vec>, ) -> UniquePtr { _keyset_new( &serde_json::to_string(keyset_info).unwrap(), secret_csprng, encryption_csprng, + initial_keys.as_mut_slice(), ) } } diff --git a/frontends/concrete-rust/concrete/src/lib.rs b/frontends/concrete-rust/concrete/src/lib.rs index 8f74414de..7cd9199a7 100644 --- a/frontends/concrete-rust/concrete/src/lib.rs +++ b/frontends/concrete-rust/concrete/src/lib.rs @@ -1,7 +1,9 @@ -pub use cxx::{UniquePtr, SharedPtr}; +pub use cxx::{SharedPtr, UniquePtr}; pub use ffi::c_void; mod ffi; +#[cfg(feature = "tfhe-rs")] +pub mod tfhe; #[cfg(feature = "compiler")] #[doc(hidden)] @@ -25,3 +27,6 @@ pub mod server { } pub mod protocol; + +#[doc(hidden)] +pub mod utils; diff --git a/frontends/concrete-rust/concrete/src/protocol.rs b/frontends/concrete-rust/concrete/src/protocol.rs index 145f8bffa..e2469015e 100644 --- a/frontends/concrete-rust/concrete/src/protocol.rs +++ b/frontends/concrete-rust/concrete/src/protocol.rs @@ -1,14 +1,50 @@ #![allow(non_camel_case_types, non_snake_case, unused)] +use crate::tfhe::ModuleSpec; use serde::{Deserialize, Serialize}; /// A complete program can be described by the ensemble of circuit signatures, and the description /// of the keyset that go with it. This structure regroup those informations. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct ProgramInfo { /// The informations on the keyset of the program. pub keyset: KeysetInfo, /// The informations for the different circuits of the program. pub circuits: Vec, + /// The tfhers spec. + pub tfhers_specs: Option, +} + +impl ProgramInfo { + // Generates a `tfhers_specs` field from the circuits informations. + // + // If the `tfhers_specs` field is not available, it means that no tfhers interoperrability is needed. + // We can generate a dummy `tfhers_specs` field to make further use of the program info object simpler. + pub fn eventually_patch_tfhers_specs(&mut self) { + if self.tfhers_specs.is_none() { + self.tfhers_specs = Some(ModuleSpec { + input_types_per_func: self + .circuits + .iter() + .map(|c| (c.name.clone(), vec![None; c.inputs.len()])) + .collect(), + output_types_per_func: self + .circuits + .iter() + .map(|c| (c.name.clone(), vec![None; c.outputs.len()])) + .collect(), + input_shapes_per_func: self + .circuits + .iter() + .map(|c| (c.name.clone(), vec![None; c.inputs.len()])) + .collect(), + output_shapes_per_func: self + .circuits + .iter() + .map(|c| (c.name.clone(), vec![None; c.outputs.len()])) + .collect(), + }); + } + } } /// A circuit signature can be described completely by the type informations for its input and @@ -17,7 +53,7 @@ pub struct ProgramInfo { /// Note: /// The order of the input and output lists matters. The order of values should be the same when /// executing the circuit. Also, the name is expected to be unique in the program. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct CircuitInfo { /// The ordered list of input types. pub inputs: Vec, @@ -29,7 +65,7 @@ pub struct CircuitInfo { /// A value flowing in or out of a circuit is expected to be of a given type, according to the /// signature of this circuit. This structure represents such a type in a circuit signature. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct GateInfo { /// The raw information that raw data must be possible to parse with. pub rawInfo: RawInfo, @@ -42,7 +78,7 @@ pub struct GateInfo { /// tensor of proper shape, signedness and precision before being pre-processed and passed to the /// computation. This structure represents the informations needed to parse this payload into the /// expected tensor. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct RawInfo { /// The shape of the tensor. pub shape: Shape, @@ -57,14 +93,14 @@ pub struct RawInfo { /// /// Note: /// If the dimensions vector is empty, the message is interpreted as a scalar. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct Shape { /// The dimensions of the value. pub dimensions: Vec, } /// The different possible type of values. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum TypeInfo { lweCiphertext(LweCiphertextTypeInfo), plaintext(PlaintextTypeInfo), @@ -73,7 +109,7 @@ pub enum TypeInfo { /// A plaintext value can flow in and out of a circuit. This structure represents the informations /// needed to verify and pre-or-post process this value. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct PlaintextTypeInfo { /// The shape of the value. pub shape: Shape, @@ -85,7 +121,7 @@ pub struct PlaintextTypeInfo { /// A plaintext value can flow in and out of a circuit. This structure represents the informations /// needed to verify and pre-or-post process this value. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct IndexTypeInfo { /// The shape of the value. pub shape: Shape, @@ -103,7 +139,7 @@ pub struct IndexTypeInfo { /// would have if the values were cleartext. That is, it does not take into account the encryption /// process. The concrete shape is the final shape of the object accounting for the encryption, /// that usually add one or more dimension to the object. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweCiphertextTypeInfo { /// The abstract shape of the value. pub abstractShape: Shape, @@ -121,7 +157,7 @@ pub struct LweCiphertextTypeInfo { /// The encryption of a cleartext value requires some parameters to operate. This structure /// represents those parameters. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweCiphertextEncryptionInfo { /// The identifier of the secret key used to perform the encryption. pub keyId: u32, @@ -136,7 +172,7 @@ pub struct LweCiphertextEncryptionInfo { /// /// Note: /// Not all compressions are available for every types of evaluation keys or ciphertexts. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum Compression { none, seed, @@ -144,7 +180,7 @@ pub enum Compression { } /// The encoding of the value stored inside the ciphertext. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum LweCiphretextTypeInfo_Encoding { integer(IntegerCiphertextEncodingInfo), boolean(BooleanCiphertextEncodingInfo), @@ -152,7 +188,7 @@ pub enum LweCiphretextTypeInfo_Encoding { /// A ciphertext can be used to represent an integer value. This structure represents the /// informations needed to encode such an integer. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct IntegerCiphertextEncodingInfo { /// The bitwidth of the encoded integer. pub width: u32, @@ -163,7 +199,7 @@ pub struct IntegerCiphertextEncodingInfo { } /// The mode used to encode the integer. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum IntegerCiphertextEncodingInfo_Mode { native(IntegerCiphertextEncodingInfo_Mode_NativeMode), chunked(IntegerCiphertextEncodingInfo_Mode_ChunkedMode), @@ -173,12 +209,12 @@ pub enum IntegerCiphertextEncodingInfo_Mode { /// An integer of width from 1 to 8 bits can be encoded in a single ciphertext natively, by /// being shifted in the most significant bits. This structure represents this integer encoding /// mode. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct IntegerCiphertextEncodingInfo_Mode_NativeMode {} /// An integer of width from 1 to n can be encoded in a set of ciphertexts by chunking the bits /// of the original integer. This structure represents this integer encoding mode. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct IntegerCiphertextEncodingInfo_Mode_ChunkedMode { /// The number of chunks to be used. pub size: u32, @@ -188,7 +224,7 @@ pub struct IntegerCiphertextEncodingInfo_Mode_ChunkedMode { /// An integer of width 1 to 16 can be encoded in a set of ciphertexts, by decomposing a value /// using a set of pairwise coprimes. This structure represents this integer encoding mode. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct IntegerCiphertextEncodingInfo_Mode_CrtMode { /// The coprimes used to decompose the original value. pub moduli: Vec, @@ -196,12 +232,12 @@ pub struct IntegerCiphertextEncodingInfo_Mode_CrtMode { /// A ciphertext can be used to represent a boolean value. This structure represents such an /// encoding. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct BooleanCiphertextEncodingInfo {} /// Secret Keys can be drawn from different ranges of values, using different distributions. This /// enumeration encodes the different supported ways. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum KeyType { binary = 0, ternary = 1, @@ -209,14 +245,14 @@ pub enum KeyType { /// Ciphertext operations are performed using modular arithmetic. Depending on the use, different /// modulus can be used for the operations. This structure encodes the different supported ways. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct Modulus { /// The modulus expected to be used. pub modulus: Modulus_enum, } /// The modulus expected to be used. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub enum Modulus_enum { native(NativeModulus), powerOfTwo(PowerOfTwoModulus), @@ -231,14 +267,14 @@ pub enum Modulus_enum { /// /// Example: /// 2^64 when the ciphertext is stored using 64 bits integers. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct NativeModulus {} /// Operations are performed using a modulus that is a power of two. /// /// Example: /// 2^n for any n between 0 and the bitwidth of the integer used to store the ciphertext. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct PowerOfTwoModulus { /// The power used to raise 2. pub power: u32, @@ -249,7 +285,7 @@ pub struct PowerOfTwoModulus { /// Example: /// n for any n between 0 and 2^N where N is the bitwidth of the integer used to store the /// ciphertext. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct IntegerModulus { /// The value used as modulus. pub modulus: u32, @@ -261,7 +297,7 @@ pub struct IntegerModulus { /// Note: /// Secret keys with same parameters are allowed to co-exist in a program, as long as they /// have different ids. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweSecretKeyInfo { /// The identifier of the key. pub id: u32, @@ -271,7 +307,7 @@ pub struct LweSecretKeyInfo { /// A secret key is parameterized by a few quantities of cryptographic importance. This structure /// represents those parameters. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweSecretKeyParams { /// The LWE dimension, e.g. the length of the key. pub lweDimension: u32, @@ -287,7 +323,7 @@ pub struct LweSecretKeyParams { /// Note: /// Keyswitch keys with same parameters, compression, input and output id, are allowed to co-exist /// in a program as long as they have different ids. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweKeyswitchKeyInfo { /// The identifier of the keyswitch key. pub id: u32, @@ -306,7 +342,7 @@ pub struct LweKeyswitchKeyInfo { /// /// Note: /// For now, only keys with the same input and output key types can be represented. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweKeyswitchKeyParams { /// The number of levels of the ciphertexts. pub levelCount: u32, @@ -333,7 +369,7 @@ pub struct LweKeyswitchKeyParams { /// Note: /// Packing keyswitch keys with same parameters, compression, input and output id, are allowed to /// co-exist in a program as long as they have different ids. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct PackingKeyswitchKeyInfo { /// The identifier of the packing keyswitch key. pub id: u32, @@ -352,7 +388,7 @@ pub struct PackingKeyswitchKeyInfo { /// /// Note: /// For now, only keys with the same input and output key types can be represented. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct PackingKeyswitchKeyParams { /// The number of levels of the ciphertexts. pub levelCount: u32, @@ -382,7 +418,7 @@ pub struct PackingKeyswitchKeyParams { /// Note: /// Bootstrap keys with same parameters, compression, input and output id, are allowed to co-exist /// in a program as long as they have different ids. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweBootstrapKeyInfo { /// The identifier of the bootstrap key. pub id: u32, @@ -401,7 +437,7 @@ pub struct LweBootstrapKeyInfo { /// /// Note: /// For now, only keys with the same input and output key types can be represented. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct LweBootstrapKeyParams { /// The number of levels of the ciphertexts. pub levelCount: u32, @@ -425,7 +461,7 @@ pub struct LweBootstrapKeyParams { /// The keyset needed for an application can be described by an ensemble of descriptions of the /// different keys used in the program. This structure represents such a description. -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct KeysetInfo { /// The secret key descriptions. pub lweSecretKeys: Vec, @@ -455,10 +491,15 @@ mod to_tokens { .iter() .map(|circuit| quote! { #circuit }) .collect::>(); + let tfhers_specs = match &self.tfhers_specs { + Some(s) => quote! {Some(#s)}, + None => quote! {None}, + }; tokens.extend(quote! { ::concrete::protocol::ProgramInfo { keyset: #keyset, circuits: vec![#(#circuits),*], + tfhers_specs: #tfhers_specs } }); } @@ -975,3 +1016,16 @@ mod to_tokens { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deserialize_program_info() { + let string = r#" + {"keyset": {"lweSecretKeys": [{"id": 0, "params": {"lweDimension": 2048, "integerPrecision": 64, "keyType": "binary"}}, {"id": 1, "params": {"lweDimension": 4096, "integerPrecision": 64, "keyType": "binary"}}, {"id": 2, "params": {"lweDimension": 776, "integerPrecision": 64, "keyType": "binary"}}, {"id": 3, "params": {"lweDimension": 626, "integerPrecision": 64, "keyType": "binary"}}, {"id": 4, "params": {"lweDimension": 2048, "integerPrecision": 64, "keyType": "binary"}}], "lweBootstrapKeys": [{"id": 0, "inputId": 2, "outputId": 0, "params": {"levelCount": 2, "baseLog": 15, "glweDimension": 2, "polynomialSize": 1024, "variance": 8.442253112932959e-31, "integerPrecision": 64, "modulus": {"modulus": {"native": {}}}, "keyType": "binary", "inputLweDimension": 776}, "compression": "none"}, {"id": 1, "inputId": 3, "outputId": 4, "params": {"levelCount": 11, "baseLog": 4, "glweDimension": 4, "polynomialSize": 512, "variance": 8.442253112932959e-31, "integerPrecision": 64, "modulus": {"modulus": {"native": {}}}, "keyType": "binary", "inputLweDimension": 626}, "compression": "none"}], "lweKeyswitchKeys": [{"id": 0, "inputId": 1, "outputId": 2, "params": {"levelCount": 5, "baseLog": 3, "variance": 4.0324907628621766e-11, "integerPrecision": 64, "modulus": {"modulus": {"native": {}}}, "keyType": "binary", "inputLweDimension": 4096, "outputLweDimension": 776}, "compression": "none"}, {"id": 1, "inputId": 0, "outputId": 1, "params": {"levelCount": 1, "baseLog": 31, "variance": 4.70197740328915e-38, "integerPrecision": 64, "modulus": {"modulus": {"native": {}}}, "keyType": "binary", "inputLweDimension": 2048, "outputLweDimension": 4096}, "compression": "none"}, {"id": 2, "inputId": 1, "outputId": 3, "params": {"levelCount": 5, "baseLog": 2, "variance": 8.437693323536307e-09, "integerPrecision": 64, "modulus": {"modulus": {"native": {}}}, "keyType": "binary", "inputLweDimension": 4096, "outputLweDimension": 626}, "compression": "none"}, {"id": 3, "inputId": 4, "outputId": 1, "params": {"levelCount": 2, "baseLog": 21, "variance": 4.70197740328915e-38, "integerPrecision": 64, "modulus": {"modulus": {"native": {}}}, "keyType": "binary", "inputLweDimension": 2048, "outputLweDimension": 4096}, "compression": "none"}], "packingKeyswitchKeys": []}, "circuits": [{"inputs": [{"rawInfo": {"shape": {"dimensions": [8, 4097]}, "integerPrecision": 64, "isSigned": false}, "typeInfo": {"lweCiphertext": {"abstractShape": {"dimensions": [8]}, "concreteShape": {"dimensions": [8, 4097]}, "integerPrecision": 64, "encryption": {"keyId": 1, "variance": 4.70197740328915e-38, "lweDimension": 4096, "modulus": {"modulus": {"native": {}}}}, "compression": "none", "encoding": {"integer": {"width": 4, "isSigned": false, "mode": {"native": {}}}}}}}, {"rawInfo": {"shape": {"dimensions": [8, 4097]}, "integerPrecision": 64, "isSigned": false}, "typeInfo": {"lweCiphertext": {"abstractShape": {"dimensions": [8]}, "concreteShape": {"dimensions": [8, 4097]}, "integerPrecision": 64, "encryption": {"keyId": 1, "variance": 4.70197740328915e-38, "lweDimension": 4096, "modulus": {"modulus": {"native": {}}}}, "compression": "none", "encoding": {"integer": {"width": 4, "isSigned": false, "mode": {"native": {}}}}}}}], "outputs": [{"rawInfo": {"shape": {"dimensions": [8, 4097]}, "integerPrecision": 64, "isSigned": false}, "typeInfo": {"lweCiphertext": {"abstractShape": {"dimensions": [8]}, "concreteShape": {"dimensions": [8, 4097]}, "integerPrecision": 64, "encryption": {"keyId": 1, "variance": 4.70197740328915e-38, "lweDimension": 4096, "modulus": {"modulus": {"native": {}}}}, "compression": "none", "encoding": {"integer": {"width": 4, "isSigned": false, "mode": {"native": {}}}}}}}], "name": "my_func"}], "tfhers_specs": {"input_types_per_func": {"my_func": [{"is_signed": false, "bit_width": 16, "carry_width": 2, "msg_width": 2, "params": {"lwe_dimension": 909, "glwe_dimension": 1, "polynomial_size": 4096, "pbs_base_log": 15, "pbs_level": 2, "lwe_noise_distribution": 0, "glwe_noise_distribution": 2.168404344971009e-19, "encryption_key_choice": 0}}, {"is_signed": false, "bit_width": 16, "carry_width": 2, "msg_width": 2, "params": {"lwe_dimension": 909, "glwe_dimension": 1, "polynomial_size": 4096, "pbs_base_log": 15, "pbs_level": 2, "lwe_noise_distribution": 0, "glwe_noise_distribution": 2.168404344971009e-19, "encryption_key_choice": 0}}]}, "output_types_per_func": {"my_func": [{"is_signed": false, "bit_width": 16, "carry_width": 2, "msg_width": 2, "params": {"lwe_dimension": 909, "glwe_dimension": 1, "polynomial_size": 4096, "pbs_base_log": 15, "pbs_level": 2, "lwe_noise_distribution": 0, "glwe_noise_distribution": 2.168404344971009e-19, "encryption_key_choice": 0}}]}, "input_shapes_per_func": {"my_func": [[], []]}, "output_shapes_per_func": {"my_func": [[]]}}} + "#; + let val: ProgramInfo = serde_json::from_str(string).unwrap(); + } +} diff --git a/frontends/concrete-rust/concrete/src/tfhe/from_value.rs b/frontends/concrete-rust/concrete/src/tfhe/from_value.rs new file mode 100644 index 000000000..6886d14b2 --- /dev/null +++ b/frontends/concrete-rust/concrete/src/tfhe/from_value.rs @@ -0,0 +1,65 @@ +use super::{EncryptionKeyChoice, IntegerType}; +use crate::ffi::Value; +use crate::utils::from_value::FromValue; +use cxx::UniquePtr; +use tfhe::core_crypto::prelude::LweCiphertext; +use tfhe::integer::ciphertext::{DataKind, Expandable}; +use tfhe::shortint::parameters::{Degree, NoiseLevel}; +use tfhe::shortint::{CarryModulus, Ciphertext, CiphertextModulus, MessageModulus, PBSOrder, AtomicPatternKind}; +use tfhe::{ + FheInt10, FheInt12, FheInt14, FheInt16, FheInt2, FheInt4, FheInt6, FheInt8, FheUint10, + FheUint12, FheUint14, FheUint16, FheUint2, FheUint4, FheUint6, FheUint8, +}; + +macro_rules! impl_from_value_integer { + ($ty:ty, $datakind:expr) => { + impl FromValue for $ty { + type Spec = IntegerType; + + fn from_value(s: Self::Spec, v: UniquePtr) -> Self { + let lwe_size = s.params.polynomial_size + 1; + let vals = v.get_tensor::().unwrap(); + let cts = (0..s.n_cts()) + .map(|i| { + Ciphertext::new( + LweCiphertext::from_container( + vals.values()[i * lwe_size..(i + 1) * lwe_size].to_vec(), + CiphertextModulus::new_native(), + ), + Degree::new(2u64.pow(s.msg_width as u32) - 1), + NoiseLevel::UNKNOWN, + MessageModulus(2u64.pow(s.msg_width as u32)), + CarryModulus(2u64.pow(s.carry_width as u32)), + match s.params.encryption_key_choice { + EncryptionKeyChoice::BIG => { + AtomicPatternKind::Standard(PBSOrder::KeyswitchBootstrap) + } + EncryptionKeyChoice::SMALL => { + AtomicPatternKind::Standard(PBSOrder::BootstrapKeyswitch) + } + }, + ) + }) + .collect(); + <$ty>::from_expanded_blocks(cts, $datakind).unwrap() + } + } + }; +} + +impl_from_value_integer!(FheUint2, DataKind::Unsigned(2)); +impl_from_value_integer!(FheUint4, DataKind::Unsigned(4)); +impl_from_value_integer!(FheUint6, DataKind::Unsigned(6)); +impl_from_value_integer!(FheUint8, DataKind::Unsigned(8)); +impl_from_value_integer!(FheUint10, DataKind::Unsigned(10)); +impl_from_value_integer!(FheUint12, DataKind::Unsigned(12)); +impl_from_value_integer!(FheUint14, DataKind::Unsigned(14)); +impl_from_value_integer!(FheUint16, DataKind::Unsigned(16)); +impl_from_value_integer!(FheInt2, DataKind::Signed(2)); +impl_from_value_integer!(FheInt4, DataKind::Signed(4)); +impl_from_value_integer!(FheInt6, DataKind::Signed(6)); +impl_from_value_integer!(FheInt8, DataKind::Signed(8)); +impl_from_value_integer!(FheInt10, DataKind::Signed(10)); +impl_from_value_integer!(FheInt12, DataKind::Signed(12)); +impl_from_value_integer!(FheInt14, DataKind::Signed(14)); +impl_from_value_integer!(FheInt16, DataKind::Signed(16)); diff --git a/frontends/concrete-rust/concrete/src/tfhe/into_key.rs b/frontends/concrete-rust/concrete/src/tfhe/into_key.rs new file mode 100644 index 000000000..a94152288 --- /dev/null +++ b/frontends/concrete-rust/concrete/src/tfhe/into_key.rs @@ -0,0 +1,30 @@ +use crate::ffi::LweSecretKey; +use crate::protocol::{KeyType, LweSecretKeyInfo, LweSecretKeyParams}; +use cxx::UniquePtr; +use tfhe::ClientKey; + +pub trait IntoLweSecretKey { + fn into_lwe_secret_key(&self, id: Option) -> UniquePtr; +} + +impl IntoLweSecretKey for ClientKey { + fn into_lwe_secret_key(&self, id: Option) -> cxx::UniquePtr { + let (integer_ck, _, _, _, _) = self.clone().into_raw_parts(); + let shortint_ck = integer_ck.into_raw_parts(); + let (glwe_secret_key, _, _) = shortint_ck.into_raw_parts(); + let lwe_secret_key = glwe_secret_key.into_lwe_secret_key(); + let buffer = lwe_secret_key.as_view().into_container(); + let info = LweSecretKeyInfo { + id: id.unwrap_or(0), + params: LweSecretKeyParams { + lweDimension: buffer.len() as u32, + integerPrecision: 64, + keyType: KeyType::binary, + }, + }; + crate::ffi::_lwe_secret_key_from_buffer_and_info( + buffer, + &serde_json::to_string(&info).unwrap(), + ) + } +} diff --git a/frontends/concrete-rust/concrete/src/tfhe/into_value.rs b/frontends/concrete-rust/concrete/src/tfhe/into_value.rs new file mode 100644 index 000000000..419bf0ce6 --- /dev/null +++ b/frontends/concrete-rust/concrete/src/tfhe/into_value.rs @@ -0,0 +1,74 @@ +use crate::ffi::{Tensor, Value}; +use crate::utils::into_value::IntoValue; +use cxx::UniquePtr; +use std::ptr::copy_nonoverlapping; +use tfhe::integer::IntegerCiphertext; +use tfhe::{ + FheInt10, FheInt12, FheInt14, FheInt16, FheInt2, FheInt4, FheInt6, FheInt8, FheUint10, + FheUint12, FheUint14, FheUint16, FheUint2, FheUint4, FheUint6, FheUint8, +}; + +macro_rules! impl_into_value { + ($($type:ty),*) => { + $( + impl IntoValue for $type { + fn into_value(self) -> UniquePtr { + let (radix, _, _) = self.into_raw_parts(); + let n_cts = radix.blocks().len(); + let lwe_size = radix.blocks()[0].ct.lwe_size().0; + let mut vals: Vec = Vec::with_capacity(n_cts * lwe_size); + + // SAFETY: We are setting the length of the vector to match its capacity. + // This is safe because the vector was allocated with exactly `n_cts * lwe_size` capacity, + // and we will initialize all elements before using them. + unsafe { vals.set_len(n_cts * lwe_size) }; + + for (i, block) in radix.blocks().iter().enumerate() { + unsafe { + // SAFETY: + // 1. `block.ct.as_view().into_container().as_ptr()` points to valid memory + // because `block.ct` is a valid ciphertext. + // 2. `vals.as_mut_ptr().add(i * lwe_size)` points to a valid, non-overlapping + // region of memory within the allocated vector because `vals` was allocated + // with sufficient capacity. + // 3. `lwe_size` elements are copied, which matches the size of the source and + // destination regions. + // 4. `block` and `vals` are non overlapping, because `block` is allocated before + // the function, and `vals` is allocated inside the scope of the function. + copy_nonoverlapping( + block.ct.as_view().into_container().as_ptr(), + vals.as_mut_ptr().add(i * lwe_size), + lwe_size, + ); + } + } + + Value::from_tensor(Tensor::::new(vals, vec![n_cts, lwe_size])) + } + } + )* + }; +} + +impl_into_value!( + FheInt2, FheInt4, FheInt6, FheInt8, FheInt10, FheInt12, FheInt14, FheInt16, FheUint2, FheUint4, + FheUint6, FheUint8, FheUint10, FheUint12, FheUint14, FheUint16 +); + +#[cfg(test)] +mod tests { + use super::*; + use tfhe::prelude::*; + use tfhe::{generate_keys, ConfigBuilder, FheUint8}; + + #[test] + fn test_into_value() { + let config = ConfigBuilder::default().build(); + let (client_key, _) = generate_keys(config); + let clear_a = 27u8; + let a = FheUint8::encrypt(clear_a, &client_key); + let val = a.into_value(); + assert!(val._has_element_type_u64()); + assert_eq!(val.get_dimensions(), [4, 2049]); + } +} diff --git a/frontends/concrete-rust/concrete/src/tfhe/mod.rs b/frontends/concrete-rust/concrete/src/tfhe/mod.rs new file mode 100644 index 000000000..a3e5a38fe --- /dev/null +++ b/frontends/concrete-rust/concrete/src/tfhe/mod.rs @@ -0,0 +1,11 @@ +mod from_value; +mod into_value; + +mod into_key; +pub use into_key::*; + +mod types; +pub use types::*; + +mod spec; +pub use spec::*; diff --git a/frontends/concrete-rust/concrete/src/tfhe/spec.rs b/frontends/concrete-rust/concrete/src/tfhe/spec.rs new file mode 100644 index 000000000..8da59d2a6 --- /dev/null +++ b/frontends/concrete-rust/concrete/src/tfhe/spec.rs @@ -0,0 +1,189 @@ +use super::IntegerType; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +pub struct ModuleSpec { + pub input_types_per_func: HashMap>>, + pub output_types_per_func: HashMap>>, + pub input_shapes_per_func: HashMap>>>, + pub output_shapes_per_func: HashMap>>>, +} + +impl ModuleSpec { + pub fn get_func(&self, name: &str) -> Option { + if !self.input_types_per_func.contains_key(name) { + return None; + } + Some(FunctionSpec { + input_types: self.input_types_per_func.get(name).unwrap(), + output_types: self.output_types_per_func.get(name).unwrap(), + input_shapes: self.input_shapes_per_func.get(name).unwrap(), + output_shapes: self.output_shapes_per_func.get(name).unwrap(), + }) + } +} + +#[derive(Debug, Clone)] +pub struct FunctionSpec<'a> { + pub input_types: &'a Vec>, + pub output_types: &'a Vec>, + pub input_shapes: &'a Vec>>, + pub output_shapes: &'a Vec>>, +} + +#[cfg(feature = "compiler")] +mod to_tokens { + //! This module contains `ToTokens` implementations. This allows protocol + //! values to be interpolated in the `quote!` macro as constructors of the values. + //! Useful to construct static protocol values. + + use super::*; + use proc_macro2::TokenStream; + use quote::{quote, ToTokens}; + + impl ToTokens for ModuleSpec { + fn to_tokens(&self, tokens: &mut TokenStream) { + let input_types_per_func = &self + .input_types_per_func + .iter() + .map(|(key, value)| { + ( + key, + value.iter().map(|maybe_type| match maybe_type { + Some(typ) => quote! {Some(#typ)}, + None => quote! { None }, + }), + ) + }) + .map(|(key, value)| quote! { (#key.to_string(), vec![#(#value),*]) }) + .collect::>(); + let output_types_per_func = &self + .output_types_per_func + .iter() + .map(|(key, value)| { + ( + key, + value.iter().map(|maybe_type| match maybe_type { + Some(typ) => quote! {Some(#typ)}, + None => quote! { None }, + }), + ) + }) + .map(|(key, value)| quote! { (#key.to_string(), vec![#(#value),*]) }) + .collect::>(); + let input_shapes_per_func = &self + .input_shapes_per_func + .iter() + .map(|(key, value)| { + ( + key, + value.iter().map(|maybe_shape| match maybe_shape { + Some(shape) => quote! {Some(vec![#(#shape),*])}, + None => quote! { None }, + }), + ) + }) + .map(|(key, value)| quote! { (#key.to_string(), vec![#(#value),*]) }) + .collect::>(); + let output_shapes_per_func = &self + .output_shapes_per_func + .iter() + .map(|(key, value)| { + ( + key, + value.iter().map(|maybe_shape| match maybe_shape { + Some(shape) => quote! {Some(vec![#(#shape),*])}, + None => quote! { None }, + }), + ) + }) + .map(|(key, value)| quote! { (#key.to_string(), vec![#(#value),*]) }) + .collect::>(); + + tokens.extend(quote! { + ::concrete::tfhe::ModuleSpec { + input_types_per_func: vec![#(#input_types_per_func),*].into_iter().collect(), + output_types_per_func: vec![#(#output_types_per_func),*].into_iter().collect(), + input_shapes_per_func: vec![#(#input_shapes_per_func),*].into_iter().collect(), + output_shapes_per_func: vec![#(#output_shapes_per_func),*].into_iter().collect(), + } + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quote::quote; + + #[test] + fn test_deserialize_spec() { + let string = r#"{ + "input_types_per_func": { + "my_func": [ + { + "is_signed": false, + "bit_width": 16, + "carry_width": 2, + "msg_width": 2, + "params": { + "lwe_dimension": 909, + "glwe_dimension": 1, + "polynomial_size": 4096, + "pbs_base_log": 15, + "pbs_level": 2, + "lwe_noise_distribution": 0, + "glwe_noise_distribution": 2.168404344971009e-19, + "encryption_key_choice": 0 + } + }, + { + "is_signed": false, + "bit_width": 16, + "carry_width": 2, + "msg_width": 2, + "params": { + "lwe_dimension": 909, + "glwe_dimension": 1, + "polynomial_size": 4096, + "pbs_base_log": 15, + "pbs_level": 2, + "lwe_noise_distribution": 0, + "glwe_noise_distribution": 2.168404344971009e-19, + "encryption_key_choice": 0 + } + } + ] + }, + "output_types_per_func": { + "my_func": [ + { + "is_signed": false, + "bit_width": 16, + "carry_width": 2, + "msg_width": 2, + "params": { + "lwe_dimension": 909, + "glwe_dimension": 1, + "polynomial_size": 4096, + "pbs_base_log": 15, + "pbs_level": 2, + "lwe_noise_distribution": 0, + "glwe_noise_distribution": 2.168404344971009e-19, + "encryption_key_choice": 0 + } + } + ] + }, + "input_shapes_per_func": { "my_func": [[], []] }, + "output_shapes_per_func": { "my_func": [[]] } + } + "#; + let cp: ModuleSpec = serde_json::from_str(string).unwrap(); + let a = quote! {#cp}; + + dbg!(a.to_string()); + } +} diff --git a/frontends/concrete-rust/concrete/src/tfhe/types.rs b/frontends/concrete-rust/concrete/src/tfhe/types.rs new file mode 100644 index 000000000..b0afd636b --- /dev/null +++ b/frontends/concrete-rust/concrete/src/tfhe/types.rs @@ -0,0 +1,133 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, PartialEq, Copy, Debug)] +pub enum EncryptionKeyChoice { + BIG = 0, + SMALL = 1, +} + +impl TryFrom for EncryptionKeyChoice { + type Error = &'static str; + + fn try_from(value: i32) -> Result { + match value { + 0 => Ok(EncryptionKeyChoice::BIG), + 1 => Ok(EncryptionKeyChoice::SMALL), + _ => Err("Invalid value for EncryptionKeyChoice"), + } + } +} + +impl Serialize for EncryptionKeyChoice { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + (*self as i32).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for EncryptionKeyChoice { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = i32::deserialize(deserializer)?; + EncryptionKeyChoice::try_from(value).map_err(serde::de::Error::custom) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct CryptoParams { + pub lwe_dimension: usize, + pub glwe_dimension: usize, + pub polynomial_size: usize, + pub pbs_base_log: usize, + pub pbs_level: usize, + pub lwe_noise_distribution: f64, + pub glwe_noise_distribution: f64, + pub encryption_key_choice: EncryptionKeyChoice, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct IntegerType { + pub carry_width: usize, + pub msg_width: usize, + pub is_signed: bool, + pub bit_width: usize, + pub params: CryptoParams, +} + +impl IntegerType { + pub fn n_cts(&self) -> usize { + self.bit_width / self.msg_width + } +} + +#[cfg(feature = "compiler")] +mod to_tokens { + //! This module contains `ToTokens` implementations. This allows protocol + //! values to be interpolated in the `quote!` macro as constructors of the values. + //! Useful to construct static protocol values. + + use super::*; + use proc_macro2::TokenStream; + use quote::{quote, ToTokens}; + + impl ToTokens for EncryptionKeyChoice { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + EncryptionKeyChoice::BIG => { + tokens.extend(quote! {::concrete::tfhe::EncryptionKeyChoice::BIG}) + } + EncryptionKeyChoice::SMALL => { + tokens.extend(quote! {::concrete::tfhe::EncryptionKeyChoice::SMALL}) + } + } + } + } + + impl ToTokens for CryptoParams { + fn to_tokens(&self, tokens: &mut TokenStream) { + let lwe_dimension = self.lwe_dimension; + let glwe_dimension = self.glwe_dimension; + let polynomial_size = self.polynomial_size; + let pbs_base_log = self.pbs_base_log; + let pbs_level = self.pbs_level; + let lwe_noise_distribution = self.lwe_noise_distribution; + let glwe_noise_distribution = self.glwe_noise_distribution; + let encryption_key_choice = &self.encryption_key_choice; + tokens.extend(quote! { + ::concrete::tfhe::CryptoParams { + lwe_dimension: #lwe_dimension, + glwe_dimension: #glwe_dimension, + polynomial_size: #polynomial_size, + pbs_base_log: #pbs_base_log, + pbs_level: #pbs_level, + lwe_noise_distribution: #lwe_noise_distribution, + glwe_noise_distribution: #glwe_noise_distribution, + encryption_key_choice: #encryption_key_choice, + } + }); + } + } + + impl ToTokens for IntegerType { + fn to_tokens(&self, tokens: &mut TokenStream) { + let carry_width = self.carry_width; + let msg_width = self.msg_width; + let params = &self.params; + let bit_width = &self.bit_width; + let is_signed = &self.is_signed; + tokens.extend(quote! { + ::concrete::tfhe::IntegerType { + carry_width: #carry_width, + msg_width: #msg_width, + params: #params, + bit_width: #bit_width, + is_signed: #is_signed + } + }); + } + } +} diff --git a/frontends/concrete-rust/concrete-macro/src/configuration.rs b/frontends/concrete-rust/concrete/src/utils/configuration.rs similarity index 61% rename from frontends/concrete-rust/concrete-macro/src/configuration.rs rename to frontends/concrete-rust/concrete/src/utils/configuration.rs index 9446fce09..6a959fcdf 100644 --- a/frontends/concrete-rust/concrete-macro/src/configuration.rs +++ b/frontends/concrete-rust/concrete/src/utils/configuration.rs @@ -1,98 +1,7 @@ +use super::python::{PythonPickledEnum, PythonPickledObject}; use serde::{Deserialize, Deserializer}; use serde_json::Value; -struct PythonPickledObject { - py_object: String, - py_serialized: String, -} - -impl<'de> Deserialize<'de> for PythonPickledObject { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let json = Value::deserialize(deserializer)?; - let Value::Object(obj) = json else { - return Err(::custom("Missing object")); - }; - let Some(Value::String(py_object)) = obj.get("py/object") else { - return Err(::custom( - "Missing field \"py/object\"", - )); - }; - let Some(py_serialized) = obj.get("serialized") else { - return Err(::custom( - "Missing field \"serialized\"", - )); - }; - Ok(PythonPickledObject { - py_object: py_object.clone(), - py_serialized: py_serialized.to_string(), - }) - } -} - -struct PythonPickledEnum { - py_type: String, - py_tuple: Value, -} - -impl<'de> Deserialize<'de> for PythonPickledEnum { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let json = Value::deserialize(deserializer)?; - let Value::Object(obj) = json else { - return Err(::custom("Missing object")); - }; - let Some(py_reduce) = obj.get("py/reduce") else { - return Err(::custom( - "Missing field \"py/reduce\"", - )); - }; - let Value::Array(arr) = py_reduce else { - return Err(::custom("Missing array")); - }; - if arr.len() != 2 { - return Err(::custom( - "Unexpected py_reduce array length", - )); - } - let Some(Value::Object(py_type_obj)) = arr.get(0) else { - return Err(::custom("Missing object")); - }; - let Some(Value::String(py_type)) = py_type_obj.get("py/type") else { - return Err(::custom( - "Missing \"py/type\" field.", - )); - }; - let Some(Value::Object(py_tuple_obj)) = arr.get(1) else { - return Err(::custom("Missing object")); - }; - let Some(py_tuple_value) = py_tuple_obj.get("py/tuple") else { - return Err(::custom( - "Missing \"py/tuple\" field.", - )); - }; - let Value::Array(py_tuple_arr) = py_tuple_value else { - return Err(::custom( - "\"py/tuple\" is not an array.", - )); - }; - if py_tuple_arr.len() != 1 { - return Err(::custom( - "Unexpected py_tuple array length", - )); - } - let py_tuple = py_tuple_arr.get(0).unwrap(); - Ok(PythonPickledEnum { - py_type: py_type.clone(), - py_tuple: py_tuple.clone(), - }) - } -} - #[derive(Debug)] pub enum ParameterSelectionStrategy { V0, diff --git a/frontends/concrete-rust/concrete/src/utils/from_value.rs b/frontends/concrete-rust/concrete/src/utils/from_value.rs new file mode 100644 index 000000000..43cf4a7e6 --- /dev/null +++ b/frontends/concrete-rust/concrete/src/utils/from_value.rs @@ -0,0 +1,19 @@ +use cxx::UniquePtr; + +use crate::ffi::{GetTensor, Tensor, Value}; + +pub trait FromValue { + type Spec; + fn from_value(s: Self::Spec, v: UniquePtr) -> Self; +} + +impl FromValue for Tensor +where + Value: GetTensor, +{ + type Spec = (); + + fn from_value(_s: Self::Spec, v: UniquePtr) -> Self { + v.get_tensor().unwrap() + } +} diff --git a/frontends/concrete-rust/concrete/src/utils/into_value.rs b/frontends/concrete-rust/concrete/src/utils/into_value.rs new file mode 100644 index 000000000..2892b96ec --- /dev/null +++ b/frontends/concrete-rust/concrete/src/utils/into_value.rs @@ -0,0 +1,12 @@ +use crate::ffi::{Tensor, Value}; +use cxx::UniquePtr; + +pub trait IntoValue { + fn into_value(self) -> UniquePtr; +} + +impl IntoValue for Tensor { + fn into_value(self) -> UniquePtr { + Value::from_tensor(self) + } +} diff --git a/frontends/concrete-rust/concrete/src/utils/mod.rs b/frontends/concrete-rust/concrete/src/utils/mod.rs new file mode 100644 index 000000000..69c7c2463 --- /dev/null +++ b/frontends/concrete-rust/concrete/src/utils/mod.rs @@ -0,0 +1,4 @@ +pub mod configuration; +pub mod from_value; +pub mod into_value; +pub mod python; diff --git a/frontends/concrete-rust/concrete/src/utils/python.rs b/frontends/concrete-rust/concrete/src/utils/python.rs new file mode 100644 index 000000000..af3f8b261 --- /dev/null +++ b/frontends/concrete-rust/concrete/src/utils/python.rs @@ -0,0 +1,94 @@ +use serde::{Deserialize, Deserializer}; +use serde_json::Value; + +pub struct PythonPickledObject { + pub py_object: String, + pub py_serialized: String, +} + +impl<'de> Deserialize<'de> for PythonPickledObject { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = Value::deserialize(deserializer)?; + let Value::Object(obj) = json else { + return Err(::custom("Missing object")); + }; + let Some(Value::String(py_object)) = obj.get("py/object") else { + return Err(::custom( + "Missing field \"py/object\"", + )); + }; + let Some(py_serialized) = obj.get("serialized") else { + return Err(::custom( + "Missing field \"serialized\"", + )); + }; + Ok(PythonPickledObject { + py_object: py_object.clone(), + py_serialized: py_serialized.to_string(), + }) + } +} + +pub struct PythonPickledEnum { + pub py_type: String, + pub py_tuple: Value, +} + +impl<'de> Deserialize<'de> for PythonPickledEnum { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let json = Value::deserialize(deserializer)?; + let Value::Object(obj) = json else { + return Err(::custom("Missing object")); + }; + let Some(py_reduce) = obj.get("py/reduce") else { + return Err(::custom( + "Missing field \"py/reduce\"", + )); + }; + let Value::Array(arr) = py_reduce else { + return Err(::custom("Missing array")); + }; + if arr.len() != 2 { + return Err(::custom( + "Unexpected py_reduce array length", + )); + } + let Some(Value::Object(py_type_obj)) = arr.get(0) else { + return Err(::custom("Missing object")); + }; + let Some(Value::String(py_type)) = py_type_obj.get("py/type") else { + return Err(::custom( + "Missing \"py/type\" field.", + )); + }; + let Some(Value::Object(py_tuple_obj)) = arr.get(1) else { + return Err(::custom("Missing object")); + }; + let Some(py_tuple_value) = py_tuple_obj.get("py/tuple") else { + return Err(::custom( + "Missing \"py/tuple\" field.", + )); + }; + let Value::Array(py_tuple_arr) = py_tuple_value else { + return Err(::custom( + "\"py/tuple\" is not an array.", + )); + }; + if py_tuple_arr.len() != 1 { + return Err(::custom( + "Unexpected py_tuple array length", + )); + } + let py_tuple = py_tuple_arr.get(0).unwrap(); + Ok(PythonPickledEnum { + py_type: py_type.clone(), + py_tuple: py_tuple.clone(), + }) + } +} diff --git a/frontends/concrete-rust/test/Cargo.toml b/frontends/concrete-rust/test/Cargo.toml index 2af6181c7..6b504e9a7 100644 --- a/frontends/concrete-rust/test/Cargo.toml +++ b/frontends/concrete-rust/test/Cargo.toml @@ -4,5 +4,6 @@ version = "0.0.0" edition = "2021" [dependencies] +tfhe = "1.0" concrete-macro = { path = "../concrete-macro" } concrete = { path = "../concrete" } diff --git a/frontends/concrete-rust/test/python/__pycache__/test_tfhers.cpython-310.pyc b/frontends/concrete-rust/test/python/__pycache__/test_tfhers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d7dba3e7dac1fb394a734f01c889910755eb404 GIT binary patch literal 1132 zcmZ8f&2G~`5Z+lQj_cS-`XeE391v0tG@=Ji6{@I8sZcA_mdjqO;N3JE;vd$VfUEY@ zUi$_JCC9!558^8q`UZ#-vvxxx#@cUYcSfF>Zzl15m!NF?S{?qNg#0v@Wij3zJlS^y zPB@H71Ps8~bW9_QSwMs}t8t6hc%9q4fgVR#qx#66JFI~|M>OWZRs_u4IoT(Cg}c{; zyI1rISd%xocTJclJWj5l!U09?mp8-#J)_JQ%~^}Lgf9XA)fVkhXSB-Oyu(*-C|0v} zLb~gBSR?5|VGV`?m90^$Ox{?8+yS~I>7dm{>kfM*XLwLL+hJuqU>(gJJlRtOK^Snt zXxSxj;Pe}vgX#AgD`s_G$J}<|9Oj1!pNs`|8iy)Rwcxk3fUgDQp-Hz8W7qKP;wcko zGJ^{;hf7jg*x)6gMK8z&JR;hf(-U$FN8OsjMhIKcNf7t7oC?+)s3f|-6kgSCz;&Jq zrN;8GVlaNb$avUwkHWVi~j=tu-X;RZoe zvWB)71fM7Uaph5lCgzx5YMI2!Os(u@30mf@fq-MJk~*YiwV-?p&BPaC?0=klcri?C zWY}=4n)uMQDt7(TH*fcjPY$sEGoh4;#Z`joe`yQQkI|KxDm_pes3h6Ja!j4)Z9~+8 zoO79;c7Zul*$<+zR7c%4<}7%x+VmFJO|!A(l)YWELdpBS>PHz{->X&%nRqAi-C-g_ z!K~K@`^>+edRhg#6B+9lFN3XMtE|1SD|$;x$}*WF<8-1kp^HYCM5#0zi?gZXn2nn< u6^hleJ}R?5OJE2uob4*~l_52IhTtIuHA-Rmv!V9eb#2-K$9aM?*?_--<_tOj literal 0 HcmV?d00001 diff --git a/frontends/concrete-rust/test/python/test.py b/frontends/concrete-rust/test/python/test.py new file mode 100644 index 000000000..3e718dbc9 --- /dev/null +++ b/frontends/concrete-rust/test/python/test.py @@ -0,0 +1,16 @@ +from concrete import fhe + +@fhe.module() +class MyModule: + + @fhe.function({"x": "encrypted"}) + def inc(x): + return (x + 1) % 20 + + @fhe.function({"x": "encrypted"}) + def dec(x): + return (x - 1) % 20 + +inputset = list(range(20)) +my_module = MyModule.compile({"inc": inputset, "dec": inputset}) +my_module.server.save("test.zip", via_mlir=True) diff --git a/frontends/concrete-rust/test/python/test_tfhers.py b/frontends/concrete-rust/test/python/test_tfhers.py new file mode 100644 index 000000000..4bc4d8390 --- /dev/null +++ b/frontends/concrete-rust/test/python/test_tfhers.py @@ -0,0 +1,35 @@ +from concrete import fhe +from concrete.fhe import tfhers + +TFHERS_UINT_8_3_2_4096 = tfhers.TFHERSIntegerType( + False, + bit_width=8, + carry_width=3, + msg_width=2, + params=tfhers.CryptoParams( + lwe_dimension=909, + glwe_dimension=1, + polynomial_size=4096, + pbs_base_log=15, + pbs_level=2, + lwe_noise_distribution=0, + glwe_noise_distribution=2.168404344971009e-19, + encryption_key_choice=tfhers.EncryptionKeyChoice.BIG, + ), +) + +@fhe.module() +class MyModule: + + @fhe.function({"x": "encrypted", "y":"encrypted"}) + def my_func(x, y): + x = tfhers.to_native(x) + y = tfhers.to_native(y) + return tfhers.from_native(x + y, TFHERS_UINT_8_3_2_4096) + +def t(v): + return tfhers.TFHERSInteger(TFHERS_UINT_8_3_2_4096, v) + +inputset = [(t(0), t(0)), (t(2**6), t(2**6))] +my_module = MyModule.compile({"my_func": inputset}) +my_module.server.save("test_tfhers.zip", via_mlir=True) diff --git a/frontends/concrete-rust/test/src/default.rs b/frontends/concrete-rust/test/src/default.rs new file mode 100644 index 000000000..a83f8b77a --- /dev/null +++ b/frontends/concrete-rust/test/src/default.rs @@ -0,0 +1,41 @@ +mod precompile { + use concrete_macro::from_concrete_python_export_zip; + from_concrete_python_export_zip!("src/test.zip"); +} + +#[cfg(test)] +mod test { + use super::precompile; + use concrete::common::{ClientKeyset, ServerKeyset, Tensor, TransportValue}; + + #[test] + fn test() { + let mut secret_csprng = concrete::common::SecretCsprng::new(0u128); + let mut encryption_csprng = concrete::common::EncryptionCsprng::new(0u128); + let keyset = precompile::KeysetBuilder::new() + .generate(secret_csprng.pin_mut(), encryption_csprng.pin_mut()); + let client_keyset = keyset.get_client(); + let serialized_client_keyset = client_keyset.serialize(); + let deserialized_client_keyset = + ClientKeyset::deserialize(serialized_client_keyset.as_slice()); + let server_keyset = keyset.get_server(); + let serialized_server_keyset = server_keyset.serialize(); + let deserialized_server_keyset = + ServerKeyset::deserialize(serialized_server_keyset.as_slice()); + let mut dec_client = precompile::client::dec::ClientFunction::new( + &deserialized_client_keyset, + encryption_csprng, + ); + let mut dec_server = precompile::server::dec::ServerFunction::new(); + let input = Tensor::new(vec![5], vec![]); + let prepared_input = dec_client.prepare_inputs(input); + let serialized_input = prepared_input.serialize(); + let deserialized_input = TransportValue::deserialize(serialized_input.as_slice()); + let output = dec_server.invoke(&deserialized_server_keyset, deserialized_input); + let serialized_output = output.serialize(); + let deserialized_output = TransportValue::deserialize(serialized_output.as_slice()); + let processed_output = dec_client.process_outputs(deserialized_output); + assert_eq!(processed_output.values(), [4]); + assert_eq!(processed_output.dimensions().len(), 0); + } +} diff --git a/frontends/concrete-rust/test/src/lib.rs b/frontends/concrete-rust/test/src/lib.rs index d29e64b93..b0eef7a12 100644 --- a/frontends/concrete-rust/test/src/lib.rs +++ b/frontends/concrete-rust/test/src/lib.rs @@ -1,36 +1,2 @@ -mod precompile{ - use concrete_macro::from_concrete_python_export_zip; - from_concrete_python_export_zip!("src/test.zip"); -} - -#[cfg(test)] -mod test { - use concrete::common::{ClientKeyset, ServerKeyset, Tensor, TransportValue}; - use crate::precompile; - - #[test] - fn test() { - let mut secret_csprng = concrete::common::SecretCsprng::new(0u128); - let mut encryption_csprng = concrete::common::EncryptionCsprng::new(0u128); - let keyset = precompile::new_keyset(secret_csprng.pin_mut(), encryption_csprng.pin_mut()); - let client_keyset = keyset.get_client(); - let serialized_client_keyset = client_keyset.serialize(); - let deserialized_client_keyset = ClientKeyset::deserialize(serialized_client_keyset.as_slice()); - let server_keyset = keyset.get_server(); - let serialized_server_keyset = server_keyset.serialize(); - let deserialized_server_keyset = ServerKeyset::deserialize(serialized_server_keyset.as_slice()); - let mut dec_client = - precompile::client::dec::ClientFunction::new(&deserialized_client_keyset, encryption_csprng); - let mut dec_server = precompile::server::dec::ServerFunction::new(); - let input = Tensor::new(vec![5], vec![]); - let prepared_input = dec_client.prepare_inputs(input); - let serialized_input = prepared_input.serialize(); - let deserialized_input = TransportValue::deserialize(serialized_input.as_slice()); - let output = dec_server.invoke(&deserialized_server_keyset, deserialized_input); - let serialized_output = output.serialize(); - let deserialized_output = TransportValue::deserialize(serialized_output.as_slice()); - let processed_output = dec_client.process_outputs(deserialized_output); - assert_eq!(processed_output.values(), [4]); - assert_eq!(processed_output.dimensions().len(), 0); - } -} +mod default; +mod tfhers; diff --git a/frontends/concrete-rust/test/src/main.rs b/frontends/concrete-rust/test/src/main.rs index 16eafa253..defbde161 100644 --- a/frontends/concrete-rust/test/src/main.rs +++ b/frontends/concrete-rust/test/src/main.rs @@ -1,22 +1,27 @@ -use concrete::common::Tensor; +use tfhe::prelude::{FheDecrypt, FheEncrypt}; +use tfhe::shortint::parameters::v0_10::classic::gaussian::p_fail_2_minus_64::ks_pbs::V0_10_PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M64; +use tfhe::{generate_keys, FheUint8}; mod precompile { use concrete_macro::from_concrete_python_export_zip; - from_concrete_python_export_zip!("src/test.zip"); + from_concrete_python_export_zip!("src/test_tfhers.zip"); } fn main() { let mut secret_csprng = concrete::common::SecretCsprng::new(0u128); let mut encryption_csprng = concrete::common::EncryptionCsprng::new(0u128); - let keyset = precompile::new_keyset(secret_csprng.pin_mut(), encryption_csprng.pin_mut()); - let client_keyset = keyset.get_client(); + let config = tfhe::ConfigBuilder::with_custom_parameters( + V0_10_PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M64, + ); + let (client_key, _) = generate_keys(config); + let keyset = precompile::KeysetBuilder::new() + .with_key_for_my_func_0_arg(&client_key) + .generate(secret_csprng.pin_mut(), encryption_csprng.pin_mut()); let server_keyset = keyset.get_server(); - let mut dec_client = - precompile::client::dec::ClientFunction::new(&client_keyset, encryption_csprng); - let mut dec_server = precompile::server::dec::ServerFunction::new(); - let input = Tensor::new(vec![5], vec![]); - let prepared_input = dec_client.prepare_inputs(input); - let output = dec_server.invoke(&server_keyset, prepared_input); - let processed_output = dec_client.process_outputs(output); - println!("{:?}", processed_output.values()); + let mut server = precompile::server::my_func::ServerFunction::new(); + let arg_0 = FheUint8::encrypt(6u8, &client_key); + let arg_1 = FheUint8::encrypt(4u8, &client_key); + let output = server.invoke(&server_keyset, arg_0, arg_1); + let decrypted: u8 = output.decrypt(&client_key); + assert_eq!(decrypted, 10); } diff --git a/frontends/concrete-rust/test/src/test.zip b/frontends/concrete-rust/test/src/test.zip index 393f0be539af0341e094804089cc34601619e154..3b52de2983cdcd0d9321f6fc0790efc2335bb291 100644 GIT binary patch delta 2184 zcmZXVc{~%28^@RB%rr;kEXQ2AS7DTU2+5IrjYXNh%*~K<{G{B+$dIIIj=4w55tAHq z7P6e9Tv;PUjYNi7*@#7cB`J`FN4=@vW0J+4OU4&XJCV1PpW#MQq_-i{87roJUd}x| zv-gpbv4NX<>~8f!4PK8#9&EBvS{$t%)?&G|(4JHR?u1@BJmh4js3U$^D+$|vZGxlX z^4p?@Mj1uHVd+!B64)fDC|jaI__!7;Ty5tS(kZ@2N-{o@%EUR*#r zqgvG9hknMOky9icv=NoZG-GQ;c1Na;&}q^gOx$d6!T}=;YbFkRFO1c=t( zU78aY>T%*A_9|xjehVxC6&R`#Ed~M2qQ5l6b(|?NK8M-@jFCijA7~VIum&W8KTVM) zv(5eCHnquZ8p|D!6fQr;tVBUz8cDsBU97!su{f%&{mcHo=xW3an&dWy0hC@!BZ>^U z6#sFIk3MB+e<1bav1h+;U{70d{ATCp@o2Cs?C2lZG3t8*hYGh7MzAmkA2f{%#;6@e z@_r$5c$Ld&aUA0!PIz%PK}+zd%f^UkGx2N;qu;>ui2mozVtV_p&?z3F<3`k{(!^vd`6 zm*=tJ)sA~X2QcXw?v0<|8gVi(W;0v?$_%pr>sNWDhM>lg5WIKLGYG)_+_up&^SV3( zzfh=n2^T4XoElU+D3chfUtixEs#8W+XI&ZC!;LHUwd}F~OMi&qBrzuD2mo-66#(Eq zRHds&fNPKkM%mlT0~;`qX>H!23ht)w3oN;)N~Fq6BgE>q3EN^u2DC1r&p6!E!k*S< z{Ig6SW}*&vg!}2$pFijetF)BCasvvDmEuVxm+8*xoEk^v=LOY7_jQ-tFH7%LDw^l2 z%8aF~o~CFPIYD2CqMU%jyX9|Yy6Rd$h$j4xaF=k85Kl-XBokIcl$XAU28Uku^`Qc4P$ME1ZhrQ53l45G+q@fW}>kPr8;Q5hnc42}H z_~*5|YRVr^r{G_k%=(fDz}B`I+O-Hm+2c*2Gbf8*X9Qk^dJwNi#Tx29J>H*ru;E(| z134s(_@ &DBQBUoQ4aKT za}88R2m1Ojv#rwzhgWJ#oEyJTA5GDjndw{nKF7D_?P?Kc{V*>34PlA5b@v&oQxGpw zMGc?Q3iyb*y*+?Mm@z;NEKfHzpkIP8pzCTsZnx(=$vs+hp(J_Uv7B@-RtCC~B6{chvlxG07aC7D z%{VwrYWL`{4YwIR!K9y){DkOM zP)Ml*8t;-+&Y%kSr3!At;=A0{1}R$ChP7pv(no5e!!=>}05qztTVAkHxvL-+%p)cm z_Jk$Qhtn)!4N_`+8-Wx#TVz+=|Jg92+qNlfNSm@G%GX|`husH%72q2RiM*q?ovoL5 z`d#(cHPS2l^4{H{tYPj@SwRrAXPS0d8`W6MKHF|NE}SV zStwyf&M_*dw#)(kt+bFVruC?)afrH@tMsDGu_C^VT>nA6cHPD;PNrWTe9dAR7gO9O zo&aDfyc1cI#Fo!t3FBM=oNbJc1^t^b{8VZ}e$aUSClwaW1x$mg72|nsg%Z){wWZoA z3!|;)U2vrCJ5rO(>2k|l_?5ZEC_~rA;Rhk$%5MfH1d>jt=fKp%9b#UuDLObskZ^me z`RXd(^J`>4n~Uuvm8eKc`?JyNMW+P{oFMaHWyj+A1$aNS=g%aAM&aVZ>gozp+^3q< zH#*m~8=Y!vEZvQQ?7D`kDoAZs=0{irfd4pp<`Y(pUzVOZ$*TXGr)QpJgJOT#{O`~? w{OZ5sY0S~1|3Xl}Z_EE9hW}3z4gvNbX;_)Fvi)jdJxq;5%9(Q*2qi-G;GM}k#1XKK9wgFvyuq+%JsVKQq`4|0%aR!Z@sq?<{{s;5z#P&#ZU2%qz_ zz(5>k*hk{N4tEAUk@ch>V&gI6m8~Qu7Q9Q1t{C4jZE-hhWGe6NcvkL1Gs+RV=gVOU zWA~fborKtx1KUvBWCat)8aG^K?}BgWTGFQ05+bvobxXD=n>B@^=;b7D0n_k=^BdPb zqV*DX3&=~V{Q-7*{xDR*DM6VAzw$mzgi6})_V7IMy1cG@i-1z^dDq_FE{TBb08Ha) zc8z-vVy?T{RG1i2W{Y^Sr|Cku=Gn%p{ng>-UmSRR1`}J~gX1sr$*(m-7W_?b(U@p| zwt(a%NNi~YzTlBTzrlL}t_WuYEGn(YDv}_{NB^g2(t4}Y+8xfS34VT}+H2@pJZaAb z<810)Qc+0$bI2dJjEZm9w&)6j%Rb7*z>M%vi;%L6it9LA4Xfl!8uk}DzHqcMjLey4 zI$-Vy!%*N5m2G8aI8C{{S~)|;%~1$992d5}ZA;9D2(35o0Kl$3f6tlm72two@!%S3 zNvMj24uDeC$?fAD;kB^q>VfCT2H;bBS(dF!M5)vOVrt3J2j`($Ml_@z9f~vMJ>CP(R?^6yO(3)+^VAPWV+Ngd) zQH-WV3|sWxH=}s&tG$>)x5F!uHL?xogcA`26X3y})eOyCv;Io>r}V+ia)-=bXS&HN zL)x}s&Eba>Oa?-cTscU*C^gVQOffaTpo^S;PGg~v^B~{%ccmsVC3PY$lrwIoLgL+Q zyN1>`iu9`=WP(WSBzjI=j4`dTuVeA?7&<<#`OP97&S~_AJAtx|P9xrGC zbY1Ek_X9Z>PlpTm8wtpzETB9N+T@WegHWiaYvyB~+hJl{p7F7Eo{q=Lt8}K7RPZig zZ@V7wa1@iZ&WL}hy<@eqLTR{4?UL{6%Ni0J9(`-bxOIM5fE5ybb7Q$OYUI7j_3yQ% zqfap%%yUmK8>45^Uf-4@lH}L03S!Ri*iT=ft6Bwr4N~%c!f@D?!|(p_aur;`uc{uxKaJ(eooM>p z8m^Ei>9F>wa0%nPy4NJUn#RB_DMFAPo_S;&8KTmVSIg5fw(_>)%obu7zfAc)fmG8&Q TZ7lf(j!%3?yK|)3{KxKJYAcY{ diff --git a/frontends/concrete-rust/test/src/test_tfhers.zip b/frontends/concrete-rust/test/src/test_tfhers.zip new file mode 100644 index 0000000000000000000000000000000000000000..ad361e63380d169eba32f123fe86c0a739d5e141 GIT binary patch literal 3633 zcmZ`+XH=6}*A2Z(2bB(C0BI^xLXnP0ucAT-LO>Hb3B_b+g3>{wOeg_n=!RkdMQMT} z(i8*?g(O7lRo>$AHslJ?y#%6n7c zz=;~l6?)FDTxEyZX{n4I#I||OdQupqTHBu=3|UE)U~o;wm;AO3-&`?W`h9&UpZU0W zue?3R7+VGIJ1eyHu8OlK0tGy#xy+hz5NFf>_|7!Y zPlIh-m${@Us4NXQm)Da{jZKw=l#P8e?*MzMyuf@;lidMMK3X*I^@DVEkrTkuafmGg z7Vqb0>izOWGPju9nqBR0P~zCN;lAC47(s&y({P~&_5C z56Ky#d`;p|2F&!O^&gz{vgyFnJtjgQ&viUu<6q2a$PR2lf|NLF$`wLyL z&jo{@(xn@c;&mdX1tUD==LNBsf>>fNU&v?o5IGFax7d@pIDO=M%Pb5&v_Im;4lPLT zvqerOG=q|$h`XvWVn8l3e69X|yF{6p8gv^lj2G3<)cmuZDI}Ta`xpGgAC|$U4mGK* znhWi~H1?n+vQ$1I1FuoeBKD^4TWM75o3+D3(beBcaJ=_05>T#}Q6kdsQF@(&Z|aM& z%dyOEf$xVvL|1ER;#SAilwz)_mn2v5&968#H{sH2;}hrZ`rQZs_kESw zb~Bs3+y1p@zSgkqIqzJ}$0g}lNfq`+JuSW3*p`w=|0#oShIf^ir>aJ=Etk;GTn|Ez zL9!&y%_AOyWGzGN7FzM*3d)Mdu-Y#@96AODV*Nv300DQ@c1#vlt|=^G=l)dE!$gXJ zzw})=ek#@fe0_bpzfJ}I?1}Zq1I(CGZ}S1mzgzmFo(g9i69BLX{Ja0XU?Eg7= zP*5FxjJ$;(TXcYcO+SFY4b35_y@J88ne>{vtivAW)wc_;ag*7e_9`w!d`>^JYybRm zjeGhsHYyA<4r#$+%8664F_60MIecD%}|)$NlBdugoenX)_UYVQ&) zQuikh1Dz(GV-}o$1v1uygM%AjS*^e77=WZ)% z)H8qEK!`R}B_?*Ox-8546aDC$^pK)H))u!1zG)>pYKs!Vc(Nzf{z}w_I6<`Iq93$RUanrzJ*ZS{O$; z2*Ms2$R{4EZ`|+cJ*XhZ)1#TF=HXn_v&-tEUTC04K};k)b<{>-w9KQxv$btq-#0ou z2*9NTuxw)|FyjC=bhl#=Z4|m_fG*hb9ufBsNlznASZcUTpPgt+aUwG{&}nD)NrN(jktm# z%`lfV_4K-a9r~@R{9EwAuvksEQDCvSChv67NF)i_igfuHQgf-|ZB!~AHDNc#CRMs$ zR$N)WL3On8!B%dhTL|JRiD6HRP|mwH+%rDF&s)&PU7oogF2O-;mSv%za!vZ$YrAIn z2Vc3cs5B;KIa-#|Wlm$D1w@0Tfw~B}xqCQ#n_oA-k!AW-EN^jdqq$sjaqrpU2{z6V z_rc%+Pq2qsg_2~(RS)qkw!6B7Q?nHFY*4u9)pDZx8(y)m>Pug>U7kDx#_;OOb@yea z#0xaUCj#G9Br;T(Xw-K0PMbcF$3OoVHa(gjRs$aNn^Fq>Kt&3(Z`bMh*5oadeaxy7 zj&zd=y6g@k#vAfbVG85Z2`M+B5qKrP!QR-$j{g%xm)cn(r z`IE^}8k3t(Hk`lD<~zi4mnB2d=C%rx9Xh7)hpf}!;qIo?S@2>cSvF!`qT@-xdiw9a zq4X3QeD@UzBC?l;3&}`IPBrIg{^3L+tFM{jq(eW=+^0;rSNUpvXDh{i1v+ zgU>Hfr_A)^SeuE#T8mmJ^!v3Ih5i>{^lncB#)4c}hJPYh(&!tp?x3DAiFqtn+cGEX zSbc{>sap?{E6nvURL<4a<#pBReg#tBLf}SoyquD}Pj(HEpEXXBJyPG`R*bYsfKfCw z69oBVqSod+@&lpJ_WkscM-27Tonh3m5sOdM`yWO;-&4hBQX~8yU5;a)__NVjhe@#B zVAqgaew5{0t>dw*&CTIB9PUGEW0;e7u6_2dKWY-JZKVq5&$i*QGJ5z<-mvBjMnwO> zT)JA5=(~${j@7^hKtxGV*I48{$Ux@H$zH5)B@LsBGcOKnlJZg}W?S4i>wb9LyYL3t zL?L**)xdht7r{6ml3cGcdM-Q`S^LZ5J)>HE%T1PkHQ}8p4xa!lPS2q}Y5@I>s(h{% zgD*XkJsP02wwhk`%Gxb}V2z+C=XQnt;a+)4!qwLJd^v7FU8WSE*qr(9gENYFWxi3B zH0Ws@FWtJRBjlb{r{!4p7a?>JmyJGs2Dgd;bTy-=YFZ20cA_@;SC#G8Kj66dYsg_&hGIHf|a@dP0e>?avx3_b!!Ephc z|Ih@`u#f8fa#TA28Wm}3cUOYKZN`zVQ#&%Z>pE*VV4@doN`fl{&4f!Tbnm>~Iufxu zfYH-&N}1Gf#~d-P*<0#96|%|k*Lo2`x7xdL@4M)4O`1#92>VH$>B?=w%4oZm#3nU5 ztKh5t{#&949T$N&HET_75JwNIS!SD7KTxlMlu}$SRuQ`hU-doh1|V&k`ou)TBAW*`JOO F@PE_|nP&h1 literal 0 HcmV?d00001 diff --git a/frontends/concrete-rust/test/src/tfhers.rs b/frontends/concrete-rust/test/src/tfhers.rs new file mode 100644 index 000000000..bd970f598 --- /dev/null +++ b/frontends/concrete-rust/test/src/tfhers.rs @@ -0,0 +1,51 @@ +mod precompile { + use concrete_macro::from_concrete_python_export_zip; + from_concrete_python_export_zip!("src/test_tfhers.zip"); +} + +#[cfg(test)] +mod test { + use super::precompile; + use tfhe::prelude::{FheDecrypt, FheEncrypt}; + use tfhe::shortint::parameters::v0_10::classic::gaussian::p_fail_2_minus_64::ks_pbs::{V0_10_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, V0_10_PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M64}; + use tfhe::{generate_keys, FheUint8}; + + #[test] + fn test() { + let mut secret_csprng = concrete::common::SecretCsprng::new(0u128); + let mut encryption_csprng = concrete::common::EncryptionCsprng::new(0u128); + let config = tfhe::ConfigBuilder::with_custom_parameters( + V0_10_PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M64, + ); + let (client_key, _) = generate_keys(config); + let keyset = precompile::KeysetBuilder::new() + .with_key_for_my_func_0_arg(&client_key) + .generate(secret_csprng.pin_mut(), encryption_csprng.pin_mut()); + let server_keyset = keyset.get_server(); + let mut server = precompile::server::my_func::ServerFunction::new(); + let arg_0 = FheUint8::encrypt(6u8, &client_key); + let arg_1 = FheUint8::encrypt(4u8, &client_key); + let output = server.invoke(&server_keyset, arg_0, arg_1); + let decrypted: u8 = output.decrypt(&client_key); + assert_eq!(decrypted, 10); + } + + #[test] + #[should_panic] + fn test_reset_key() { + let mut secret_csprng = concrete::common::SecretCsprng::new(0u128); + let mut encryption_csprng = concrete::common::EncryptionCsprng::new(0u128); + let config1 = tfhe::ConfigBuilder::with_custom_parameters( + V0_10_PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M64, + ); + let config2 = tfhe::ConfigBuilder::with_custom_parameters( + V0_10_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + ); + let (client_key1, _) = generate_keys(config1); + let (client_key2, _) = generate_keys(config2); + precompile::KeysetBuilder::new() + .with_key_for_my_func_0_arg(&client_key1) + .with_key_for_my_func_0_arg(&client_key2) + .generate(secret_csprng.pin_mut(), encryption_csprng.pin_mut()); + } +} diff --git a/tools/concrete-protocol/src/concrete-protocol.capnp b/tools/concrete-protocol/src/concrete-protocol.capnp index f13ae3bc6..ee6837544 100644 --- a/tools/concrete-protocol/src/concrete-protocol.capnp +++ b/tools/concrete-protocol/src/concrete-protocol.capnp @@ -1,10 +1,10 @@ # Concrete Protocol -# -# The following document contains a programatic description of a communication protocol to store and -# exchange data with applications of the concrete framework. +# +# The following document contains a programatic description of a communication protocol to store and +# exchange data with applications of the concrete framework. # # Todo: -# + Use `storagePrecision` instead of `integerPrecision` to better differentiate between the +# + Use `storagePrecision` instead of `integerPrecision` to better differentiate between the # message and the storage. # + Use `storageInfo` instead of `rawInfo`. @@ -17,20 +17,20 @@ $Cxx.namespace("concreteprotocol"); ######################################################################################### Commons ## enum KeyType { - # Secret Keys can be drawn from different ranges of values, using different distributions. This + # Secret Keys can be drawn from different ranges of values, using different distributions. This # enumeration encodes the different supported ways. - + binary @0; # Uniform sampling in {0, 1} ternary @1; # Uniform sampling in {-1, 0, 1} } struct Modulus { - # Ciphertext operations are performed using modular arithmetic. Depending on the use, different - # modulus can be used for the operations. This structure encodes the different supported ways. + # Ciphertext operations are performed using modular arithmetic. Depending on the use, different + # modulus can be used for the operations. This structure encodes the different supported ways. modulus :union $Cxx.name("mod") { # The modulus expected to be used. - + native @0 :NativeModulus; powerOfTwo @1 :PowerOfTwoModulus; integer @2 :IntegerModulus; @@ -38,50 +38,50 @@ struct Modulus { } struct NativeModulus{ - # Operations are performed using the modulus of the integers used to store the ciphertexts. - # + # Operations are performed using the modulus of the integers used to store the ciphertexts. + # # Note: - # The bitwidth of the integer storage is represented implicitly here, and must be grabbed from + # The bitwidth of the integer storage is represented implicitly here, and must be grabbed from # the rest of the description. - # + # # Example: # 2^64 when the ciphertext is stored using 64 bits integers. } struct PowerOfTwoModulus{ - # Operations are performed using a modulus that is a power of two. - # - # Example: + # Operations are performed using a modulus that is a power of two. + # + # Example: # 2^n for any n between 0 and the bitwidth of the integer used to store the ciphertext. - + power @0 :UInt32; # The power used to raise 2. } struct IntegerModulus{ - # Operations are performed using a modulus that is an arbitrary integer. + # Operations are performed using a modulus that is an arbitrary integer. # # Example: - # n for any n between 0 and 2^N where N is the bitwidth of the integer used to store the + # n for any n between 0 and 2^N where N is the bitwidth of the integer used to store the # ciphertext. - + modulus @0 :UInt32; # The value used as modulus. } struct Shape{ - # Scalar and tensor values are represented by the same types. This structure contains a - # description of the shape of value. + # Scalar and tensor values are represented by the same types. This structure contains a + # description of the shape of value. # # Note: # If the dimensions vector is empty, the message is interpreted as a scalar. - + dimensions @0 :List(UInt32); # The dimensions of the value. } struct RawInfo{ - # A value exchanged at the boundary between two parties of a computation will be transmitted as a - # binary payload containing a tensor of integers. This payload will first have to be parsed to a - # tensor of proper shape, signedness and precision before being pre-processed and passed to the - # computation. This structure represents the informations needed to parse this payload into the + # A value exchanged at the boundary between two parties of a computation will be transmitted as a + # binary payload containing a tensor of integers. This payload will first have to be parsed to a + # tensor of proper shape, signedness and precision before being pre-processed and passed to the + # computation. This structure represents the informations needed to parse this payload into the # expected tensor. shape @0 :Shape; # The shape of the tensor. @@ -93,8 +93,8 @@ struct Payload{ # A structure carrying a binary payload. # # Note: - # There is a limit to the maximum size of a Data type. For this reason, large payloads must be - # split into several blobs stored sequentially in a list. All but the last blobs store the + # There is a limit to the maximum size of a Data type. For this reason, large payloads must be + # split into several blobs stored sequentially in a list. All but the last blobs store the # maximum amount of data allowed by Data, and the last store the remainder. data @0 :List(Data); # The binary data of the payload } @@ -102,12 +102,12 @@ struct Payload{ ##################################################################################### Compression ## enum Compression{ - # Evaluation keys and ciphertexts can be compressed when transported over the wire. This + # Evaluation keys and ciphertexts can be compressed when transported over the wire. This # enumeration encodes the different compressions that can be used to compress scheme objects. - # + # # Note: # Not all compressions are available for every types of evaluation keys or ciphertexts. - + none @0; # No compression is used. seed @1; # The mask is represented by the seed of a csprng. paillier @2; # An output lwe ciphertext transciphered to the paillier cryptosystem. @@ -116,22 +116,22 @@ enum Compression{ ################################################################################# LWE secret keys ## struct LweSecretKeyParams { - # A secret key is parameterized by a few quantities of cryptographic importance. This structure + # A secret key is parameterized by a few quantities of cryptographic importance. This structure # represents those parameters. - + lweDimension @0 :UInt32; # The LWE dimension, e.g. the length of the key. integerPrecision @1 :UInt32; # The bitwidth of the integers used for storage. keyType @2 :KeyType; # The kind of distribution used to sample the key. } struct LweSecretKeyInfo { - # A secret key value is uniquely described by cryptographic parameters and an identifier. This + # A secret key value is uniquely described by cryptographic parameters and an identifier. This # structure represents this description of a secret key. - # + # # Note: - # Secret keys with same parameters are allowed to co-exist in a program, as long as they + # Secret keys with same parameters are allowed to co-exist in a program, as long as they # have different ids. - + id @0 :UInt32; # The identifier of the key. params @1 :LweSecretKeyParams; # The cryptographic parameters of the keys. } @@ -139,7 +139,7 @@ struct LweSecretKeyInfo { struct LweSecretKey { # A secret key value is a payload and a description to interpret this payload. This structure # can be used to store and communicate a secret key. - + info @0 :LweSecretKeyInfo; # The description of the secret key. payload @1 :Payload; # The payload } @@ -147,12 +147,12 @@ struct LweSecretKey { ############################################################################## LWE bootstrap keys ## struct LweBootstrapKeyParams { - # A bootstrap key is parameterized by a few quantities of cryptographic importance. This structure + # A bootstrap key is parameterized by a few quantities of cryptographic importance. This structure # represents those parameters. # # Note: - # For now, only keys with the same input and output key types can be represented. - + # For now, only keys with the same input and output key types can be represented. + levelCount @0 :UInt32; # The number of levels of the ciphertexts. baseLog @1 :UInt32; # The logarithm of the base of the ciphertext. glweDimension @2 :UInt32; # The dimension of the ciphertexts. @@ -167,7 +167,7 @@ struct LweBootstrapKeyParams { struct LweBootstrapKeyInfo { # A bootstrap key value is uniquely described by cryptographic parameters and a few application # related quantities. This structure represents this description of a bootstrap key. - # + # # Note: # Bootstrap keys with same parameters, compression, input and output id, are allowed to co-exist # in a program as long as they have different ids. @@ -180,9 +180,9 @@ struct LweBootstrapKeyInfo { } struct LweBootstrapKey { - # A bootstrap key value is a payload and a description to interpret this payload. This structure + # A bootstrap key value is a payload and a description to interpret this payload. This structure # can be used to store and communicate a bootstrap key. - + info @0 :LweBootstrapKeyInfo; # The description of the bootstrap key. payload @1 :Payload; # The payload. } @@ -190,11 +190,11 @@ struct LweBootstrapKey { ############################################################################## LWE keyswitch keys ## struct LweKeyswitchKeyParams { - # A keyswitch key is parameterized by a few quantities of cryptographic importance. This structure + # A keyswitch key is parameterized by a few quantities of cryptographic importance. This structure # represents those parameters. # # Note: - # For now, only keys with the same input and output key types can be represented. + # For now, only keys with the same input and output key types can be represented. levelCount @0 :UInt32; # The number of levels of the ciphertexts. baseLog @1 :UInt32; # The logarithm of the base of ciphertexts. @@ -209,7 +209,7 @@ struct LweKeyswitchKeyParams { struct LweKeyswitchKeyInfo { # A keyswitch key value is uniquely described by cryptographic parameters and a few application # related quantities. This structure represents this description of a keyswitch key. - # + # # Note: # Keyswitch keys with same parameters, compression, input and output id, are allowed to co-exist # in a program as long as they have different ids. @@ -222,9 +222,9 @@ struct LweKeyswitchKeyInfo { } struct LweKeyswitchKey { - # A keyswitch key value is a payload and a description to interpret this payload. This structure + # A keyswitch key value is a payload and a description to interpret this payload. This structure # can be used to store and communicate a keyswitch key. - + info @0 :LweKeyswitchKeyInfo; # The description of the keyswitch key. payload @1 :Payload; # The payload. } @@ -232,11 +232,11 @@ struct LweKeyswitchKey { ########################################################################## Packing keyswitch keys ## struct PackingKeyswitchKeyParams { - # A packing keyswitch key is parameterized by a few quantities of cryptographic importance. This + # A packing keyswitch key is parameterized by a few quantities of cryptographic importance. This # structure represents those parameters. - # + # # Note: - # For now, only keys with the same input and output key types can be represented. + # For now, only keys with the same input and output key types can be represented. levelCount @0 :UInt32; # The number of levels of the ciphertexts. baseLog @1 :UInt32; # The logarithm of the base of the ciphertexts. @@ -251,12 +251,12 @@ struct PackingKeyswitchKeyParams { } struct PackingKeyswitchKeyInfo { - # A packing keyswitch key value is uniquely described by cryptographic parameters and a few - # application related quantities. This structure represents this description of a packing + # A packing keyswitch key value is uniquely described by cryptographic parameters and a few + # application related quantities. This structure represents this description of a packing # keyswitch key. - # + # # Note: - # Packing keyswitch keys with same parameters, compression, input and output id, are allowed to + # Packing keyswitch keys with same parameters, compression, input and output id, are allowed to # co-exist in a program as long as they have different ids. id @0 :UInt32; # The identifier of the packing keyswitch key. @@ -267,9 +267,9 @@ struct PackingKeyswitchKeyInfo { } struct PackingKeyswitchKey { - # A packiing keyswitch key value is a payload and a description to interpret this payload. This + # A packiing keyswitch key value is a payload and a description to interpret this payload. This # structure can be used to store and communicate a packing keyswitch key. - + info @0 :PackingKeyswitchKeyInfo; # The description of the packing keyswitch key. payload @1 :Payload; # The payload. } @@ -277,7 +277,7 @@ struct PackingKeyswitchKey { ######################################################################################### Keysets ## struct KeysetInfo { - # The keyset needed for an application can be described by an ensemble of descriptions of the + # The keyset needed for an application can be described by an ensemble of descriptions of the # different keys used in the program. This structure represents such a description. lweSecretKeys @0 :List(LweSecretKeyInfo); # The secret key descriptions. @@ -287,25 +287,25 @@ struct KeysetInfo { } struct ServerKeyset { - # A server keyset is represented by an ensemble of evaluation key values. This structure allows to + # A server keyset is represented by an ensemble of evaluation key values. This structure allows to # store and communicate such a keyset. - + lweBootstrapKeys @0 :List(LweBootstrapKey); # The bootstrap key values. lweKeyswitchKeys @1 :List(LweKeyswitchKey); # The keyswitch key values. packingKeyswitchKeys @2 :List(PackingKeyswitchKey); # The packing keyswitch key values. } struct ClientKeyset { - # A client keyset is represented by an ensemble of secret key values. This structure allows to + # A client keyset is represented by an ensemble of secret key values. This structure allows to # store and communicate such a keyset. - + lweSecretKeys @0 :List(LweSecretKey); # The secret key values. } struct Keyset { - # A complete application keyset is the union of a server keyset, and a client keyset. This + # A complete application keyset is the union of a server keyset, and a client keyset. This # structure allows to store and communicate such a keyset. - + server @0 :ServerKeyset; client @1 :ClientKeyset; } @@ -314,18 +314,18 @@ struct Keyset { struct EncodingInfo { # A value in an fhe program can encode various kind of informations, be it encrypted or not. - # To correctly communicate, the different parties participating in the execution of the program + # To correctly communicate, the different parties participating in the execution of the program # must share informations about what encoding is used for values exchanged at their boundaries. # This structure represents such informations. - # + # # Note: - # The shape field is expected to contain the _abstract_ shape of the value. This means that for - # an encrypted value, the shape must not contain informations about the shape of the - # ciphertext(s) themselves. Said differently, the shape must be the one that would be used if + # The shape field is expected to contain the _abstract_ shape of the value. This means that for + # an encrypted value, the shape must not contain informations about the shape of the + # ciphertext(s) themselves. Said differently, the shape must be the one that would be used if # the value was not encrypted. shape @0 :Shape; # The shape of the value. - encoding :union { + encoding :union { # The encoding for each scalar element of the value. integerCiphertext @1 :IntegerCiphertextEncodingInfo; @@ -336,12 +336,12 @@ struct EncodingInfo { } struct IntegerCiphertextEncodingInfo { - # A ciphertext can be used to represent an integer value. This structure represents the + # A ciphertext can be used to represent an integer value. This structure represents the # informations needed to encode such an integer. width @0 :UInt32; # The bitwidth of the encoded integer. isSigned @1 :Bool; # The signedness of the encoded integer. - mode :union { + mode :union { # The mode used to encode the integer. native @2 :NativeMode; @@ -350,39 +350,39 @@ struct IntegerCiphertextEncodingInfo { } struct NativeMode { - # An integer of width from 1 to 8 bits can be encoded in a single ciphertext natively, by - # being shifted in the most significant bits. This structure represents this integer encoding + # An integer of width from 1 to 8 bits can be encoded in a single ciphertext natively, by + # being shifted in the most significant bits. This structure represents this integer encoding # mode. - } + } struct ChunkedMode { - # An integer of width from 1 to n can be encoded in a set of ciphertexts by chunking the bits + # An integer of width from 1 to n can be encoded in a set of ciphertexts by chunking the bits # of the original integer. This structure represents this integer encoding mode. - + size @0 :UInt32; # The number of chunks to be used. width @1 :UInt32; # The number of bits encoded by each chunks. } struct CrtMode { - # An integer of width 1 to 16 can be encoded in a set of ciphertexts, by decomposing a value + # An integer of width 1 to 16 can be encoded in a set of ciphertexts, by decomposing a value # using a set of pairwise coprimes. This structure represents this integer encoding mode. - + moduli @0 :List(UInt32); # The coprimes used to decompose the original value. } } struct BooleanCiphertextEncodingInfo { - # A ciphertext can be used to represent a boolean value. This structure represents such an + # A ciphertext can be used to represent a boolean value. This structure represents such an # encoding. } struct PlaintextEncodingInfo { - # A cleartext value can be used to represent a plaintext value used in computation with + # A cleartext value can be used to represent a plaintext value used in computation with # ciphertexts. This structure represent such an encoding. } struct IndexEncodingInfo { - # A cleartext value can be used to represent an index value used to index in a tensor of values. + # A cleartext value can be used to represent an index value used to index in a tensor of values. # This structure represent such an encoding. } @@ -391,25 +391,25 @@ struct CircuitEncodingInfo { # name. This structure represents this circuit encoding signature. # # Note: - # The order of the input and output lists matters. The order of values should be the same when + # The order of the input and output lists matters. The order of values should be the same when # executing the circuit. Also, the name is expected to be unique in the program. - + inputs @0 :List(EncodingInfo); # The ordered list of input encoding infos. outputs @1 :List(EncodingInfo); # The ordered list of output encoding infos. name @2 :Text; # The name of the circuit. } struct ProgramEncodingInfo { - # A program encodings is described by the set of circuit encodings. This structure represents + # A program encodings is described by the set of circuit encodings. This structure represents # this ensemble of encoding signatures. - + circuits @0 :List(CircuitEncodingInfo); # The list of the circuit encoding infos. } ###################################################################################### Encryption ## struct LweCiphertextEncryptionInfo { - # The encryption of a cleartext value requires some parameters to operate. This structure + # The encryption of a cleartext value requires some parameters to operate. This structure # represents those parameters. keyId @0 :UInt32; # The identifier of the secret key used to perform the encryption. @@ -435,9 +435,9 @@ struct LweCiphertextTypeInfo { # needed to verify and pre-or-post process this value. # # Note: - # Two shape information are carried in this type. The abstract shape is the shape the tensor - # would have if the values were cleartext. That is, it does not take into account the encryption - # process. The concrete shape is the final shape of the object accounting for the encryption, + # Two shape information are carried in this type. The abstract shape is the shape the tensor + # would have if the values were cleartext. That is, it does not take into account the encryption + # process. The concrete shape is the final shape of the object accounting for the encryption, # that usually add one or more dimension to the object. abstractShape @0 :Shape; # The abstract shape of the value. @@ -448,13 +448,13 @@ struct LweCiphertextTypeInfo { encoding :union { # The encoding of the value stored inside the ciphertext. - integer @5 :IntegerCiphertextEncodingInfo; + integer @5 :IntegerCiphertextEncodingInfo; boolean @6 :BooleanCiphertextEncodingInfo; } } struct PlaintextTypeInfo { - # A plaintext value can flow in and out of a circuit. This structure represents the informations + # A plaintext value can flow in and out of a circuit. This structure represents the informations # needed to verify and pre-or-post process this value. shape @0 :Shape; # The shape of the value. @@ -470,23 +470,23 @@ struct IndexTypeInfo { integerPrecision @1 :UInt32; # The precision of the indexes. isSigned @2 :Bool; # The signedness of the indexes. } - + ############################################################################### Circuit signature ## struct GateInfo { - # A value flowing in or out of a circuit is expected to be of a given type, according to the + # A value flowing in or out of a circuit is expected to be of a given type, according to the # signature of this circuit. This structure represents such a type in a circuit signature. - + rawInfo @0 :RawInfo; # The raw information that raw data must be possible to parse with. typeInfo @1 :TypeInfo; # The type of the value expected at the gate. } struct CircuitInfo { - # A circuit signature can be described completely by the type informations for its input and + # A circuit signature can be described completely by the type informations for its input and # outputs, as well as its name. This structure regroup those informations. - # + # # Note: - # The order of the input and output lists matters. The order of values should be the same when + # The order of the input and output lists matters. The order of values should be the same when # executing the circuit. Also, the name is expected to be unique in the program. inputs @0 :List(GateInfo); # The ordered list of input types. @@ -497,7 +497,7 @@ struct CircuitInfo { struct ProgramInfo { # A complete program can be described by the ensemble of circuit signatures, and the description # of the keyset that go with it. This structure regroup those informations. - + keyset @0 :KeysetInfo; # The informations on the keyset of the program. circuits @1 :List(CircuitInfo); # The informations for the different circuits of the program. } @@ -506,14 +506,14 @@ struct ProgramInfo { struct Value { # A value is the union of a binary payload, raw informations to turn this payload into an integer - # tensor, and typ informations to check and pre-post process values at the boundary of a - # circuit. This structure can be used to store, or communicate a value used during a program + # tensor, and typ informations to check and pre-post process values at the boundary of a + # circuit. This structure can be used to store, or communicate a value used during a program # execution. - # + # # Note: - # The value info is a smaller runtime equivalent of the gate types used in the circuit + # The value info is a smaller runtime equivalent of the gate types used in the circuit # signatures. - + payload @0 :Payload; # The binary payload containing a raw integer tensor. rawInfo @1 :RawInfo; # The informations to parse the binary payload. typeInfo @2 :TypeInfo; # The type of the value. @@ -522,11 +522,9 @@ struct Value { ################################################################################### Public values ## struct PublicArguments { - args @0 :List(Value); + args @0 :List(Value); } struct PublicResults { results @0 :List(Value); } - -