feat(rust): partially bind CircuitGate, EncryptionGate

This adds partial bindings of the CircuitGate, EncrytionGate
and Encoding types for the rust frontend
This commit is contained in:
tmontaigu
2023-01-24 15:49:44 +01:00
parent 8e8651c6a6
commit ee00672996
4 changed files with 211 additions and 0 deletions

View File

@@ -37,6 +37,9 @@ DEFINE_C_API_STRUCT(LibrarySupport, void);
DEFINE_C_API_STRUCT(CompilationOptions, void);
DEFINE_C_API_STRUCT(OptimizerConfig, void);
DEFINE_C_API_STRUCT(ServerLambda, void);
DEFINE_C_API_STRUCT(Encoding, void);
DEFINE_C_API_STRUCT(EncryptionGate, void);
DEFINE_C_API_STRUCT(CircuitGate, void);
DEFINE_C_API_STRUCT(ClientParameters, void);
DEFINE_C_API_STRUCT(KeySet, void);
DEFINE_C_API_STRUCT(KeySetCache, void);
@@ -63,6 +66,9 @@ DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport)
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions)
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig)
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda)
DEFINE_NULL_PTR_CHECKER(circuitGateIsNull, CircuitGate)
DEFINE_NULL_PTR_CHECKER(encodingIsNull, Encoding)
DEFINE_NULL_PTR_CHECKER(encryptionGateIsNull, EncryptionGate)
DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters)
DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet)
DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache)
@@ -240,6 +246,46 @@ clientParametersCopy(ClientParameters params);
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
/// Returns the number of output circuit gates
MLIR_CAPI_EXPORTED size_t clientParametersOutputsSize(ClientParameters params);
/// Returns the number of input circuit gates
MLIR_CAPI_EXPORTED size_t clientParametersInputsSize(ClientParameters params);
/// Returns the output circuit gate corresponding to the index
///
/// - `index` must be valid.
MLIR_CAPI_EXPORTED CircuitGate
clientParametersOutputCircuitGate(ClientParameters params, size_t index);
/// Returns the input circuit gate corresponding to the index
///
/// - `index` must be valid.
MLIR_CAPI_EXPORTED CircuitGate
clientParametersInputCircuitGate(ClientParameters params, size_t index);
/// Returns the EncryptionGate of the circuit gate.
///
/// - The returned gate will be null if the gate does not represent encrypted
/// data
MLIR_CAPI_EXPORTED EncryptionGate
circuitGateEncryptionGate(CircuitGate circuit_gate);
/// Returns the variance of the encryption gate
MLIR_CAPI_EXPORTED double
encryptionGateVariance(EncryptionGate encryption_gate);
/// Returns the Encoding of the encryption gate.
MLIR_CAPI_EXPORTED Encoding
encryptionGateEncoding(EncryptionGate encryption_gate);
/// Returns the precision (bit width) of the encoding
MLIR_CAPI_EXPORTED uint64_t encodingPrecision(Encoding encoding);
MLIR_CAPI_EXPORTED void circuitGateDestroy(CircuitGate gate);
MLIR_CAPI_EXPORTED void encryptionGateDestroy(EncryptionGate gate);
MLIR_CAPI_EXPORTED void encodingDestroy(Encoding encoding);
/// ********** KeySet CAPI *****************************************************
MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params,

View File

@@ -58,6 +58,12 @@ DEFINE_C_API_PTR_METHODS_WITH_ERROR(PublicResult,
mlir::concretelang::clientlib::PublicResult)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationFeedback,
mlir::concretelang::CompilationFeedback)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(Encoding,
mlir::concretelang::clientlib::Encoding)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
EncryptionGate, mlir::concretelang::clientlib::EncryptionGate)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CircuitGate,
mlir::concretelang::clientlib::CircuitGate)
#undef DEFINE_C_API_PTR_METHODS_WITH_ERROR

View File

