fix(compiler): update mlir-c usage in rust bindings

This commit is contained in:
youben11
2023-03-15 16:24:51 +01:00
committed by Ayoub Benaissa
parent 2cdf166b96
commit 529d96f564
5 changed files with 52 additions and 45 deletions

View File

@@ -30,5 +30,6 @@
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Interfaces.h>
#include <mlir-c/Pass.h>
#include <mlir-c/RegisterEverything.h>
#include <mlir-c/Support.h>
#include <mlir-c/Transforms.h>

View File

@@ -1156,7 +1156,6 @@ mod test {
return %0 : !FHE.eint<5>
}";
let runtime_library_path = runtime_lib_path();
let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap();
let temp_dir = TempDir::new("concrete_compiler_test").unwrap();
let support =
LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap();

View File

@@ -253,7 +253,7 @@ mod test {
fn test_invalid_fhe_eint_type() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
let invalid_eint = fheEncryptedIntegerTypeGetChecked(context, 0);
assert!(invalid_eint.isError);
}
@@ -263,7 +263,7 @@ mod test {
fn test_valid_fhe_eint_type() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
@@ -282,7 +282,7 @@ mod test {
fn test_valid_fhe_esint_type() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 5);
@@ -301,7 +301,7 @@ mod test {
fn test_fhe_func() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -360,7 +360,7 @@ module {
fn test_zero_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -382,7 +382,7 @@ module {
fn test_zero_tensor_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -413,7 +413,7 @@ module {
fn test_add_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -447,7 +447,7 @@ module {
fn test_add_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -485,7 +485,7 @@ module {
fn test_sub_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -519,7 +519,7 @@ module {
fn test_sub_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -557,7 +557,7 @@ module {
fn test_sub_int_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -595,7 +595,7 @@ module {
fn test_negate_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -628,7 +628,7 @@ module {
fn test_mul_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -666,7 +666,7 @@ module {
fn test_to_signed_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -699,7 +699,7 @@ module {
fn test_to_unsigned_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -732,7 +732,7 @@ module {
fn test_apply_lut_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();

View File

@@ -557,7 +557,7 @@ mod test {
fn test_fhelinalg_func() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHELinalg dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
@@ -630,7 +630,7 @@ module {
fn test_add_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -668,7 +668,7 @@ module {
fn test_add_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -711,7 +711,7 @@ module {
fn test_sub_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -749,7 +749,7 @@ module {
fn test_sub_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -793,7 +793,7 @@ module {
fn test_sub_int_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -836,7 +836,7 @@ module {
fn test_neg_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -870,7 +870,7 @@ module {
fn test_mul_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -914,7 +914,7 @@ module {
fn test_apply_lut_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -954,7 +954,7 @@ module {
fn test_apply_multi_lut_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -999,7 +999,7 @@ module {
fn test_apply_mapped_lut_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1049,7 +1049,7 @@ module {
fn test_dot_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1092,7 +1092,7 @@ module {
fn test_matmul_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1135,7 +1135,7 @@ module {
fn test_matmul_int_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1178,7 +1178,7 @@ module {
fn test_sum_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1218,7 +1218,7 @@ module {
fn test_concat_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1258,7 +1258,7 @@ tensor<6x3x!FHE.eint<4>>";
fn test_conv2d_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1312,7 +1312,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
fn test_transpose_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1351,7 +1351,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
fn test_from_element_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1387,7 +1387,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
fn test_to_signed_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
@@ -1421,7 +1421,7 @@ padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tenso
fn test_to_unsigned_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);

View File

@@ -59,7 +59,7 @@ pub fn print_mlir_type_to_string(mlir_type: MlirType) -> String {
/// use concrete_compiler::mlir::ffi::*;
/// unsafe{
/// let context = mlirContextCreate();
/// mlirRegisterAllDialects(context);
/// register_all_dialects(context);
///
/// // input/output types
/// let func_input_types = [
@@ -275,6 +275,13 @@ pub fn create_constant_tensor_op(
}
}
pub unsafe fn register_all_dialects(context: MlirContext) {
let registry = mlirDialectRegistryCreate();
mlirRegisterAllDialects(registry);
mlirContextAppendDialectRegistry(context, registry);
mlirContextLoadAllAvailableDialects(context);
}
#[cfg(test)]
mod test {
use super::*;
@@ -293,7 +300,7 @@ mod test {
fn test_module_parsing() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
let module_string = "
module{
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
@@ -322,7 +329,7 @@ mod test {
fn test_module_creation() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// input/output types
let func_input_types = [
@@ -378,7 +385,7 @@ module {
fn test_constant_flat_tensor() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// create a constant flat tensor
let contant_flat_tensor_op = create_constant_flat_tensor_op(context, &[0, 1, 2, 3], 64);
@@ -393,7 +400,7 @@ module {
fn test_constant_tensor() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// create a constant tensor
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0, 1, 2, 3], 64);
@@ -408,7 +415,7 @@ module {
fn test_constant_tensor_with_signle_elem() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// create a constant tensor
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0], 7);
@@ -423,7 +430,7 @@ module {
fn test_constant_int() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
register_all_dialects(context);
// create a constant flat tensor
let contant_int_op = create_constant_int_op(context, 73, 10);