From ee006729963cb8dce023e943f87f41f90a9fa875 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Tue, 24 Jan 2023 15:49:44 +0100 Subject: [PATCH] feat(rust): partially bind CircuitGate, EncryptionGate This adds partial bindings of the CircuitGate, EncrytionGate and Encoding types for the rust frontend --- .../concretelang-c/Support/CompilerEngine.h | 46 ++++++++ compiler/include/concretelang/CAPI/Wrappers.h | 6 + compiler/lib/Bindings/Rust/src/compiler.rs | 108 ++++++++++++++++++ compiler/lib/CAPI/Support/CompilerEngine.cpp | 51 +++++++++ 4 files changed, 211 insertions(+) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 696244e0b..77df44681 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -37,6 +37,9 @@ DEFINE_C_API_STRUCT(LibrarySupport, void); DEFINE_C_API_STRUCT(CompilationOptions, void); DEFINE_C_API_STRUCT(OptimizerConfig, void); DEFINE_C_API_STRUCT(ServerLambda, void); +DEFINE_C_API_STRUCT(Encoding, void); +DEFINE_C_API_STRUCT(EncryptionGate, void); +DEFINE_C_API_STRUCT(CircuitGate, void); DEFINE_C_API_STRUCT(ClientParameters, void); DEFINE_C_API_STRUCT(KeySet, void); DEFINE_C_API_STRUCT(KeySetCache, void); @@ -63,6 +66,9 @@ DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport) DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions) DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig) DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda) +DEFINE_NULL_PTR_CHECKER(circuitGateIsNull, CircuitGate) +DEFINE_NULL_PTR_CHECKER(encodingIsNull, Encoding) +DEFINE_NULL_PTR_CHECKER(encryptionGateIsNull, EncryptionGate) DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters) DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet) DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache) @@ -240,6 +246,46 @@ clientParametersCopy(ClientParameters params); MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params); +/// Returns the number of output circuit gates +MLIR_CAPI_EXPORTED size_t clientParametersOutputsSize(ClientParameters params); + +/// Returns the number of input circuit gates +MLIR_CAPI_EXPORTED size_t clientParametersInputsSize(ClientParameters params); + +/// Returns the output circuit gate corresponding to the index +/// +/// - `index` must be valid. +MLIR_CAPI_EXPORTED CircuitGate +clientParametersOutputCircuitGate(ClientParameters params, size_t index); + +/// Returns the input circuit gate corresponding to the index +/// +/// - `index` must be valid. +MLIR_CAPI_EXPORTED CircuitGate +clientParametersInputCircuitGate(ClientParameters params, size_t index); + +/// Returns the EncryptionGate of the circuit gate. +/// +/// - The returned gate will be null if the gate does not represent encrypted +/// data +MLIR_CAPI_EXPORTED EncryptionGate +circuitGateEncryptionGate(CircuitGate circuit_gate); + +/// Returns the variance of the encryption gate +MLIR_CAPI_EXPORTED double +encryptionGateVariance(EncryptionGate encryption_gate); + +/// Returns the Encoding of the encryption gate. +MLIR_CAPI_EXPORTED Encoding +encryptionGateEncoding(EncryptionGate encryption_gate); + +/// Returns the precision (bit width) of the encoding +MLIR_CAPI_EXPORTED uint64_t encodingPrecision(Encoding encoding); + +MLIR_CAPI_EXPORTED void circuitGateDestroy(CircuitGate gate); +MLIR_CAPI_EXPORTED void encryptionGateDestroy(EncryptionGate gate); +MLIR_CAPI_EXPORTED void encodingDestroy(Encoding encoding); + /// ********** KeySet CAPI ***************************************************** MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params, diff --git a/compiler/include/concretelang/CAPI/Wrappers.h b/compiler/include/concretelang/CAPI/Wrappers.h index 2506209d9..9c637ac39 100644 --- a/compiler/include/concretelang/CAPI/Wrappers.h +++ b/compiler/include/concretelang/CAPI/Wrappers.h @@ -58,6 +58,12 @@ DEFINE_C_API_PTR_METHODS_WITH_ERROR(PublicResult, mlir::concretelang::clientlib::PublicResult) DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationFeedback, mlir::concretelang::CompilationFeedback) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(Encoding, + mlir::concretelang::clientlib::Encoding) +DEFINE_C_API_PTR_METHODS_WITH_ERROR( + EncryptionGate, mlir::concretelang::clientlib::EncryptionGate) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(CircuitGate, + mlir::concretelang::clientlib::CircuitGate) #undef DEFINE_C_API_PTR_METHODS_WITH_ERROR diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 991dcc286..6afa414fb 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -49,6 +49,9 @@ impl_CStructErrorMsg! {[ ffi::LibraryCompilationResult, ffi::LibrarySupport, ffi::ServerLambda, + ffi::CircuitGate, + ffi::EncryptionGate, + ffi::Encoding, ffi::ClientParameters, ffi::KeySet, ffi::KeySetCache, @@ -219,6 +222,18 @@ def_CStructWrapper! { serverLambdaIsNull, serverLambdaDestroy, }, + CircuitGate => { + circuitGateIsNull, + circuitGateDestroy, + }, + EncryptionGate => { + encryptionGateIsNull, + encryptionGateDestroy, + }, + Encoding => { + encodingIsNull, + encodingDestroy, + }, ClientParameters => { clientParametersIsNull, clientParametersDestroy, @@ -578,7 +593,71 @@ impl LibrarySupport { impl ServerLambda {} +impl CircuitGate { + pub fn encryption_gate(self) -> Option { + let inner = unsafe { ffi::circuitGateEncryptionGate(self._c) }; + let gate = EncryptionGate::wrap(inner); + if gate.is_null() { + None + } else { + Some(gate) + } + } +} + +impl EncryptionGate { + pub fn encoding(self) -> Encoding { + let inner = unsafe { ffi::encryptionGateEncoding(self._c) }; + + Encoding::wrap(inner) + } + + pub fn variance(&self) -> f64 { + unsafe { ffi::encryptionGateVariance(self._c) } + } +} + +impl Encoding { + pub fn precision(&self) -> u64 { + unsafe { ffi::encodingPrecision(self._c) } + } +} + impl ClientParameters { + pub fn num_inputs(&self) -> usize { + unsafe { ffi::clientParametersInputsSize(self._c) } + .try_into() + .unwrap() + } + + pub fn input(&self, index: usize) -> Option { + if index >= self.num_inputs() { + None + } else { + let gate = unsafe { + ffi::clientParametersInputCircuitGate(self._c, index.try_into().unwrap()) + }; + Some(CircuitGate::wrap(gate)) + } + } + + pub fn num_outputs(&self) -> usize { + unsafe { ffi::clientParametersOutputsSize(self._c) } + .try_into() + .unwrap() + } + + pub fn output(&self, index: usize) -> Option { + if index >= self.num_outputs() { + None + } else { + let gate = unsafe { + ffi::clientParametersOutputCircuitGate(self._c, index.try_into().unwrap()) + }; + Some(CircuitGate::wrap(gate)) + } + } + pub fn serialize(&self) -> Result, CompilerError> { unsafe { let serialized_ref = BufferRef::wrap(ffi::clientParametersSerialize(self._c)); @@ -1116,6 +1195,35 @@ mod test { assert!(!server.is_null()); let client_params = support.load_client_parameters(&result).unwrap(); assert!(!client_params.is_null()); + + assert_eq!(client_params.num_inputs(), 2); + let input_bitwidth_0 = client_params + .input(0) + .unwrap() + .encryption_gate() + .unwrap() + .encoding() + .precision(); + let input_bitwidth_1 = client_params + .input(1) + .unwrap() + .encryption_gate() + .unwrap() + .encoding() + .precision(); + + assert_eq!(input_bitwidth_0, 5); + assert_eq!(input_bitwidth_1, 5); + + assert_eq!(client_params.num_outputs(), 1); + let output_bitwidth = client_params + .output(0) + .unwrap() + .encryption_gate() + .unwrap() + .encoding() + .precision(); + assert_eq!(output_bitwidth, 5); } #[test] diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index c90840328..cd59eee8e 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -341,6 +341,57 @@ ClientParameters clientParametersCopy(ClientParameters params) { void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)} +size_t clientParametersOutputsSize(ClientParameters params) { + return unwrap(params)->outputs.size(); +} + +size_t clientParametersInputsSize(ClientParameters params) { + return unwrap(params)->inputs.size(); +} + +CircuitGate clientParametersOutputCircuitGate(ClientParameters params, + size_t index) { + auto &cppGate = unwrap(params)->outputs[index]; + auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate); + return wrap(cppGateCopy); +} + +CircuitGate clientParametersInputCircuitGate(ClientParameters params, + size_t index) { + auto &cppGate = unwrap(params)->inputs[index]; + auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate); + return wrap(cppGateCopy); +} + +EncryptionGate circuitGateEncryptionGate(CircuitGate circuit_gate) { + auto &maybe_gate = unwrap(circuit_gate)->encryption; + + if (maybe_gate) { + auto *copy = new mlir::concretelang::clientlib::EncryptionGate(*maybe_gate); + return wrap(copy); + } + return (static_cast(wrap))(nullptr); +} + +double encryptionGateVariance(EncryptionGate encryption_gate) { + return unwrap(encryption_gate)->variance; +} + +Encoding encryptionGateEncoding(EncryptionGate encryption_gate) { + auto &cppEncoding = unwrap(encryption_gate)->encoding; + auto *copy = new mlir::concretelang::clientlib::Encoding(cppEncoding); + return wrap(copy); +} + +uint64_t encodingPrecision(Encoding encoding) { + return unwrap(encoding)->precision; +} + +void circuitGateDestroy(CircuitGate gate) { C_STRUCT_CLEANER(gate) } +void encryptionGateDestroy(EncryptionGate gate) { C_STRUCT_CLEANER(gate) } +void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)} + /// ********** KeySet CAPI ***************************************************** KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,