@@ -49,6 +49,9 @@ impl_CStructErrorMsg! {[
ffi::LibraryCompilationResult,
ffi::LibrarySupport,
ffi::ServerLambda,
ffi::CircuitGate,
ffi::EncryptionGate,
ffi::Encoding,
ffi::ClientParameters,
ffi::KeySet,
ffi::KeySetCache,
@@ -219,6 +222,18 @@ def_CStructWrapper! {
serverLambdaIsNull,
serverLambdaDestroy,
},
CircuitGate => {
circuitGateIsNull,
circuitGateDestroy,
},
EncryptionGate => {
encryptionGateIsNull,
encryptionGateDestroy,
},
Encoding => {
encodingIsNull,
encodingDestroy,
},
ClientParameters => {
clientParametersIsNull,
clientParametersDestroy,
@@ -578,7 +593,71 @@ impl LibrarySupport {
impl ServerLambda {}
impl CircuitGate {
pub fn encryption_gate(self) -> Option<EncryptionGate> {
let inner = unsafe { ffi::circuitGateEncryptionGate(self._c) };
let gate = EncryptionGate::wrap(inner);
if gate.is_null() {
None
} else {
Some(gate)
}
}
}
impl EncryptionGate {
pub fn encoding(self) -> Encoding {
let inner = unsafe { ffi::encryptionGateEncoding(self._c) };
Encoding::wrap(inner)
}
pub fn variance(&self) -> f64 {
unsafe { ffi::encryptionGateVariance(self._c) }
}
}
impl Encoding {
pub fn precision(&self) -> u64 {
unsafe { ffi::encodingPrecision(self._c) }
}
}
impl ClientParameters {
pub fn num_inputs(&self) -> usize {
unsafe { ffi::clientParametersInputsSize(self._c) }
.try_into()
.unwrap()
}
pub fn input(&self, index: usize) -> Option<CircuitGate> {
if index >= self.num_inputs() {
None
} else {
let gate = unsafe {
ffi::clientParametersInputCircuitGate(self._c, index.try_into().unwrap())
};
Some(CircuitGate::wrap(gate))
}
}
pub fn num_outputs(&self) -> usize {
unsafe { ffi::clientParametersOutputsSize(self._c) }
.try_into()
.unwrap()
}
pub fn output(&self, index: usize) -> Option<CircuitGate> {
if index >= self.num_outputs() {
None
} else {
let gate = unsafe {
ffi::clientParametersOutputCircuitGate(self._c, index.try_into().unwrap())
};
Some(CircuitGate::wrap(gate))
}
}
pub fn serialize(&self) -> Result<Vec<c_char>, CompilerError> {
unsafe {
let serialized_ref = BufferRef::wrap(ffi::clientParametersSerialize(self._c));
@@ -1116,6 +1195,35 @@ mod test {
assert!(!server.is_null());
let client_params = support.load_client_parameters(&result).unwrap();
assert!(!client_params.is_null());
assert_eq!(client_params.num_inputs(), 2);
let input_bitwidth_0 = client_params
.input(0)
.unwrap()
.encryption_gate()
.unwrap()
.encoding()
.precision();
let input_bitwidth_1 = client_params
.input(1)
.unwrap()
.encryption_gate()
.unwrap()
.encoding()
.precision();
assert_eq!(input_bitwidth_0, 5);
assert_eq!(input_bitwidth_1, 5);
assert_eq!(client_params.num_outputs(), 1);
let output_bitwidth = client_params
.output(0)
.unwrap()
.encryption_gate()
.unwrap()
.encoding()
.precision();
assert_eq!(output_bitwidth, 5);
}
#[test]

View File

@@ -341,6 +341,57 @@ ClientParameters clientParametersCopy(ClientParameters params) {
void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)}
size_t clientParametersOutputsSize(ClientParameters params) {
return unwrap(params)->outputs.size();
}
size_t clientParametersInputsSize(ClientParameters params) {
return unwrap(params)->inputs.size();
}
CircuitGate clientParametersOutputCircuitGate(ClientParameters params,
size_t index) {
auto &cppGate = unwrap(params)->outputs[index];
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
return wrap(cppGateCopy);
}
CircuitGate clientParametersInputCircuitGate(ClientParameters params,
size_t index) {
auto &cppGate = unwrap(params)->inputs[index];
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
return wrap(cppGateCopy);
}
EncryptionGate circuitGateEncryptionGate(CircuitGate circuit_gate) {
auto &maybe_gate = unwrap(circuit_gate)->encryption;
if (maybe_gate) {
auto *copy = new mlir::concretelang::clientlib::EncryptionGate(*maybe_gate);
return wrap(copy);
}
return (static_cast<EncryptionGate (*)(
mlir::concretelang::clientlib::EncryptionGate *)>(wrap))(nullptr);
}
double encryptionGateVariance(EncryptionGate encryption_gate) {
return unwrap(encryption_gate)->variance;
}
Encoding encryptionGateEncoding(EncryptionGate encryption_gate) {
auto &cppEncoding = unwrap(encryption_gate)->encoding;
auto *copy = new mlir::concretelang::clientlib::Encoding(cppEncoding);
return wrap(copy);
}
uint64_t encodingPrecision(Encoding encoding) {
return unwrap(encoding)->precision;
}
void circuitGateDestroy(CircuitGate gate) { C_STRUCT_CLEANER(gate) }
void encryptionGateDestroy(EncryptionGate gate) { C_STRUCT_CLEANER(gate) }
void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)}
/// ********** KeySet CAPI *****************************************************
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,