mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add rust bindings
The rust bindings are intented to access both LLVM/MLIR CAPI as well as the concrete-compiler one. This initial commit provide the API for LLVM/MLIR only. Tests should be used as an example to how to generate a valid DAG of operations in MLIR.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -51,3 +51,5 @@ _build/
|
||||
|
||||
|
||||
compiler/tests/TestLib/out/
|
||||
compiler/lib/Bindings/Rust/target/
|
||||
compiler/lib/Bindings/Rust/Cargo.lock
|
||||
|
||||
@@ -144,6 +144,13 @@ python-bindings: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangPythonModules
|
||||
|
||||
# MLIRCAPIRegistration is currently required for linking rust bindings
|
||||
# This may fade if we can represent it somewhere else, mainly, we may be able to define a single
|
||||
# lib that the rust bindings need to link to, while that lib contains all necessary libs
|
||||
rust-bindings: build-initialized concretecompiler
|
||||
cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration
|
||||
cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo build --release
|
||||
|
||||
clientlib: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangClientLib
|
||||
|
||||
@@ -198,6 +205,10 @@ run-python-tests: python-bindings concretecompiler
|
||||
test-compiler-file-output: concretecompiler
|
||||
pytest -vs tests/test_compiler_file_output
|
||||
|
||||
## rust-tests
|
||||
run-rust-tests: rust-bindings
|
||||
cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo test --release
|
||||
|
||||
## end-to-end-tests
|
||||
|
||||
build-end-to-end-tests: build-end-to-end-jit-test build-end-to-end-jit-fhe build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-fhelinalg build-end-to-end-jit-lambda
|
||||
|
||||
11
compiler/lib/Bindings/Rust/Cargo.toml
Normal file
11
compiler/lib/Bindings/Rust/Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "concrete_compiler_rust"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.60.1"
|
||||
|
||||
[dev-dependencies]
|
||||
stdio-override = "0.1.3"
|
||||
tempfile = "3.3.0"
|
||||
9
compiler/lib/Bindings/Rust/README.md
Normal file
9
compiler/lib/Bindings/Rust/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Rust Bindings
|
||||
|
||||
A Rust library providing an API to the Concrete Compiler.
|
||||
|
||||
### Build
|
||||
|
||||
```bash
|
||||
$ cargo build --release
|
||||
```
|
||||
32
compiler/lib/Bindings/Rust/api.h
Normal file
32
compiler/lib/Bindings/Rust/api.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir-c/AffineExpr.h>
|
||||
#include <mlir-c/AffineMap.h>
|
||||
#include <mlir-c/BuiltinAttributes.h>
|
||||
#include <mlir-c/BuiltinTypes.h>
|
||||
#include <mlir-c/Conversion.h>
|
||||
#include <mlir-c/Debug.h>
|
||||
#include <mlir-c/Diagnostics.h>
|
||||
#include <mlir-c/Dialect/Async.h>
|
||||
#include <mlir-c/Dialect/ControlFlow.h>
|
||||
#include <mlir-c/Dialect/Func.h>
|
||||
#include <mlir-c/Dialect/GPU.h>
|
||||
#include <mlir-c/Dialect/LLVM.h>
|
||||
#include <mlir-c/Dialect/Linalg.h>
|
||||
#include <mlir-c/Dialect/PDL.h>
|
||||
#include <mlir-c/Dialect/Quant.h>
|
||||
#include <mlir-c/Dialect/SCF.h>
|
||||
#include <mlir-c/Dialect/Shape.h>
|
||||
#include <mlir-c/Dialect/SparseTensor.h>
|
||||
#include <mlir-c/Dialect/Tensor.h>
|
||||
#include <mlir-c/ExecutionEngine.h>
|
||||
#include <mlir-c/IR.h>
|
||||
#include <mlir-c/IntegerSet.h>
|
||||
#include <mlir-c/Interfaces.h>
|
||||
#include <mlir-c/Pass.h>
|
||||
#include <mlir-c/Registration.h>
|
||||
#include <mlir-c/Support.h>
|
||||
#include <mlir-c/Transforms.h>
|
||||
273
compiler/lib/Bindings/Rust/build.rs
Normal file
273
compiler/lib/Bindings/Rust/build.rs
Normal file
@@ -0,0 +1,273 @@
|
||||
extern crate bindgen;
|
||||
|
||||
use std::env;
|
||||
use std::error::Error;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::exit;
|
||||
|
||||
const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRMemRefDialect",
|
||||
"MLIRVectorToSPIRV",
|
||||
"MLIRControlFlowInterfaces",
|
||||
"MLIRLinalgToStandard",
|
||||
"MLIRAnalysis",
|
||||
"MLIRSPIRVDeserialization",
|
||||
"MLIRTransformDialect",
|
||||
"MLIRSparseTensorPipelines",
|
||||
"MLIRVectorToGPU",
|
||||
"MLIRTranslateLib",
|
||||
"MLIRPass",
|
||||
"MLIRComplexToLibm",
|
||||
"MLIRInferTypeOpInterface",
|
||||
"MLIRMemRefToSPIRV",
|
||||
"MLIRAMDGPUToROCDL",
|
||||
"MLIRBufferizationTransformOps",
|
||||
"MLIRExecutionEngineUtils",
|
||||
"MLIRNVVMDialect",
|
||||
"MLIRSCFUtils",
|
||||
"MLIRLinalgTransforms",
|
||||
"MLIRParser",
|
||||
"MLIRFuncTransforms",
|
||||
"MLIRTosaTestPasses",
|
||||
"MLIRTosaToArith",
|
||||
"MLIRTensorDialect",
|
||||
"MLIRGPUTransforms",
|
||||
"MLIRLowerableDialectsToLLVM",
|
||||
"MLIRBufferizationToMemRef",
|
||||
"MLIRPresburger",
|
||||
"MLIRFuncDialect",
|
||||
"MLIRPDLToPDLInterp",
|
||||
"MLIRArithmeticTransforms",
|
||||
"MLIRViewLikeInterface",
|
||||
"MLIRTargetCpp",
|
||||
"MLIROpenMPToLLVM",
|
||||
"MLIRSPIRVConversion",
|
||||
"MLIRNVGPUTransforms",
|
||||
"MLIRSparseTensorTransforms",
|
||||
"MLIRAffineAnalysis",
|
||||
"MLIRArmSVETransforms",
|
||||
"MLIRArmNeon2dToIntr",
|
||||
"MLIRDataLayoutInterfaces",
|
||||
"MLIRAffineTransforms",
|
||||
"MLIROpenACCToLLVMIRTranslation",
|
||||
"MLIRTensorUtils",
|
||||
"MLIRSPIRVSerialization",
|
||||
"MLIRShapeToStandard",
|
||||
"MLIRArithmeticToSPIRV",
|
||||
"MLIRArithmeticDialect",
|
||||
"MLIRFuncToSPIRV",
|
||||
"MLIRQuantUtils",
|
||||
"MLIRTensorTilingInterfaceImpl",
|
||||
"MLIRX86VectorToLLVMIRTranslation",
|
||||
"MLIRCopyOpInterface",
|
||||
"MLIRMathToLibm",
|
||||
"MLIRGPUToGPURuntimeTransforms",
|
||||
"MLIRLLVMDialect",
|
||||
"MLIRAffineDialect",
|
||||
"MLIRTransforms",
|
||||
"MLIRVectorTransforms",
|
||||
"MLIROpenMPDialect",
|
||||
"MLIRControlFlowDialect",
|
||||
"MLIRVectorUtils",
|
||||
"MLIRROCDLDialect",
|
||||
"MLIRPDLDialect",
|
||||
"MLIRAsyncDialect",
|
||||
"MLIRLinalgToLLVM",
|
||||
"MLIROpenACCDialect",
|
||||
"MLIRVectorDialect",
|
||||
"MLIROpenACCToSCF",
|
||||
"MLIRIR",
|
||||
"MLIRCAPIIR",
|
||||
"MLIRTargetLLVMIRImport",
|
||||
"MLIRTensorToLinalg",
|
||||
"MLIRCallInterfaces",
|
||||
"MLIRTensorInferTypeOpInterfaceImpl",
|
||||
"MLIRTransformDialectTransforms",
|
||||
"MLIRComplexDialect",
|
||||
"MLIRAffineUtils",
|
||||
"MLIRLoopLikeInterface",
|
||||
"MLIRDialect",
|
||||
"MLIRLinalgUtils",
|
||||
"MLIRSCFToSPIRV",
|
||||
"MLIRAffineToStandard",
|
||||
"MLIRX86VectorDialect",
|
||||
"MLIRGPUToVulkanTransforms",
|
||||
"MLIRRewrite",
|
||||
"MLIRAMXToLLVMIRTranslation",
|
||||
"MLIRInferIntRangeInterface",
|
||||
"MLIRCAPIRegistration",
|
||||
"MLIRNVVMToLLVMIRTranslation",
|
||||
"MLIRAsyncTransforms",
|
||||
"MLIRPDLInterpDialect",
|
||||
"MLIRTransformUtils",
|
||||
"MLIRLinalgDialect",
|
||||
"MLIRMathDialect",
|
||||
"MLIRMemRefTransforms",
|
||||
"MLIRSPIRVModuleCombiner",
|
||||
"MLIRMathToLLVM",
|
||||
"MLIRControlFlowToLLVM",
|
||||
"MLIRArmSVEDialect",
|
||||
"MLIRSPIRVTranslateRegistration",
|
||||
"MLIRToLLVMIRTranslationRegistration",
|
||||
"MLIRSCFDialect",
|
||||
"MLIRTilingInterface",
|
||||
"MLIREmitCDialect",
|
||||
"MLIRTableGen",
|
||||
"MLIRTosaToSCF",
|
||||
"MLIROpenMPToLLVMIRTranslation",
|
||||
"MLIRSupport",
|
||||
"MLIROpenACCToLLVM",
|
||||
"MLIRAMDGPUDialect",
|
||||
"MLIRTosaToLinalg",
|
||||
"MLIRSparseTensorUtils",
|
||||
"MLIRFuncToLLVM",
|
||||
"MLIRTargetLLVMIRExport",
|
||||
"MLIRControlFlowToSPIRV",
|
||||
"MLIRReconcileUnrealizedCasts",
|
||||
"MLIRComplexToStandard",
|
||||
"MLIRMathTransforms",
|
||||
"MLIRSPIRVUtils",
|
||||
"MLIRCastInterfaces",
|
||||
"MLIRTosaToTensor",
|
||||
"MLIRMemRefUtils",
|
||||
"MLIRGPUToSPIRV",
|
||||
"MLIRBufferizationDialect",
|
||||
"MLIRSCFToControlFlow",
|
||||
"MLIRArmSVEToLLVMIRTranslation",
|
||||
"MLIRExecutionEngine",
|
||||
"MLIRBufferizationTransforms",
|
||||
"MLIRSparseTensorDialect",
|
||||
"MLIRTensorToSPIRV",
|
||||
"MLIRVectorToSCF",
|
||||
"MLIRQuantTransforms",
|
||||
"MLIRLLVMToLLVMIRTranslation",
|
||||
"MLIRNVGPUDialect",
|
||||
"MLIRAsyncToLLVM",
|
||||
"MLIRAMXDialect",
|
||||
"MLIRLinalgTransformOps",
|
||||
"MLIRMathToSPIRV",
|
||||
"MLIRSCFToOpenMP",
|
||||
"MLIRShapeDialect",
|
||||
"MLIRGPUToROCDLTransforms",
|
||||
"MLIRGPUToNVVMTransforms",
|
||||
"MLIRTensorTransforms",
|
||||
"MLIRSCFToGPU",
|
||||
"MLIRDialectUtils",
|
||||
"MLIRNVGPUToNVVM",
|
||||
"MLIRTosaDialect",
|
||||
"MLIRVectorToLLVM",
|
||||
"MLIRSPIRVDialect",
|
||||
"MLIRSideEffectInterfaces",
|
||||
"MLIRVectorToROCDL",
|
||||
"MLIRQuantDialect",
|
||||
"MLIRSCFTransforms",
|
||||
"MLIRMLProgramDialect",
|
||||
"MLIRLinalgToSPIRV",
|
||||
"MLIRDLTIDialect",
|
||||
"MLIRLinalgFrontend",
|
||||
"MLIRROCDLToLLVMIRTranslation",
|
||||
"MLIRArmNeonDialect",
|
||||
"MLIRSPIRVToLLVM",
|
||||
"MLIRLLVMIRTransforms",
|
||||
"MLIRTosaTransforms",
|
||||
"MLIRLLVMCommonConversion",
|
||||
"MLIRSCFTransformOps",
|
||||
"MLIRArmNeonToLLVMIRTranslation",
|
||||
"MLIRAMXTransforms",
|
||||
"MLIRSPIRVTransforms",
|
||||
"MLIRMemRefToLLVM",
|
||||
"MLIRSPIRVBinaryUtils",
|
||||
"MLIRLinalgAnalysis",
|
||||
"MLIRArithmeticUtils",
|
||||
"MLIRVectorInterfaces",
|
||||
"MLIRGPUOps",
|
||||
"MLIRComplexToLLVM",
|
||||
"MLIRShapeOpsTransforms",
|
||||
"MLIRX86VectorTransforms",
|
||||
"MLIRArithmeticToLLVM",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
if let Err(error) = run() {
|
||||
eprintln!("{}", error);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn run() -> Result<(), Box<dyn Error>> {
|
||||
// library paths
|
||||
let build_dir = get_build_dir();
|
||||
let lib_dir = build_dir.join("lib");
|
||||
println!("cargo:rustc-link-search={}", lib_dir.to_str().unwrap());
|
||||
|
||||
// include paths
|
||||
let root = std::fs::canonicalize("../../../../")?;
|
||||
let include_paths = [
|
||||
// compiler build
|
||||
build_dir.join("tools/concretelang/include/"),
|
||||
// mlir build
|
||||
build_dir.join("tools/mlir/include"),
|
||||
// llvm build
|
||||
build_dir.join("include"),
|
||||
// compiler
|
||||
root.join("compiler/include/"),
|
||||
// mlir
|
||||
root.join("llvm-project/mlir/include/"),
|
||||
// llvm
|
||||
root.join("llvm-project/llvm/include/"),
|
||||
// concrete-optimizer
|
||||
root.join("compiler/concrete-optimizer/concrete-optimizer-cpp/src/cpp/"),
|
||||
];
|
||||
|
||||
// linking
|
||||
for mlir_static_lib in MLIR_STATIC_LIBS {
|
||||
println!("cargo:rustc-link-lib=static={}", mlir_static_lib);
|
||||
}
|
||||
println!("cargo:rustc-link-lib=static=LLVMSupport");
|
||||
println!("cargo:rustc-link-lib=static=LLVMCore");
|
||||
// required by llvm
|
||||
println!("cargo:rustc-link-lib=tinfo");
|
||||
if let Some(name) = get_system_libcpp() {
|
||||
println!("cargo:rustc-link-lib={}", name);
|
||||
}
|
||||
|
||||
println!("cargo:rerun-if-changed=api.h");
|
||||
bindgen::builder()
|
||||
.header("api.h")
|
||||
.clang_args(
|
||||
include_paths
|
||||
.into_iter()
|
||||
.map(|path| format!("-I{}", path.to_str().unwrap())),
|
||||
)
|
||||
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
|
||||
.generate()
|
||||
.unwrap()
|
||||
.write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_system_libcpp() -> Option<&'static str> {
|
||||
if cfg!(target_env = "msvc") {
|
||||
None
|
||||
} else if cfg!(target_os = "macos") {
|
||||
Some("c++")
|
||||
} else {
|
||||
Some("stdc++")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_build_dir() -> PathBuf {
|
||||
// this env variable can be used to point to a different build directory
|
||||
let build_dir = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
|
||||
Ok(val) => std::path::Path::new(&val).to_path_buf(),
|
||||
Err(_e) => std::path::Path::new(".")
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("build"),
|
||||
};
|
||||
return build_dir;
|
||||
}
|
||||
172
compiler/lib/Bindings/Rust/src/lib.rs
Normal file
172
compiler/lib/Bindings/Rust/src/lib.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
|
||||
#[cfg(test)]
|
||||
mod concrete_compiler_tests {
|
||||
use super::*;
|
||||
use std::ffi::CString;
|
||||
use std::fs;
|
||||
use stdio_override::StderrOverride;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_function_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
let func_type = mlirFunctionTypeGet(context, 0, std::ptr::null(), 0, std::ptr::null());
|
||||
assert!(mlirTypeIsAFunction(func_type));
|
||||
mlirContextDestroy(context);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module_parsing() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
let module_string = "
|
||||
module{
|
||||
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%1 = arith.addi %arg0, %arg1 : i64
|
||||
return %1: i64
|
||||
}
|
||||
}";
|
||||
let module_cstring = CString::new(module_string).unwrap();
|
||||
let module_reference = mlirStringRefCreateFromCString(module_cstring.as_ptr());
|
||||
|
||||
let parsed_module = mlirModuleCreateParse(context, module_reference);
|
||||
let parsed_func = mlirBlockGetFirstOperation(mlirModuleGetBody(parsed_module));
|
||||
|
||||
let func_type_str = CString::new("function_type").unwrap();
|
||||
// just check that we do have a function here, which should be enough to know that parsing worked well
|
||||
assert!(mlirTypeIsAFunction(mlirTypeAttrGetValue(
|
||||
mlirOperationGetAttributeByName(
|
||||
parsed_func,
|
||||
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
|
||||
)
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module_creation() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
|
||||
// input/output types
|
||||
let func_input_types = [
|
||||
mlirIntegerTypeGet(context, 64),
|
||||
mlirIntegerTypeGet(context, 64),
|
||||
];
|
||||
let func_output_type = [mlirIntegerTypeGet(context, 64)];
|
||||
|
||||
// create the main block of the function
|
||||
let func_block = mlirBlockCreate(2, func_input_types.as_ptr(), &location);
|
||||
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
// create addi operation and append it to the block
|
||||
let addi_str = CString::new("arith.addi").unwrap();
|
||||
let mut addi_op_state =
|
||||
mlirOperationStateGet(mlirStringRefCreateFromCString(addi_str.as_ptr()), location);
|
||||
mlirOperationStateAddOperands(
|
||||
&mut addi_op_state,
|
||||
2,
|
||||
[
|
||||
mlirBlockGetArgument(func_block, 0),
|
||||
mlirBlockGetArgument(func_block, 1),
|
||||
]
|
||||
.as_ptr(),
|
||||
);
|
||||
mlirOperationStateEnableResultTypeInference(&mut addi_op_state);
|
||||
let addi_op = mlirOperationCreate(&mut addi_op_state);
|
||||
mlirBlockAppendOwnedOperation(func_block, addi_op);
|
||||
|
||||
// create return operation and append it to the block
|
||||
let ret_str = CString::new("func.return").unwrap();
|
||||
let mut ret_op_state =
|
||||
mlirOperationStateGet(mlirStringRefCreateFromCString(ret_str.as_ptr()), location);
|
||||
mlirOperationStateAddOperands(
|
||||
&mut ret_op_state,
|
||||
1,
|
||||
[mlirOperationGetResult(addi_op, 0)].as_ptr(),
|
||||
);
|
||||
let ret_op = mlirOperationCreate(&mut ret_op_state);
|
||||
mlirBlockAppendOwnedOperation(func_block, ret_op);
|
||||
|
||||
// create region to hold the previously created block
|
||||
let func_region = mlirRegionCreate();
|
||||
mlirRegionAppendOwnedBlock(func_region, func_block);
|
||||
|
||||
// create function to hold the previously created region
|
||||
let func_str = CString::new("func.func").unwrap();
|
||||
let mut func_op_state =
|
||||
mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location);
|
||||
mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr());
|
||||
// set function attributes
|
||||
let func_type_str = CString::new("function_type").unwrap();
|
||||
let sym_name_str = CString::new("sym_name").unwrap();
|
||||
let func_name_str = CString::new("test").unwrap();
|
||||
let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
context,
|
||||
2,
|
||||
func_input_types.as_ptr(),
|
||||
1,
|
||||
func_output_type.as_ptr(),
|
||||
));
|
||||
let sym_name_attr = mlirStringAttrGet(
|
||||
context,
|
||||
mlirStringRefCreateFromCString(func_name_str.as_ptr()),
|
||||
);
|
||||
mlirOperationStateAddAttributes(
|
||||
&mut func_op_state,
|
||||
2,
|
||||
[
|
||||
// func type
|
||||
mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(
|
||||
context,
|
||||
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
|
||||
),
|
||||
func_type_attr,
|
||||
),
|
||||
// func name
|
||||
mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(
|
||||
context,
|
||||
mlirStringRefCreateFromCString(sym_name_str.as_ptr()),
|
||||
),
|
||||
sym_name_attr,
|
||||
),
|
||||
]
|
||||
.as_ptr(),
|
||||
);
|
||||
let func_op = mlirOperationCreate(&mut func_op_state);
|
||||
|
||||
// create module to hold the previously created function
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
|
||||
|
||||
// dump module to stderr and capture it in a temp file
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let temp_file = temp_dir.path().join("concrete_compiler_test_stderr");
|
||||
let guard = StderrOverride::override_file(&temp_file).unwrap();
|
||||
mlirOperationDump(mlirModuleGetOperation(module));
|
||||
let printed_module = fs::read_to_string(&temp_file).unwrap();
|
||||
drop(guard);
|
||||
|
||||
// assert that textual representation of the created module is as expected
|
||||
let expected_filename = std::path::Path::new(".")
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("tests")
|
||||
.join("test_module_creation_expected.mlir");
|
||||
let expected_module = fs::read_to_string(expected_filename).unwrap();
|
||||
assert!(printed_module.eq(expected_module.as_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
module {
|
||||
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%0 = arith.addi %arg0, %arg1 : i64
|
||||
return %0 : i64
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user