diff --git a/compiler/python/CMakeLists.txt b/compiler/python/CMakeLists.txt index eaf260c5e..3f6ae83fe 100644 --- a/compiler/python/CMakeLists.txt +++ b/compiler/python/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang SOURCES ZamalangModule.cpp HLFHEModule.cpp + CompilerAPIModule.cpp LINK_LIBS ZAMALANGCAPIHLFHE ) diff --git a/compiler/python/CompilerAPIModule.cpp b/compiler/python/CompilerAPIModule.cpp new file mode 100644 index 000000000..c294b0af2 --- /dev/null +++ b/compiler/python/CompilerAPIModule.cpp @@ -0,0 +1,31 @@ +#include "CompilerAPIModule.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" +#include +#include + +#include +#include +#include +#include + +using namespace zamalang; + +/// Populate the compiler API python module. +void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { + m.doc() = "Zamalang compiler python API"; + + m.def("round_trip", [](std::string mlir_input) { + mlir::MLIRContext context; + context.getOrLoadDialect(); + context.getOrLoadDialect(); + auto mlir_module = mlir::parseSourceString(mlir_input, &context); + if (!mlir_module) { + throw std::logic_error("mlir parsing failed"); + } + std::string result; + llvm::raw_string_ostream os(result); + mlir_module->print(os); + return os.str(); + }); +} \ No newline at end of file diff --git a/compiler/python/CompilerAPIModule.h b/compiler/python/CompilerAPIModule.h new file mode 100644 index 000000000..5989938d0 --- /dev/null +++ b/compiler/python/CompilerAPIModule.h @@ -0,0 +1,14 @@ +#ifndef ZAMALANG_PYTHON_COMPILER_API_MODULE_H +#define ZAMALANG_PYTHON_COMPILER_API_MODULE_H + +#include + +namespace zamalang { +namespace python { + +void populateCompilerAPISubmodule(pybind11::module &m); + +} // namespace python +} // namespace zamalang + +#endif // ZAMALANG_PYTHON_DIALECTMODULES_H \ No newline at end of file diff --git a/compiler/python/ZamalangModule.cpp b/compiler/python/ZamalangModule.cpp index e2755ef35..9d6437ff9 100644 --- a/compiler/python/ZamalangModule.cpp +++ b/compiler/python/ZamalangModule.cpp @@ -1,4 +1,5 @@ #include "DialectModules.h" +#include "CompilerAPIModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Registration.h" @@ -32,4 +33,7 @@ PYBIND11_MODULE(_zamalang, m) { py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API"); zamalang::python::populateDialectHLFHESubmodule(hlfhe); + + py::module api = m.def_submodule("_compiler", "Compiler API"); + zamalang::python::populateCompilerAPISubmodule(api); } \ No newline at end of file diff --git a/compiler/python/zamalang/__init__.py b/compiler/python/zamalang/__init__.py index 4e4982057..abb0d8426 100644 --- a/compiler/python/zamalang/__init__.py +++ b/compiler/python/zamalang/__init__.py @@ -1 +1,2 @@ from _zamalang import * +import _zamalang._compiler as compiler