mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
we want to wrap CStructs in RustStructs to own them, and free memeory when they are no longer used. Users won't have to deal with the direct binded CAPI, but the new wrappers
1211 lines
41 KiB
Rust
1211 lines
41 KiB
Rust
//! Compiler module
|
|
|
|
use crate::mlir::ffi;
|
|
use std::os::raw::c_char;
|
|
use std::{ffi::CStr, path::Path};
|
|
|
|
#[derive(Debug)]
|
|
pub struct CompilerError(String);
|
|
|
|
/// Retreive buffer of the error message from a C struct.
|
|
trait CStructErrorMsg {
|
|
fn error_msg(&self) -> *const i8;
|
|
}
|
|
|
|
/// All C struct can return a pointer to the allocated error message.
|
|
macro_rules! impl_CStructErrorMsg {
|
|
([$($t:ty),+]) => {
|
|
$(impl CStructErrorMsg for $t {
|
|
fn error_msg(&self) -> *const i8 {
|
|
self.error
|
|
}
|
|
})*
|
|
}
|
|
}
|
|
impl_CStructErrorMsg! {[
|
|
ffi::BufferRef,
|
|
ffi::CompilationOptions,
|
|
ffi::OptimizerConfig,
|
|
ffi::CompilerEngine,
|
|
ffi::CompilationResult,
|
|
ffi::Library,
|
|
ffi::LibraryCompilationResult,
|
|
ffi::LibrarySupport,
|
|
ffi::ServerLambda,
|
|
ffi::ClientParameters,
|
|
ffi::KeySet,
|
|
ffi::KeySetCache,
|
|
ffi::EvaluationKeys,
|
|
ffi::LambdaArgument,
|
|
ffi::PublicArguments,
|
|
ffi::PublicResult,
|
|
ffi::CompilationFeedback
|
|
]}
|
|
|
|
/// Construct a rust error message from a buffer in the C struct.
|
|
fn get_error_msg_from_ctype<T: CStructErrorMsg>(c_struct: &T) -> String {
|
|
unsafe {
|
|
let error_msg_cstr = CStr::from_ptr(c_struct.error_msg());
|
|
String::from(error_msg_cstr.to_str().unwrap())
|
|
}
|
|
}
|
|
|
|
/// Wrapper to own MlirStringRef coming from the compiler and destroy them on drop
|
|
struct MlirStringRef(ffi::MlirStringRef);
|
|
|
|
impl MlirStringRef {
|
|
pub fn to_string(&self) -> Result<String, CompilerError> {
|
|
unsafe {
|
|
if self.0.data.is_null() {
|
|
return Err(CompilerError("string ref points to null".to_string()));
|
|
}
|
|
let result = String::from_utf8_lossy(std::slice::from_raw_parts(
|
|
self.0.data as *const u8,
|
|
self.0.length as usize,
|
|
))
|
|
.to_string();
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
/// Create an ffi MlirStringRef for a rust str.
|
|
///
|
|
/// The reason behind not returning a wrapper is that it would lead to freeing rust memory
|
|
/// using a custom destructor in C.
|
|
///
|
|
/// # SAFETY
|
|
/// The caller has to make sure the &str outlive the ffi::MlirStringRef
|
|
pub unsafe fn from_rust_str(s: &str) -> ffi::MlirStringRef {
|
|
ffi::MlirStringRef {
|
|
data: s.as_ptr() as *const c_char,
|
|
length: s.len() as ffi::size_t,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for MlirStringRef {
|
|
fn drop(&mut self) {
|
|
unsafe { ffi::mlirStringRefDestroy(self.0) }
|
|
}
|
|
}
|
|
|
|
trait CStructWrapper<T> {
|
|
// wrap a c-struct inside a rust-struct
|
|
fn wrap(c_struct: T) -> Self;
|
|
// check if the wrapped c-struct is null
|
|
fn is_null(&self) -> bool;
|
|
// get error message
|
|
fn error_msg(&self) -> String;
|
|
// drop
|
|
fn destroy(&mut self);
|
|
}
|
|
|
|
/// Wrapper of CStruct.
|
|
///
|
|
/// We want to have a Rust wrapper for every CStruct that will take care of owning
|
|
/// it, and freeing memory when it's no longer used.
|
|
macro_rules! def_CStructWrapper {
|
|
(
|
|
$name:ident => {
|
|
$ffi_is_null_fn:ident,
|
|
$ffi_destroy_fn:ident
|
|
$(,)?
|
|
}
|
|
) => {
|
|
|
|
pub struct $name{ _c: ffi::$name }
|
|
|
|
impl CStructWrapper<ffi::$name> for $name {
|
|
// wrap a c-struct inside a rust-struct
|
|
fn wrap(c_struct: ffi::$name) -> Self {
|
|
Self{_c: c_struct}
|
|
}
|
|
// check if the wrapped C-struct is null
|
|
fn is_null(&self) -> bool {
|
|
unsafe {
|
|
ffi::$ffi_is_null_fn(self._c)
|
|
}
|
|
}
|
|
// get error message
|
|
fn error_msg(&self) -> String {
|
|
get_error_msg_from_ctype(&self._c)
|
|
}
|
|
// free memory allocated for the C-struct
|
|
fn destroy(&mut self) {
|
|
unsafe {
|
|
ffi::$ffi_destroy_fn(self._c)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for $name {
|
|
fn drop(&mut self) {
|
|
self.destroy();
|
|
}
|
|
}
|
|
};
|
|
|
|
(
|
|
$(
|
|
$name:ident => {
|
|
$ffi_is_null_fn:ident,
|
|
$ffi_destroy_fn:ident
|
|
$(,)?
|
|
}
|
|
),+
|
|
$(,)?
|
|
) => {
|
|
$(
|
|
def_CStructWrapper!{
|
|
$name => {
|
|
$ffi_is_null_fn,
|
|
$ffi_destroy_fn
|
|
}
|
|
}
|
|
)+
|
|
};
|
|
}
|
|
def_CStructWrapper! {
|
|
BufferRef => {
|
|
bufferRefIsNull,
|
|
bufferRefDestroy
|
|
},
|
|
CompilationOptions => {
|
|
compilationOptionsIsNull,
|
|
compilationOptionsDestroy,
|
|
},
|
|
OptimizerConfig => {
|
|
optimizerConfigIsNull,
|
|
optimizerConfigDestroy,
|
|
},
|
|
CompilerEngine => {
|
|
compilerEngineIsNull,
|
|
compilerEngineDestroy,
|
|
},
|
|
CompilationResult => {
|
|
compilationResultIsNull,
|
|
compilationResultDestroy,
|
|
},
|
|
Library => {
|
|
libraryIsNull,
|
|
libraryDestroy,
|
|
},
|
|
LibraryCompilationResult => {
|
|
libraryCompilationResultIsNull,
|
|
libraryCompilationResultDestroy,
|
|
},
|
|
LibrarySupport => {
|
|
librarySupportIsNull,
|
|
librarySupportDestroy,
|
|
},
|
|
ServerLambda => {
|
|
serverLambdaIsNull,
|
|
serverLambdaDestroy,
|
|
},
|
|
ClientParameters => {
|
|
clientParametersIsNull,
|
|
clientParametersDestroy,
|
|
},
|
|
KeySetCache => {
|
|
keySetCacheIsNull,
|
|
keySetCacheDestroy,
|
|
},
|
|
EvaluationKeys => {
|
|
evaluationKeysIsNull,
|
|
evaluationKeysDestroy,
|
|
},
|
|
LambdaArgument => {
|
|
lambdaArgumentIsNull,
|
|
lambdaArgumentDestroy,
|
|
},
|
|
PublicArguments => {
|
|
publicArgumentsIsNull,
|
|
publicArgumentsDestroy,
|
|
},
|
|
PublicResult => {
|
|
publicResultIsNull,
|
|
publicResultDestroy,
|
|
},
|
|
CompilationFeedback => {
|
|
compilationFeedbackIsNull,
|
|
compilationFeedbackDestroy,
|
|
}
|
|
}
|
|
|
|
impl BufferRef {
|
|
/// Create a reference to a buffer in memory.
|
|
///
|
|
/// The pointed memory will not get owned. The caller must make sure the pointer points
|
|
/// to a valid memory region of the provided length, and that the pointed memory outlive
|
|
/// the buffer reference.
|
|
pub fn new(ptr: *const c_char, length: ffi::size_t) -> Result<ffi::BufferRef, CompilerError> {
|
|
unsafe {
|
|
let buffer_ref = ffi::bufferRefCreate(ptr, length);
|
|
if ffi::bufferRefIsNull(buffer_ref) {
|
|
let error_msg = get_error_msg_from_ctype(&buffer_ref);
|
|
ffi::bufferRefDestroy(buffer_ref);
|
|
return Err(CompilerError(error_msg));
|
|
}
|
|
return Ok(buffer_ref);
|
|
}
|
|
}
|
|
|
|
/// Copy the content of the buffer into a new vector of bytes.
|
|
///
|
|
/// Returns an empty vector if the buffer reference a null pointer.
|
|
pub fn to_bytes(&self) -> Vec<c_char> {
|
|
if self.is_null() {
|
|
return Vec::new();
|
|
}
|
|
let buffer_ref_c = self._c;
|
|
unsafe {
|
|
let result = std::slice::from_raw_parts(
|
|
buffer_ref_c.data as *const c_char,
|
|
buffer_ref_c.length as usize,
|
|
)
|
|
.to_vec();
|
|
result
|
|
}
|
|
}
|
|
}
|
|
|
|
impl CompilationOptions {
|
|
pub fn new(
|
|
func_name: &str,
|
|
auto_parallelize: bool,
|
|
batch_concrete_ops: bool,
|
|
dataflow_parallelize: bool,
|
|
emit_gpu_ops: bool,
|
|
loop_parallelize: bool,
|
|
optimize_concrete: bool,
|
|
optimizer_config: &OptimizerConfig,
|
|
verify_diagnostics: bool,
|
|
) -> Result<CompilationOptions, CompilerError> {
|
|
unsafe {
|
|
let options = CompilationOptions::wrap(ffi::compilationOptionsCreate(
|
|
MlirStringRef::from_rust_str(func_name),
|
|
auto_parallelize,
|
|
batch_concrete_ops,
|
|
dataflow_parallelize,
|
|
emit_gpu_ops,
|
|
loop_parallelize,
|
|
optimize_concrete,
|
|
optimizer_config._c,
|
|
verify_diagnostics,
|
|
));
|
|
if options.is_null() {
|
|
return Err(CompilerError(options.error_msg()));
|
|
}
|
|
Ok(options)
|
|
}
|
|
}
|
|
|
|
pub fn get_default() -> Result<CompilationOptions, CompilerError> {
|
|
unsafe {
|
|
let options = CompilationOptions::wrap(ffi::compilationOptionsCreateDefault());
|
|
if options.is_null() {
|
|
return Err(CompilerError(options.error_msg()));
|
|
}
|
|
Ok(options)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl OptimizerConfig {
|
|
pub fn new(
|
|
display: bool,
|
|
fallback_log_norm_woppbs: f64,
|
|
global_p_error: f64,
|
|
p_error: f64,
|
|
security: u64,
|
|
strategy_v0: bool,
|
|
use_gpu_constraints: bool,
|
|
) -> Result<OptimizerConfig, CompilerError> {
|
|
unsafe {
|
|
let config = OptimizerConfig::wrap(ffi::optimizerConfigCreate(
|
|
display,
|
|
fallback_log_norm_woppbs,
|
|
global_p_error,
|
|
p_error,
|
|
security,
|
|
strategy_v0,
|
|
use_gpu_constraints,
|
|
));
|
|
if config.is_null() {
|
|
return Err(CompilerError(config.error_msg()));
|
|
}
|
|
Ok(config)
|
|
}
|
|
}
|
|
|
|
pub fn get_default() -> Result<OptimizerConfig, CompilerError> {
|
|
unsafe {
|
|
let config = OptimizerConfig::wrap(ffi::optimizerConfigCreateDefault());
|
|
if config.is_null() {
|
|
return Err(CompilerError(config.error_msg()));
|
|
}
|
|
Ok(config)
|
|
}
|
|
}
|
|
}
|
|
impl CompilerEngine {
|
|
pub fn new(options: Option<&CompilationOptions>) -> Result<CompilerEngine, CompilerError> {
|
|
unsafe {
|
|
let engine = CompilerEngine::wrap(ffi::compilerEngineCreate());
|
|
if engine.is_null() {
|
|
return Err(CompilerError(engine.error_msg()));
|
|
}
|
|
if let Some(o) = options {
|
|
engine.set_options(o)
|
|
}
|
|
Ok(engine)
|
|
}
|
|
}
|
|
|
|
pub fn set_options(&self, options: &CompilationOptions) {
|
|
unsafe {
|
|
ffi::compilerEngineCompileSetOptions(self._c, options._c);
|
|
}
|
|
}
|
|
|
|
pub fn compile(
|
|
&self,
|
|
module: &str,
|
|
target: ffi::CompilationTarget,
|
|
) -> Result<CompilationResult, CompilerError> {
|
|
unsafe {
|
|
let module_string_ref = MlirStringRef::from_rust_str(module);
|
|
let result = CompilationResult::wrap(ffi::compilerEngineCompile(
|
|
self._c,
|
|
module_string_ref,
|
|
target,
|
|
));
|
|
if result.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
result.error_msg()
|
|
)));
|
|
}
|
|
Ok(result)
|
|
}
|
|
}
|
|
}
|
|
impl CompilationResult {
|
|
pub fn get_module_string(&self) -> Result<String, CompilerError> {
|
|
unsafe { MlirStringRef(ffi::compilationResultGetModuleString(self._c)).to_string() }
|
|
}
|
|
}
|
|
impl Library {
|
|
pub fn new(
|
|
output_dir_path: &str,
|
|
runtime_library_path: Option<&str>,
|
|
clean_up: bool,
|
|
) -> Result<Library, CompilerError> {
|
|
unsafe {
|
|
let lib = Library::wrap(ffi::libraryCreate(
|
|
MlirStringRef::from_rust_str(output_dir_path),
|
|
MlirStringRef::from_rust_str(runtime_library_path.unwrap_or("")),
|
|
clean_up,
|
|
));
|
|
if lib.is_null() {
|
|
return Err(CompilerError(lib.error_msg()));
|
|
}
|
|
Ok(lib)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl LibraryCompilationResult {}
|
|
|
|
/// Support for compiling and executing libraries.
|
|
impl LibrarySupport {
|
|
/// LibrarySupport manages build files generated by the compiler under the `output_dir_path`.
|
|
///
|
|
/// The compiled library needs to link to the runtime for proper execution.
|
|
pub fn new(
|
|
output_dir_path: &str,
|
|
runtime_library_path: Option<String>,
|
|
) -> Result<LibrarySupport, CompilerError> {
|
|
unsafe {
|
|
let runtime_library_path = match runtime_library_path {
|
|
Some(val) => val.to_string(),
|
|
None => "".to_string(),
|
|
};
|
|
let runtime_library_path_buffer = runtime_library_path.as_str();
|
|
let support = LibrarySupport::wrap(ffi::librarySupportCreateDefault(
|
|
MlirStringRef::from_rust_str(output_dir_path),
|
|
MlirStringRef::from_rust_str(runtime_library_path_buffer),
|
|
));
|
|
if support.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
support.error_msg()
|
|
)));
|
|
}
|
|
Ok(support)
|
|
}
|
|
}
|
|
|
|
/// Compile an MLIR into a library.
|
|
pub fn compile(
|
|
&self,
|
|
mlir_code: &str,
|
|
options: Option<CompilationOptions>,
|
|
) -> Result<LibraryCompilationResult, CompilerError> {
|
|
unsafe {
|
|
let options = options.unwrap_or_else(|| CompilationOptions::get_default().unwrap());
|
|
let result = LibraryCompilationResult::wrap(ffi::librarySupportCompile(
|
|
self._c,
|
|
MlirStringRef::from_rust_str(mlir_code),
|
|
options._c,
|
|
));
|
|
if result.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
result.error_msg()
|
|
)));
|
|
}
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
/// Load server lambda from a compilation result.
|
|
///
|
|
/// This can be used for executing the compiled function.
|
|
pub fn load_server_lambda(
|
|
&self,
|
|
result: &LibraryCompilationResult,
|
|
) -> Result<ServerLambda, CompilerError> {
|
|
unsafe {
|
|
let server =
|
|
ServerLambda::wrap(ffi::librarySupportLoadServerLambda(self._c, result._c));
|
|
if server.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
server.error_msg()
|
|
)));
|
|
}
|
|
Ok(server)
|
|
}
|
|
}
|
|
|
|
/// Load client parameters from a compilation result.
|
|
///
|
|
/// This can be used for creating keys for the compiled library.
|
|
pub fn load_client_parameters(
|
|
&self,
|
|
result: &LibraryCompilationResult,
|
|
) -> Result<ClientParameters, CompilerError> {
|
|
unsafe {
|
|
let params =
|
|
ClientParameters::wrap(ffi::librarySupportLoadClientParameters(self._c, result._c));
|
|
if params.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
params.error_msg()
|
|
)));
|
|
}
|
|
Ok(params)
|
|
}
|
|
}
|
|
|
|
/// Run a compiled circuit.
|
|
pub fn server_lambda_call(
|
|
&self,
|
|
server_lambda: &ServerLambda,
|
|
args: &PublicArguments,
|
|
eval_keys: &EvaluationKeys,
|
|
) -> Result<PublicResult, CompilerError> {
|
|
unsafe {
|
|
let result = PublicResult::wrap(ffi::librarySupportServerCall(
|
|
self._c,
|
|
server_lambda._c,
|
|
args._c,
|
|
eval_keys._c,
|
|
));
|
|
if result.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
result.error_msg()
|
|
)));
|
|
}
|
|
Ok(result)
|
|
}
|
|
}
|
|
|
|
/// Get path to the compiled shared library
|
|
pub fn get_shared_lib_path(&self) -> String {
|
|
unsafe {
|
|
MlirStringRef(ffi::librarySupportGetSharedLibPath(self._c))
|
|
.to_string()
|
|
.unwrap()
|
|
}
|
|
}
|
|
|
|
/// Get path to the client parameters
|
|
pub fn get_client_parameters_path(&self) -> String {
|
|
unsafe {
|
|
MlirStringRef(ffi::librarySupportGetClientParametersPath(self._c))
|
|
.to_string()
|
|
.unwrap()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl ServerLambda {}
|
|
|
|
impl ClientParameters {
|
|
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::wrap(ffi::clientParametersSerialize(self._c));
|
|
if serialized_ref.is_null() {
|
|
return Err(CompilerError(serialized_ref.error_msg()));
|
|
}
|
|
Ok(serialized_ref.to_bytes())
|
|
}
|
|
}
|
|
pub fn unserialize(serialized: &Vec<c_char>) -> Result<ClientParameters, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::new(
|
|
serialized.as_ptr() as *const c_char,
|
|
serialized.len().try_into().unwrap(),
|
|
)
|
|
.unwrap();
|
|
let params = ClientParameters::wrap(ffi::clientParametersUnserialize(serialized_ref));
|
|
if params.is_null() {
|
|
return Err(CompilerError(params.error_msg()));
|
|
}
|
|
Ok(params)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Clone for ClientParameters {
|
|
fn clone(&self) -> Self {
|
|
unsafe { ClientParameters::wrap(ffi::clientParametersCopy(self._c)) }
|
|
}
|
|
}
|
|
|
|
struct KeySet_ {
|
|
_c: ffi::KeySet,
|
|
}
|
|
|
|
impl CStructWrapper<ffi::KeySet> for KeySet_ {
|
|
// wrap a c-struct inside a rust-struct
|
|
fn wrap(c_struct: ffi::KeySet) -> KeySet_ {
|
|
KeySet_ { _c: c_struct }
|
|
}
|
|
// check if the wrapped C-struct is null
|
|
fn is_null(&self) -> bool {
|
|
unsafe { ffi::keySetIsNull(self._c) }
|
|
}
|
|
// get error message
|
|
fn error_msg(&self) -> String {
|
|
get_error_msg_from_ctype(&self._c)
|
|
}
|
|
// free memory allocated for the C-struct
|
|
fn destroy(&mut self) {
|
|
unsafe { ffi::keySetDestroy(self._c) }
|
|
}
|
|
}
|
|
|
|
impl Drop for KeySet_ {
|
|
fn drop(&mut self) {
|
|
self.destroy();
|
|
}
|
|
}
|
|
pub struct KeySet {
|
|
key_set: KeySet_,
|
|
client_params: ClientParameters,
|
|
}
|
|
|
|
impl KeySet {
|
|
/// Get a keyset based on the client parameters, and the different seeds.
|
|
///
|
|
/// If a cache is set, this operation would first try to load an existing key,
|
|
/// otherwise, a new keyset will be generated.
|
|
pub fn new(
|
|
client_params: &ClientParameters,
|
|
seed_msb: Option<u64>,
|
|
seed_lsb: Option<u64>,
|
|
key_set_cache: Option<&KeySetCache>,
|
|
) -> Result<KeySet, CompilerError> {
|
|
unsafe {
|
|
let key_set = match key_set_cache {
|
|
Some(cache) => KeySet_::wrap(ffi::keySetCacheLoadOrGenerateKeySet(
|
|
cache._c,
|
|
client_params._c,
|
|
seed_msb.unwrap_or(0),
|
|
seed_lsb.unwrap_or(0),
|
|
)),
|
|
None => KeySet_::wrap(ffi::keySetGenerate(
|
|
client_params._c,
|
|
seed_msb.unwrap_or(0),
|
|
seed_lsb.unwrap_or(0),
|
|
)),
|
|
};
|
|
if key_set.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
key_set.error_msg()
|
|
)));
|
|
}
|
|
Ok(KeySet {
|
|
key_set,
|
|
client_params: client_params.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
pub fn get_evaluation_keys(&self) -> Result<EvaluationKeys, CompilerError> {
|
|
unsafe {
|
|
let eval_keys = EvaluationKeys::wrap(ffi::keySetGetEvaluationKeys(self.key_set._c));
|
|
if eval_keys.is_null() {
|
|
return Err(CompilerError(eval_keys.error_msg()));
|
|
}
|
|
Ok(eval_keys)
|
|
}
|
|
}
|
|
|
|
/// Encrypt arguments of a compiled circuit.
|
|
pub fn encrypt_args(&self, args: &[LambdaArgument]) -> Result<PublicArguments, CompilerError> {
|
|
LambdaArgument::encrypt_args(args, self)
|
|
}
|
|
|
|
pub fn decrypt_result(&self, result: &PublicResult) -> Result<LambdaArgument, CompilerError> {
|
|
result.decrypt(self)
|
|
}
|
|
}
|
|
|
|
impl KeySetCache {
|
|
pub fn new(path: &Path) -> Result<KeySetCache, CompilerError> {
|
|
unsafe {
|
|
let cache_path_buffer = path.to_str().unwrap();
|
|
let cache = KeySetCache::wrap(ffi::keySetCacheCreate(MlirStringRef::from_rust_str(
|
|
cache_path_buffer,
|
|
)));
|
|
if cache.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
cache.error_msg()
|
|
)));
|
|
}
|
|
Ok(cache)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl EvaluationKeys {
|
|
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::wrap(ffi::evaluationKeysSerialize(self._c));
|
|
if serialized_ref.is_null() {
|
|
return Err(CompilerError(serialized_ref.error_msg()));
|
|
}
|
|
Ok(serialized_ref.to_bytes())
|
|
}
|
|
}
|
|
pub fn unserialize(serialized: &Vec<c_char>) -> Result<EvaluationKeys, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::new(
|
|
serialized.as_ptr() as *const c_char,
|
|
serialized.len().try_into().unwrap(),
|
|
)
|
|
.unwrap();
|
|
let eval_keys = EvaluationKeys::wrap(ffi::evaluationKeysUnserialize(serialized_ref));
|
|
if eval_keys.is_null() {
|
|
return Err(CompilerError(eval_keys.error_msg()));
|
|
}
|
|
Ok(eval_keys)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl LambdaArgument {
|
|
pub fn encrypt_args(
|
|
args: &[LambdaArgument],
|
|
key_set: &KeySet,
|
|
) -> Result<PublicArguments, CompilerError> {
|
|
unsafe {
|
|
let args: Vec<ffi::LambdaArgument> = args.into_iter().map(|a| a._c).collect();
|
|
let public_args = PublicArguments::wrap(ffi::lambdaArgumentEncrypt(
|
|
args.as_ptr(),
|
|
args.len() as u64,
|
|
key_set.client_params._c,
|
|
key_set.key_set._c,
|
|
));
|
|
if public_args.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
public_args.error_msg()
|
|
)));
|
|
}
|
|
Ok(public_args)
|
|
}
|
|
}
|
|
|
|
pub fn from_scalar(scalar: u64) -> Result<LambdaArgument, CompilerError> {
|
|
unsafe {
|
|
let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromScalar(scalar));
|
|
if arg.is_null() {
|
|
return Err(CompilerError(arg.error_msg()));
|
|
}
|
|
Ok(arg)
|
|
}
|
|
}
|
|
|
|
pub fn is_scalar(&self) -> bool {
|
|
unsafe { ffi::lambdaArgumentIsScalar(self._c) }
|
|
}
|
|
|
|
pub fn get_scalar(&self) -> Result<u64, CompilerError> {
|
|
unsafe {
|
|
if !self.is_scalar() {
|
|
return Err(CompilerError("argument is not a scalar".to_string()));
|
|
}
|
|
Ok(ffi::lambdaArgumentGetScalar(self._c))
|
|
}
|
|
}
|
|
|
|
pub fn from_tensor_u8(data: &[u8], dims: &[i64]) -> Result<LambdaArgument, CompilerError> {
|
|
unsafe {
|
|
let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromTensorU8(
|
|
data.as_ptr(),
|
|
dims.as_ptr(),
|
|
dims.len().try_into().unwrap(),
|
|
));
|
|
if arg.is_null() {
|
|
return Err(CompilerError(arg.error_msg()));
|
|
}
|
|
Ok(arg)
|
|
}
|
|
}
|
|
|
|
pub fn from_tensor_u64(data: &[u64], dims: &[i64]) -> Result<LambdaArgument, CompilerError> {
|
|
unsafe {
|
|
let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromTensorU64(
|
|
data.as_ptr(),
|
|
dims.as_ptr(),
|
|
dims.len().try_into().unwrap(),
|
|
));
|
|
if arg.is_null() {
|
|
return Err(CompilerError(arg.error_msg()));
|
|
}
|
|
Ok(arg)
|
|
}
|
|
}
|
|
|
|
pub fn is_tensor(&self) -> bool {
|
|
unsafe { ffi::lambdaArgumentIsTensor(self._c) }
|
|
}
|
|
|
|
pub fn get_data_size(&self) -> Result<i64, CompilerError> {
|
|
unsafe {
|
|
if !self.is_tensor() {
|
|
return Err(CompilerError("argument is not a tensor".to_string()));
|
|
}
|
|
Ok(ffi::lambdaArgumentGetTensorDataSize(self._c))
|
|
}
|
|
}
|
|
|
|
pub fn get_rank(&self) -> Result<ffi::size_t, CompilerError> {
|
|
unsafe {
|
|
if !self.is_tensor() {
|
|
return Err(CompilerError("argument is not a tensor".to_string()));
|
|
}
|
|
Ok(ffi::lambdaArgumentGetTensorRank(self._c))
|
|
}
|
|
}
|
|
|
|
pub fn get_dims(&self) -> Result<Vec<i64>, CompilerError> {
|
|
unsafe {
|
|
let rank = self.get_rank().unwrap();
|
|
let mut dims = Vec::new();
|
|
dims.resize(rank.try_into().unwrap(), 0);
|
|
if !ffi::lambdaArgumentGetTensorDims(self._c, dims.as_mut_ptr()) {
|
|
return Err(CompilerError("couldn't get dims".to_string()));
|
|
}
|
|
Ok(dims)
|
|
}
|
|
}
|
|
|
|
pub fn get_data(&self) -> Result<Vec<u64>, CompilerError> {
|
|
unsafe {
|
|
let size = self.get_data_size().unwrap();
|
|
let mut data = Vec::new();
|
|
data.resize(size.try_into().unwrap(), 0);
|
|
if !ffi::lambdaArgumentGetTensorData(self._c, data.as_mut_ptr()) {
|
|
return Err(CompilerError("couldn't get data".to_string()));
|
|
}
|
|
Ok(data)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl PublicArguments {
|
|
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::wrap(ffi::publicArgumentsSerialize(self._c));
|
|
if serialized_ref.is_null() {
|
|
return Err(CompilerError(serialized_ref.error_msg()));
|
|
}
|
|
Ok(serialized_ref.to_bytes())
|
|
}
|
|
}
|
|
pub fn unserialize(
|
|
serialized: &Vec<c_char>,
|
|
client_parameters: &ClientParameters,
|
|
) -> Result<PublicArguments, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::new(
|
|
serialized.as_ptr() as *const c_char,
|
|
serialized.len().try_into().unwrap(),
|
|
)
|
|
.unwrap();
|
|
let public_args = PublicArguments::wrap(ffi::publicArgumentsUnserialize(
|
|
serialized_ref,
|
|
client_parameters._c,
|
|
));
|
|
if public_args.is_null() {
|
|
return Err(CompilerError(public_args.error_msg()));
|
|
}
|
|
Ok(public_args)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl PublicResult {
|
|
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::wrap(ffi::publicResultSerialize(self._c));
|
|
if serialized_ref.is_null() {
|
|
return Err(CompilerError(serialized_ref.error_msg()));
|
|
}
|
|
Ok(serialized_ref.to_bytes())
|
|
}
|
|
}
|
|
pub fn unserialize(
|
|
serialized: &Vec<c_char>,
|
|
client_parameters: &ClientParameters,
|
|
) -> Result<PublicResult, CompilerError> {
|
|
unsafe {
|
|
let serialized_ref = BufferRef::new(
|
|
serialized.as_ptr() as *const c_char,
|
|
serialized.len().try_into().unwrap(),
|
|
)
|
|
.unwrap();
|
|
let public_result = PublicResult::wrap(ffi::publicResultUnserialize(
|
|
serialized_ref,
|
|
client_parameters._c,
|
|
));
|
|
if public_result.is_null() {
|
|
return Err(CompilerError(public_result.error_msg()));
|
|
}
|
|
Ok(public_result)
|
|
}
|
|
}
|
|
|
|
pub fn decrypt(&self, key_set: &KeySet) -> Result<LambdaArgument, CompilerError> {
|
|
unsafe {
|
|
let arg = LambdaArgument::wrap(ffi::publicResultDecrypt(self._c, key_set.key_set._c));
|
|
if arg.is_null() {
|
|
return Err(CompilerError(format!(
|
|
"Error in compiler (check logs for more info): {}",
|
|
arg.error_msg()
|
|
)));
|
|
}
|
|
Ok(arg)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl CompilationFeedback {
|
|
pub fn get_complexity(&self) -> f64 {
|
|
unsafe { ffi::compilationFeedbackGetComplexity(self._c) }
|
|
}
|
|
|
|
pub fn get_p_error(&self) -> f64 {
|
|
unsafe { ffi::compilationFeedbackGetPError(self._c) }
|
|
}
|
|
|
|
pub fn get_global_p_error(&self) -> f64 {
|
|
unsafe { ffi::compilationFeedbackGetGlobalPError(self._c) }
|
|
}
|
|
|
|
pub fn get_total_secret_keys_size(&self) -> u64 {
|
|
unsafe { ffi::compilationFeedbackGetTotalSecretKeysSize(self._c) }
|
|
}
|
|
|
|
pub fn get_total_bootstrap_keys_size(&self) -> u64 {
|
|
unsafe { ffi::compilationFeedbackGetTotalBootstrapKeysSize(self._c) }
|
|
}
|
|
|
|
pub fn get_total_keyswitch_keys_size(&self) -> u64 {
|
|
unsafe { ffi::compilationFeedbackGetTotalKeyswitchKeysSize(self._c) }
|
|
}
|
|
|
|
pub fn get_total_inputs_size(&self) -> u64 {
|
|
unsafe { ffi::compilationFeedbackGetTotalInputsSize(self._c) }
|
|
}
|
|
|
|
pub fn get_total_outputs_size(&self) -> u64 {
|
|
unsafe { ffi::compilationFeedbackGetTotalOutputsSize(self._c) }
|
|
}
|
|
}
|
|
|
|
/// Parse the MLIR code and returns it.
|
|
///
|
|
/// The function parse the provided MLIR textual representation and returns it. It would fail with
|
|
/// an error message to stderr reporting what's bad with the parsed IR.
|
|
///
|
|
/// # Examples
|
|
/// ```
|
|
/// use concrete_compiler_rust::compiler::*;
|
|
///
|
|
/// let module_to_compile = "
|
|
/// func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
|
/// %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
|
/// return %0 : !FHE.eint<5>
|
|
/// }";
|
|
/// let result_str = round_trip(module_to_compile);
|
|
/// ```
|
|
///
|
|
pub fn round_trip(mlir_code: &str) -> Result<String, CompilerError> {
|
|
let engine = CompilerEngine::new(None).unwrap();
|
|
let compilation_result = engine.compile(mlir_code, ffi::CompilationTarget_ROUND_TRIP)?;
|
|
compilation_result.get_module_string()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use std::env;
|
|
use tempdir::TempDir;
|
|
|
|
use super::*;
|
|
|
|
fn get_runtime_lib_path() -> Option<String> {
|
|
match env::var("CONCRETE_COMPILER_INSTALL_DIR") {
|
|
Ok(val) => Some(val + "/lib/libConcretelangRuntime.so"),
|
|
Err(_e) => None,
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler_round_trip() {
|
|
let module_to_compile = "
|
|
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
|
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
|
return %0 : !FHE.eint<5>
|
|
}";
|
|
let result_str = round_trip(module_to_compile).unwrap();
|
|
let expected_module = "module {
|
|
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
|
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
|
return %0 : !FHE.eint<5>
|
|
}
|
|
}
|
|
";
|
|
assert_eq!(expected_module, result_str);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler_round_trip_invalid_mlir() {
|
|
let module_to_compile = "bla bla bla";
|
|
let result_str = round_trip(module_to_compile);
|
|
assert!(
|
|
matches!(result_str, Err(CompilerError(err)) if err == "Error in compiler (check logs for more info): Could not parse source\n")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler_compile_lib() {
|
|
let module_to_compile = "
|
|
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
|
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
|
return %0 : !FHE.eint<5>
|
|
}";
|
|
let runtime_library_path = get_runtime_lib_path();
|
|
let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap();
|
|
let support =
|
|
LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap();
|
|
let lib = support.compile(module_to_compile, None).unwrap();
|
|
assert!(!lib.is_null());
|
|
// the sharedlib should be enough as a sign that the compilation worked
|
|
assert!(Path::new(support.get_shared_lib_path().as_str()).exists());
|
|
assert!(Path::new(support.get_client_parameters_path().as_str()).exists());
|
|
}
|
|
|
|
/// We want to make sure setting a pointer to null in rust passes the nullptr check in C/Cpp
|
|
#[test]
|
|
fn test_compiler_null_ptr_compatibility() {
|
|
unsafe {
|
|
let lib = ffi::Library {
|
|
ptr: std::ptr::null_mut(),
|
|
error: std::ptr::null_mut(),
|
|
};
|
|
assert!(ffi::libraryIsNull(lib));
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler_load_server_lambda_and_client_parameters() {
|
|
let module_to_compile = "
|
|
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
|
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
|
return %0 : !FHE.eint<5>
|
|
}";
|
|
let runtime_library_path = get_runtime_lib_path();
|
|
let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap();
|
|
let support =
|
|
LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap();
|
|
let result = support.compile(module_to_compile, None).unwrap();
|
|
let server = support.load_server_lambda(&result).unwrap();
|
|
assert!(!server.is_null());
|
|
let client_params = support.load_client_parameters(&result).unwrap();
|
|
assert!(!client_params.is_null());
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler_compile_and_exec_scalar_args() {
|
|
let module_to_compile = "
|
|
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
|
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
|
return %0 : !FHE.eint<5>
|
|
}";
|
|
let runtime_library_path = get_runtime_lib_path();
|
|
let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap();
|
|
let lib_support =
|
|
LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap();
|
|
// compile
|
|
let result = lib_support.compile(module_to_compile, None).unwrap();
|
|
// loading materials from compilation
|
|
// - server_lambda: used for execution
|
|
// - client_parameters: used for keygen, encryption, and evaluation keys
|
|
let server_lambda = lib_support.load_server_lambda(&result).unwrap();
|
|
let client_params = lib_support.load_client_parameters(&result).unwrap();
|
|
let key_set = KeySet::new(&client_params, None, None, None).unwrap();
|
|
let eval_keys = key_set.get_evaluation_keys().unwrap();
|
|
// build lambda arguments from scalar and encrypt them
|
|
let args = [
|
|
LambdaArgument::from_scalar(4).unwrap(),
|
|
LambdaArgument::from_scalar(2).unwrap(),
|
|
];
|
|
let encrypted_args = key_set.encrypt_args(&args).unwrap();
|
|
// execute the compiled function on the encrypted arguments
|
|
let encrypted_result = lib_support
|
|
.server_lambda_call(&server_lambda, &encrypted_args, &eval_keys)
|
|
.unwrap();
|
|
// decrypt the result of execution
|
|
let result_arg = key_set.decrypt_result(&encrypted_result).unwrap();
|
|
// get the scalar value from the result lambda argument
|
|
let result = result_arg.get_scalar().unwrap();
|
|
assert_eq!(result, 6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler_compile_and_exec_with_serialization() {
|
|
let module_to_compile = "
|
|
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
|
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
|
return %0 : !FHE.eint<5>
|
|
}";
|
|
let runtime_library_path = get_runtime_lib_path();
|
|
let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap();
|
|
let lib_support =
|
|
LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap();
|
|
// compile
|
|
let result = lib_support.compile(module_to_compile, None).unwrap();
|
|
// loading materials from compilation
|
|
// - server_lambda: used for execution
|
|
// - client_parameters: used for keygen, encryption, and evaluation keys
|
|
let server_lambda = lib_support.load_server_lambda(&result).unwrap();
|
|
let client_params = lib_support.load_client_parameters(&result).unwrap();
|
|
// serialize client parameters
|
|
let serialized_params = client_params.serialize().unwrap();
|
|
let client_params = ClientParameters::unserialize(&serialized_params).unwrap();
|
|
// generate keys
|
|
let key_set = KeySet::new(&client_params, None, None, None).unwrap();
|
|
let eval_keys = key_set.get_evaluation_keys().unwrap();
|
|
// serialize eval keys
|
|
let serialized_eval_keys = eval_keys.serialize().unwrap();
|
|
let eval_keys = EvaluationKeys::unserialize(&serialized_eval_keys).unwrap();
|
|
// build lambda arguments from scalar and encrypt them
|
|
let args = [
|
|
LambdaArgument::from_scalar(4).unwrap(),
|
|
LambdaArgument::from_scalar(2).unwrap(),
|
|
];
|
|
let encrypted_args = key_set.encrypt_args(&args).unwrap();
|
|
// serialize args
|
|
let serialized_encrypted_args = encrypted_args.serialize().unwrap();
|
|
let encrypted_args =
|
|
PublicArguments::unserialize(&serialized_encrypted_args, &client_params).unwrap();
|
|
// execute the compiled function on the encrypted arguments
|
|
let encrypted_result = lib_support
|
|
.server_lambda_call(&server_lambda, &encrypted_args, &eval_keys)
|
|
.unwrap();
|
|
// serialize result
|
|
let serialized_encrypted_result = encrypted_result.serialize().unwrap();
|
|
let encrypted_result =
|
|
PublicResult::unserialize(&serialized_encrypted_result, &client_params).unwrap();
|
|
// decrypt the result of execution
|
|
let result_arg = key_set.decrypt_result(&encrypted_result).unwrap();
|
|
// get the scalar value from the result lambda argument
|
|
let result = result_arg.get_scalar().unwrap();
|
|
assert_eq!(result, 6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tensor_lambda_argument() {
|
|
let tensor_data = [1, 2, 3, 73u64];
|
|
let tensor_dims = [2, 2i64];
|
|
let tensor_arg = LambdaArgument::from_tensor_u64(&tensor_data, &tensor_dims).unwrap();
|
|
assert!(!tensor_arg.is_null());
|
|
assert!(!tensor_arg.is_scalar());
|
|
assert!(tensor_arg.is_tensor());
|
|
assert_eq!(tensor_arg.get_rank().unwrap(), 2);
|
|
assert_eq!(tensor_arg.get_data_size().unwrap(), 4);
|
|
assert_eq!(tensor_arg.get_dims().unwrap(), tensor_dims);
|
|
assert_eq!(tensor_arg.get_data().unwrap(), tensor_data);
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler_compile_and_exec_tensor_args() {
|
|
let module_to_compile = "
|
|
func.func @main(%arg0: tensor<2x3x!FHE.eint<5>>, %arg1: tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> {
|
|
%0 = \"FHELinalg.add_eint\"(%arg0, %arg1) : (tensor<2x3x!FHE.eint<5>>, tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>>
|
|
return %0 : tensor<2x3x!FHE.eint<5>>
|
|
}";
|
|
let runtime_library_path = get_runtime_lib_path();
|
|
let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap();
|
|
let lib_support =
|
|
LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap();
|
|
// compile
|
|
let result = lib_support.compile(module_to_compile, None).unwrap();
|
|
// loading materials from compilation
|
|
// - server_lambda: used for execution
|
|
// - client_parameters: used for keygen, encryption, and evaluation keys
|
|
let server_lambda = lib_support.load_server_lambda(&result).unwrap();
|
|
let client_params = lib_support.load_client_parameters(&result).unwrap();
|
|
let key_set = KeySet::new(&client_params, None, None, None).unwrap();
|
|
let eval_keys = key_set.get_evaluation_keys().unwrap();
|
|
// build lambda arguments from scalar and encrypt them
|
|
let args = [
|
|
LambdaArgument::from_tensor_u8(&[1, 2, 3, 4, 5, 6], &[2, 3]).unwrap(),
|
|
LambdaArgument::from_tensor_u8(&[1, 4, 7, 4, 2, 9], &[2, 3]).unwrap(),
|
|
];
|
|
let encrypted_args = key_set.encrypt_args(&args).unwrap();
|
|
// execute the compiled function on the encrypted arguments
|
|
let encrypted_result = lib_support
|
|
.server_lambda_call(&server_lambda, &encrypted_args, &eval_keys)
|
|
.unwrap();
|
|
// decrypt the result of execution
|
|
let result_arg = key_set.decrypt_result(&encrypted_result).unwrap();
|
|
// check the tensor dims value from the result lambda argument
|
|
assert_eq!(result_arg.get_rank().unwrap(), 2);
|
|
assert_eq!(result_arg.get_data_size().unwrap(), 6);
|
|
assert_eq!(result_arg.get_dims().unwrap(), [2, 3]);
|
|
// check the tensor data from the result lambda argument
|
|
assert_eq!(result_arg.get_data().unwrap(), [2, 6, 10, 8, 7, 15]);
|
|
}
|
|
}
|