diff --git a/compiler/lib/Bindings/Rust/Cargo.toml b/compiler/lib/Bindings/Rust/Cargo.toml index 4ced9e254..fed09d820 100644 --- a/compiler/lib/Bindings/Rust/Cargo.toml +++ b/compiler/lib/Bindings/Rust/Cargo.toml @@ -6,6 +6,3 @@ edition = "2021" [build-dependencies] bindgen = "0.60.1" -[dev-dependencies] -stdio-override = "0.1.3" -tempfile = "3.3.0" diff --git a/compiler/lib/Bindings/Rust/src/lib.rs b/compiler/lib/Bindings/Rust/src/lib.rs index c64843e25..f8b2818e8 100644 --- a/compiler/lib/Bindings/Rust/src/lib.rs +++ b/compiler/lib/Bindings/Rust/src/lib.rs @@ -2,15 +2,39 @@ #![allow(non_camel_case_types)] #![allow(non_snake_case)] +use std::ops::AddAssign; + include!(concat!(env!("OUT_DIR"), "/bindings.rs")); +pub(crate) unsafe extern "C" fn mlir_rust_string_receiver_callback( + mlirStrRef: MlirStringRef, + user_data: *mut ::std::os::raw::c_void, +) { + let rust_string = &mut *(user_data as *mut String); + let slc = std::slice::from_raw_parts(mlirStrRef.data as *const u8, mlirStrRef.length as usize); + rust_string.add_assign(&String::from_utf8_lossy(slc)); +} + +pub fn print_mlir_operation_to_string(op: MlirOperation) -> String { + let mut rust_string = String::default(); + let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void; + + unsafe { + mlirOperationPrint( + op, + Some(mlir_rust_string_receiver_callback), + receiver_ptr + ); + } + + rust_string +} + + #[cfg(test)] mod concrete_compiler_tests { use super::*; use std::ffi::CString; - use std::fs; - use stdio_override::StderrOverride; - use tempfile::tempdir; #[test] fn test_function_type() { @@ -151,22 +175,21 @@ mod concrete_compiler_tests { let module = mlirModuleCreateEmpty(location); mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op); - // dump module to stderr and capture it in a temp file - let temp_dir = tempdir().unwrap(); - let temp_file = temp_dir.path().join("concrete_compiler_test_stderr"); - let guard = StderrOverride::override_file(&temp_file).unwrap(); - mlirOperationDump(mlirModuleGetOperation(module)); - let printed_module = fs::read_to_string(&temp_file).unwrap(); - drop(guard); - - // assert that textual representation of the created module is as expected - let expected_filename = std::path::Path::new(".") - .parent() - .unwrap() - .join("tests") - .join("test_module_creation_expected.mlir"); - let expected_module = fs::read_to_string(expected_filename).unwrap(); - assert!(printed_module.eq(expected_module.as_str())); + let printed_module = super::print_mlir_operation_to_string(mlirModuleGetOperation(module)); + let expected_module = "\ +module { + func.func @test(%arg0: i64, %arg1: i64) -> i64 { + %0 = arith.addi %arg0, %arg1 : i64 + return %0 : i64 + } +} +"; + assert_eq!( + printed_module, + expected_module, + "left: \n{}, right: \n{}", + printed_module, + expected_module); } } } diff --git a/compiler/lib/Bindings/Rust/tests/test_module_creation_expected.mlir b/compiler/lib/Bindings/Rust/tests/test_module_creation_expected.mlir deleted file mode 100644 index c8394bad6..000000000 --- a/compiler/lib/Bindings/Rust/tests/test_module_creation_expected.mlir +++ /dev/null @@ -1,6 +0,0 @@ -module { - func.func @test(%arg0: i64, %arg1: i64) -> i64 { - %0 = arith.addi %arg0, %arg1 : i64 - return %0 : i64 - } -}