feat(rust): print mlir string repr directly into a String

Instead of overriding the process stderr to get
the string representation from mlir we can can
directly capture in into a string using mlir's
printOperation.

Another problem with overriding stderr is that
each `#[test]` runs as a different thread meaning that
as soon as we have 2+ tests the tests could panic
due to conflicts/races between the different overrides.

This also moves the expected string directly into the test
as a literal.
This commit is contained in:
tmontaigu
2022-11-02 19:02:12 +01:00
parent d1db4a5e45
commit 6b0f6e9f10
3 changed files with 42 additions and 28 deletions

View File

@@ -6,6 +6,3 @@ edition = "2021"
[build-dependencies]
bindgen = "0.60.1"
[dev-dependencies]
stdio-override = "0.1.3"
tempfile = "3.3.0"

View File

@@ -2,15 +2,39 @@
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
use std::ops::AddAssign;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
pub(crate) unsafe extern "C" fn mlir_rust_string_receiver_callback(
mlirStrRef: MlirStringRef,
user_data: *mut ::std::os::raw::c_void,
) {
let rust_string = &mut *(user_data as *mut String);
let slc = std::slice::from_raw_parts(mlirStrRef.data as *const u8, mlirStrRef.length as usize);
rust_string.add_assign(&String::from_utf8_lossy(slc));
}
pub fn print_mlir_operation_to_string(op: MlirOperation) -> String {
let mut rust_string = String::default();
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
unsafe {
mlirOperationPrint(
op,
Some(mlir_rust_string_receiver_callback),
receiver_ptr
);
}
rust_string
}
#[cfg(test)]
mod concrete_compiler_tests {
use super::*;
use std::ffi::CString;
use std::fs;
use stdio_override::StderrOverride;
use tempfile::tempdir;
#[test]
fn test_function_type() {
@@ -151,22 +175,21 @@ mod concrete_compiler_tests {
let module = mlirModuleCreateEmpty(location);
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
// dump module to stderr and capture it in a temp file
let temp_dir = tempdir().unwrap();
let temp_file = temp_dir.path().join("concrete_compiler_test_stderr");
let guard = StderrOverride::override_file(&temp_file).unwrap();
mlirOperationDump(mlirModuleGetOperation(module));
let printed_module = fs::read_to_string(&temp_file).unwrap();
drop(guard);
// assert that textual representation of the created module is as expected
let expected_filename = std::path::Path::new(".")
.parent()
.unwrap()
.join("tests")
.join("test_module_creation_expected.mlir");
let expected_module = fs::read_to_string(expected_filename).unwrap();
assert!(printed_module.eq(expected_module.as_str()));
let printed_module = super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
let expected_module = "\
module {
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
%0 = arith.addi %arg0, %arg1 : i64
return %0 : i64
}
}
";
assert_eq!(
printed_module,
expected_module,
"left: \n{}, right: \n{}",
printed_module,
expected_module);
}
}
}

View File

@@ -1,6 +0,0 @@
module {
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
%0 = arith.addi %arg0, %arg1 : i64
return %0 : i64
}
}