mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(rust): print mlir string repr directly into a String
Instead of overriding the process stderr to get the string representation from mlir we can can directly capture in into a string using mlir's printOperation. Another problem with overriding stderr is that each `#[test]` runs as a different thread meaning that as soon as we have 2+ tests the tests could panic due to conflicts/races between the different overrides. This also moves the expected string directly into the test as a literal.
This commit is contained in:
@@ -6,6 +6,3 @@ edition = "2021"
|
||||
[build-dependencies]
|
||||
bindgen = "0.60.1"
|
||||
|
||||
[dev-dependencies]
|
||||
stdio-override = "0.1.3"
|
||||
tempfile = "3.3.0"
|
||||
|
||||
@@ -2,15 +2,39 @@
|
||||
#![allow(non_camel_case_types)]
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::ops::AddAssign;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
|
||||
pub(crate) unsafe extern "C" fn mlir_rust_string_receiver_callback(
|
||||
mlirStrRef: MlirStringRef,
|
||||
user_data: *mut ::std::os::raw::c_void,
|
||||
) {
|
||||
let rust_string = &mut *(user_data as *mut String);
|
||||
let slc = std::slice::from_raw_parts(mlirStrRef.data as *const u8, mlirStrRef.length as usize);
|
||||
rust_string.add_assign(&String::from_utf8_lossy(slc));
|
||||
}
|
||||
|
||||
pub fn print_mlir_operation_to_string(op: MlirOperation) -> String {
|
||||
let mut rust_string = String::default();
|
||||
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
|
||||
|
||||
unsafe {
|
||||
mlirOperationPrint(
|
||||
op,
|
||||
Some(mlir_rust_string_receiver_callback),
|
||||
receiver_ptr
|
||||
);
|
||||
}
|
||||
|
||||
rust_string
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod concrete_compiler_tests {
|
||||
use super::*;
|
||||
use std::ffi::CString;
|
||||
use std::fs;
|
||||
use stdio_override::StderrOverride;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_function_type() {
|
||||
@@ -151,22 +175,21 @@ mod concrete_compiler_tests {
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
|
||||
|
||||
// dump module to stderr and capture it in a temp file
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let temp_file = temp_dir.path().join("concrete_compiler_test_stderr");
|
||||
let guard = StderrOverride::override_file(&temp_file).unwrap();
|
||||
mlirOperationDump(mlirModuleGetOperation(module));
|
||||
let printed_module = fs::read_to_string(&temp_file).unwrap();
|
||||
drop(guard);
|
||||
|
||||
// assert that textual representation of the created module is as expected
|
||||
let expected_filename = std::path::Path::new(".")
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("tests")
|
||||
.join("test_module_creation_expected.mlir");
|
||||
let expected_module = fs::read_to_string(expected_filename).unwrap();
|
||||
assert!(printed_module.eq(expected_module.as_str()));
|
||||
let printed_module = super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
|
||||
let expected_module = "\
|
||||
module {
|
||||
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%0 = arith.addi %arg0, %arg1 : i64
|
||||
return %0 : i64
|
||||
}
|
||||
}
|
||||
";
|
||||
assert_eq!(
|
||||
printed_module,
|
||||
expected_module,
|
||||
"left: \n{}, right: \n{}",
|
||||
printed_module,
|
||||
expected_module);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
module {
|
||||
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%0 = arith.addi %arg0, %arg1 : i64
|
||||
return %0 : i64
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user