mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(rust): partially bind CircuitGate, EncryptionGate
This adds partial bindings of the CircuitGate, EncrytionGate and Encoding types for the rust frontend
This commit is contained in:
@@ -37,6 +37,9 @@ DEFINE_C_API_STRUCT(LibrarySupport, void);
|
||||
DEFINE_C_API_STRUCT(CompilationOptions, void);
|
||||
DEFINE_C_API_STRUCT(OptimizerConfig, void);
|
||||
DEFINE_C_API_STRUCT(ServerLambda, void);
|
||||
DEFINE_C_API_STRUCT(Encoding, void);
|
||||
DEFINE_C_API_STRUCT(EncryptionGate, void);
|
||||
DEFINE_C_API_STRUCT(CircuitGate, void);
|
||||
DEFINE_C_API_STRUCT(ClientParameters, void);
|
||||
DEFINE_C_API_STRUCT(KeySet, void);
|
||||
DEFINE_C_API_STRUCT(KeySetCache, void);
|
||||
@@ -63,6 +66,9 @@ DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions)
|
||||
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig)
|
||||
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda)
|
||||
DEFINE_NULL_PTR_CHECKER(circuitGateIsNull, CircuitGate)
|
||||
DEFINE_NULL_PTR_CHECKER(encodingIsNull, Encoding)
|
||||
DEFINE_NULL_PTR_CHECKER(encryptionGateIsNull, EncryptionGate)
|
||||
DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters)
|
||||
DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet)
|
||||
DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache)
|
||||
@@ -240,6 +246,46 @@ clientParametersCopy(ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
|
||||
|
||||
/// Returns the number of output circuit gates
|
||||
MLIR_CAPI_EXPORTED size_t clientParametersOutputsSize(ClientParameters params);
|
||||
|
||||
/// Returns the number of input circuit gates
|
||||
MLIR_CAPI_EXPORTED size_t clientParametersInputsSize(ClientParameters params);
|
||||
|
||||
/// Returns the output circuit gate corresponding to the index
|
||||
///
|
||||
/// - `index` must be valid.
|
||||
MLIR_CAPI_EXPORTED CircuitGate
|
||||
clientParametersOutputCircuitGate(ClientParameters params, size_t index);
|
||||
|
||||
/// Returns the input circuit gate corresponding to the index
|
||||
///
|
||||
/// - `index` must be valid.
|
||||
MLIR_CAPI_EXPORTED CircuitGate
|
||||
clientParametersInputCircuitGate(ClientParameters params, size_t index);
|
||||
|
||||
/// Returns the EncryptionGate of the circuit gate.
|
||||
///
|
||||
/// - The returned gate will be null if the gate does not represent encrypted
|
||||
/// data
|
||||
MLIR_CAPI_EXPORTED EncryptionGate
|
||||
circuitGateEncryptionGate(CircuitGate circuit_gate);
|
||||
|
||||
/// Returns the variance of the encryption gate
|
||||
MLIR_CAPI_EXPORTED double
|
||||
encryptionGateVariance(EncryptionGate encryption_gate);
|
||||
|
||||
/// Returns the Encoding of the encryption gate.
|
||||
MLIR_CAPI_EXPORTED Encoding
|
||||
encryptionGateEncoding(EncryptionGate encryption_gate);
|
||||
|
||||
/// Returns the precision (bit width) of the encoding
|
||||
MLIR_CAPI_EXPORTED uint64_t encodingPrecision(Encoding encoding);
|
||||
|
||||
MLIR_CAPI_EXPORTED void circuitGateDestroy(CircuitGate gate);
|
||||
MLIR_CAPI_EXPORTED void encryptionGateDestroy(EncryptionGate gate);
|
||||
MLIR_CAPI_EXPORTED void encodingDestroy(Encoding encoding);
|
||||
|
||||
/// ********** KeySet CAPI *****************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params,
|
||||
|
||||
@@ -58,6 +58,12 @@ DEFINE_C_API_PTR_METHODS_WITH_ERROR(PublicResult,
|
||||
mlir::concretelang::clientlib::PublicResult)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationFeedback,
|
||||
mlir::concretelang::CompilationFeedback)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(Encoding,
|
||||
mlir::concretelang::clientlib::Encoding)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
|
||||
EncryptionGate, mlir::concretelang::clientlib::EncryptionGate)
|
||||
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CircuitGate,
|
||||
mlir::concretelang::clientlib::CircuitGate)
|
||||
|
||||
#undef DEFINE_C_API_PTR_METHODS_WITH_ERROR
|
||||
|
||||
|
||||
@@ -49,6 +49,9 @@ impl_CStructErrorMsg! {[
|
||||
ffi::LibraryCompilationResult,
|
||||
ffi::LibrarySupport,
|
||||
ffi::ServerLambda,
|
||||
ffi::CircuitGate,
|
||||
ffi::EncryptionGate,
|
||||
ffi::Encoding,
|
||||
ffi::ClientParameters,
|
||||
ffi::KeySet,
|
||||
ffi::KeySetCache,
|
||||
@@ -219,6 +222,18 @@ def_CStructWrapper! {
|
||||
serverLambdaIsNull,
|
||||
serverLambdaDestroy,
|
||||
},
|
||||
CircuitGate => {
|
||||
circuitGateIsNull,
|
||||
circuitGateDestroy,
|
||||
},
|
||||
EncryptionGate => {
|
||||
encryptionGateIsNull,
|
||||
encryptionGateDestroy,
|
||||
},
|
||||
Encoding => {
|
||||
encodingIsNull,
|
||||
encodingDestroy,
|
||||
},
|
||||
ClientParameters => {
|
||||
clientParametersIsNull,
|
||||
clientParametersDestroy,
|
||||
@@ -578,7 +593,71 @@ impl LibrarySupport {
|
||||
|
||||
impl ServerLambda {}
|
||||
|
||||
impl CircuitGate {
|
||||
pub fn encryption_gate(self) -> Option<EncryptionGate> {
|
||||
let inner = unsafe { ffi::circuitGateEncryptionGate(self._c) };
|
||||
let gate = EncryptionGate::wrap(inner);
|
||||
if gate.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(gate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptionGate {
|
||||
pub fn encoding(self) -> Encoding {
|
||||
let inner = unsafe { ffi::encryptionGateEncoding(self._c) };
|
||||
|
||||
Encoding::wrap(inner)
|
||||
}
|
||||
|
||||
pub fn variance(&self) -> f64 {
|
||||
unsafe { ffi::encryptionGateVariance(self._c) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
pub fn precision(&self) -> u64 {
|
||||
unsafe { ffi::encodingPrecision(self._c) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientParameters {
|
||||
pub fn num_inputs(&self) -> usize {
|
||||
unsafe { ffi::clientParametersInputsSize(self._c) }
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn input(&self, index: usize) -> Option<CircuitGate> {
|
||||
if index >= self.num_inputs() {
|
||||
None
|
||||
} else {
|
||||
let gate = unsafe {
|
||||
ffi::clientParametersInputCircuitGate(self._c, index.try_into().unwrap())
|
||||
};
|
||||
Some(CircuitGate::wrap(gate))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_outputs(&self) -> usize {
|
||||
unsafe { ffi::clientParametersOutputsSize(self._c) }
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn output(&self, index: usize) -> Option<CircuitGate> {
|
||||
if index >= self.num_outputs() {
|
||||
None
|
||||
} else {
|
||||
let gate = unsafe {
|
||||
ffi::clientParametersOutputCircuitGate(self._c, index.try_into().unwrap())
|
||||
};
|
||||
Some(CircuitGate::wrap(gate))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<Vec<c_char>, CompilerError> {
|
||||
unsafe {
|
||||
let serialized_ref = BufferRef::wrap(ffi::clientParametersSerialize(self._c));
|
||||
@@ -1116,6 +1195,35 @@ mod test {
|
||||
assert!(!server.is_null());
|
||||
let client_params = support.load_client_parameters(&result).unwrap();
|
||||
assert!(!client_params.is_null());
|
||||
|
||||
assert_eq!(client_params.num_inputs(), 2);
|
||||
let input_bitwidth_0 = client_params
|
||||
.input(0)
|
||||
.unwrap()
|
||||
.encryption_gate()
|
||||
.unwrap()
|
||||
.encoding()
|
||||
.precision();
|
||||
let input_bitwidth_1 = client_params
|
||||
.input(1)
|
||||
.unwrap()
|
||||
.encryption_gate()
|
||||
.unwrap()
|
||||
.encoding()
|
||||
.precision();
|
||||
|
||||
assert_eq!(input_bitwidth_0, 5);
|
||||
assert_eq!(input_bitwidth_1, 5);
|
||||
|
||||
assert_eq!(client_params.num_outputs(), 1);
|
||||
let output_bitwidth = client_params
|
||||
.output(0)
|
||||
.unwrap()
|
||||
.encryption_gate()
|
||||
.unwrap()
|
||||
.encoding()
|
||||
.precision();
|
||||
assert_eq!(output_bitwidth, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -341,6 +341,57 @@ ClientParameters clientParametersCopy(ClientParameters params) {
|
||||
|
||||
void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)}
|
||||
|
||||
size_t clientParametersOutputsSize(ClientParameters params) {
|
||||
return unwrap(params)->outputs.size();
|
||||
}
|
||||
|
||||
size_t clientParametersInputsSize(ClientParameters params) {
|
||||
return unwrap(params)->inputs.size();
|
||||
}
|
||||
|
||||
CircuitGate clientParametersOutputCircuitGate(ClientParameters params,
|
||||
size_t index) {
|
||||
auto &cppGate = unwrap(params)->outputs[index];
|
||||
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
|
||||
return wrap(cppGateCopy);
|
||||
}
|
||||
|
||||
CircuitGate clientParametersInputCircuitGate(ClientParameters params,
|
||||
size_t index) {
|
||||
auto &cppGate = unwrap(params)->inputs[index];
|
||||
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
|
||||
return wrap(cppGateCopy);
|
||||
}
|
||||
|
||||
EncryptionGate circuitGateEncryptionGate(CircuitGate circuit_gate) {
|
||||
auto &maybe_gate = unwrap(circuit_gate)->encryption;
|
||||
|
||||
if (maybe_gate) {
|
||||
auto *copy = new mlir::concretelang::clientlib::EncryptionGate(*maybe_gate);
|
||||
return wrap(copy);
|
||||
}
|
||||
return (static_cast<EncryptionGate (*)(
|
||||
mlir::concretelang::clientlib::EncryptionGate *)>(wrap))(nullptr);
|
||||
}
|
||||
|
||||
double encryptionGateVariance(EncryptionGate encryption_gate) {
|
||||
return unwrap(encryption_gate)->variance;
|
||||
}
|
||||
|
||||
Encoding encryptionGateEncoding(EncryptionGate encryption_gate) {
|
||||
auto &cppEncoding = unwrap(encryption_gate)->encoding;
|
||||
auto *copy = new mlir::concretelang::clientlib::Encoding(cppEncoding);
|
||||
return wrap(copy);
|
||||
}
|
||||
|
||||
uint64_t encodingPrecision(Encoding encoding) {
|
||||
return unwrap(encoding)->precision;
|
||||
}
|
||||
|
||||
void circuitGateDestroy(CircuitGate gate) { C_STRUCT_CLEANER(gate) }
|
||||
void encryptionGateDestroy(EncryptionGate gate) { C_STRUCT_CLEANER(gate) }
|
||||
void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)}
|
||||
|
||||
/// ********** KeySet CAPI *****************************************************
|
||||
|
||||
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
|
||||
|
||||
Reference in New Issue
Block a user