mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(compiler): update mlir-c usage in rust bindings
This commit is contained in:
@@ -30,5 +30,6 @@
|
||||
#include <mlir-c/IntegerSet.h>
|
||||
#include <mlir-c/Interfaces.h>
|
||||
#include <mlir-c/Pass.h>
|
||||
#include <mlir-c/RegisterEverything.h>
|
||||
#include <mlir-c/Support.h>
|
||||
#include <mlir-c/Transforms.h>
|
||||
|
||||
@@ -1156,7 +1156,6 @@ mod test {
|
||||
return %0 : !FHE.eint<5>
|
||||
}";
|
||||
let runtime_library_path = runtime_lib_path();
|
||||
let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap();
|
||||
let temp_dir = TempDir::new("concrete_compiler_test").unwrap();
|
||||
let support =
|
||||
LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap();
|
||||
|
||||
@@ -253,7 +253,7 @@ mod test {
|
||||
fn test_invalid_fhe_eint_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
let invalid_eint = fheEncryptedIntegerTypeGetChecked(context, 0);
|
||||
assert!(invalid_eint.isError);
|
||||
}
|
||||
@@ -263,7 +263,7 @@ mod test {
|
||||
fn test_valid_fhe_eint_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
|
||||
@@ -282,7 +282,7 @@ mod test {
|
||||
fn test_valid_fhe_esint_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 5);
|
||||
@@ -301,7 +301,7 @@ mod test {
|
||||
fn test_fhe_func() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -360,7 +360,7 @@ module {
|
||||
fn test_zero_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -382,7 +382,7 @@ module {
|
||||
fn test_zero_tensor_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -413,7 +413,7 @@ module {
|
||||
fn test_add_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -447,7 +447,7 @@ module {
|
||||
fn test_add_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -485,7 +485,7 @@ module {
|
||||
fn test_sub_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -519,7 +519,7 @@ module {
|
||||
fn test_sub_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -557,7 +557,7 @@ module {
|
||||
fn test_sub_int_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -595,7 +595,7 @@ module {
|
||||
fn test_negate_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -628,7 +628,7 @@ module {
|
||||
fn test_mul_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -666,7 +666,7 @@ module {
|
||||
fn test_to_signed_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -699,7 +699,7 @@ module {
|
||||
fn test_to_unsigned_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -732,7 +732,7 @@ module {
|
||||
fn test_apply_lut_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
|
||||
@@ -557,7 +557,7 @@ mod test {
|
||||
fn test_fhelinalg_func() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHELinalg dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
@@ -630,7 +630,7 @@ module {
|
||||
fn test_add_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -668,7 +668,7 @@ module {
|
||||
fn test_add_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -711,7 +711,7 @@ module {
|
||||
fn test_sub_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -749,7 +749,7 @@ module {
|
||||
fn test_sub_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -793,7 +793,7 @@ module {
|
||||
fn test_sub_int_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -836,7 +836,7 @@ module {
|
||||
fn test_neg_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -870,7 +870,7 @@ module {
|
||||
fn test_mul_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -914,7 +914,7 @@ module {
|
||||
fn test_apply_lut_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -954,7 +954,7 @@ module {
|
||||
fn test_apply_multi_lut_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -999,7 +999,7 @@ module {
|
||||
fn test_apply_mapped_lut_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1049,7 +1049,7 @@ module {
|
||||
fn test_dot_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1092,7 +1092,7 @@ module {
|
||||
fn test_matmul_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1135,7 +1135,7 @@ module {
|
||||
fn test_matmul_int_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1178,7 +1178,7 @@ module {
|
||||
fn test_sum_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1218,7 +1218,7 @@ module {
|
||||
fn test_concat_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1258,7 +1258,7 @@ tensor<6x3x!FHE.eint<4>>";
|
||||
fn test_conv2d_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1312,7 +1312,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
|
||||
fn test_transpose_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1351,7 +1351,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
|
||||
fn test_from_element_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1387,7 +1387,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
|
||||
fn test_to_signed_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
@@ -1421,7 +1421,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
|
||||
fn test_to_unsigned_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
@@ -59,7 +59,7 @@ pub fn print_mlir_type_to_string(mlir_type: MlirType) -> String {
|
||||
/// use concrete_compiler::mlir::ffi::*;
|
||||
/// unsafe{
|
||||
/// let context = mlirContextCreate();
|
||||
/// mlirRegisterAllDialects(context);
|
||||
/// register_all_dialects(context);
|
||||
///
|
||||
/// // input/output types
|
||||
/// let func_input_types = [
|
||||
@@ -275,6 +275,13 @@ pub fn create_constant_tensor_op(
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn register_all_dialects(context: MlirContext) {
|
||||
let registry = mlirDialectRegistryCreate();
|
||||
mlirRegisterAllDialects(registry);
|
||||
mlirContextAppendDialectRegistry(context, registry);
|
||||
mlirContextLoadAllAvailableDialects(context);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
@@ -293,7 +300,7 @@ mod test {
|
||||
fn test_module_parsing() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
let module_string = "
|
||||
module{
|
||||
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
|
||||
@@ -322,7 +329,7 @@ mod test {
|
||||
fn test_module_creation() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// input/output types
|
||||
let func_input_types = [
|
||||
@@ -378,7 +385,7 @@ module {
|
||||
fn test_constant_flat_tensor() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant flat tensor
|
||||
let contant_flat_tensor_op = create_constant_flat_tensor_op(context, &[0, 1, 2, 3], 64);
|
||||
@@ -393,7 +400,7 @@ module {
|
||||
fn test_constant_tensor() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant tensor
|
||||
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0, 1, 2, 3], 64);
|
||||
@@ -408,7 +415,7 @@ module {
|
||||
fn test_constant_tensor_with_signle_elem() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant tensor
|
||||
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0], 7);
|
||||
@@ -423,7 +430,7 @@ module {
|
||||
fn test_constant_int() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant flat tensor
|
||||
let contant_int_op = create_constant_int_op(context, 73, 10);
|
||||
|
||||
Reference in New Issue
Block a user