feat(python): CompilerEngine to compile and run

This commit is contained in:
youben11
2021-08-13 15:53:14 +01:00
committed by Quentin Bourgerie
parent 5613c69602
commit f948db1228
12 changed files with 298 additions and 10 deletions

View File

@@ -14,6 +14,11 @@ add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
CompilerAPIModule.cpp
LINK_LIBS
ZAMALANGCAPIHLFHE
ZamalangSupport
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
Concrete
)
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)

View File

@@ -1,15 +1,26 @@
#include "CompilerAPIModule.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include <mlir/Parser.h>
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/CompilerTools.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <stdexcept>
#include <string>
using namespace zamalang;
using mlir::zamalang::CompilerEngine;
/// Populate the compiler API python module.
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
@@ -19,13 +30,39 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
auto mlir_module = mlir::parseSourceString(mlir_input, &context);
if (!mlir_module) {
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
auto module_ref = mlir::parseSourceString(mlir_input, &context);
if (!module_ref) {
throw std::logic_error("mlir parsing failed");
}
std::string result;
llvm::raw_string_ostream os(result);
mlir_module->print(os);
module_ref->print(os);
return os.str();
});
}
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
.def(pybind11::init())
.def("run",
[](CompilerEngine &engine, std::vector<uint64_t> args) {
auto result = engine.run(args);
if (!result) {
llvm::errs()
<< "Execution failed: " << result.takeError() << "\n";
throw std::runtime_error(
"failed running, see previous logs for more info");
}
return result.get();
})
.def("compile_fhe",
[](CompilerEngine &engine, std::string mlir_input) {
auto result = engine.compileFHE(mlir_input);
if (!result) {
llvm::errs()
<< "Compilation failed: " << result.takeError() << "\n";
throw std::runtime_error(
"failed compiling, see previous logs for more info");
}
})
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
}

View File

@@ -1,2 +1,3 @@
"""Zamalang python module"""
from _zamalang import *
import _zamalang._compiler as compiler
from .compiler import CompilerEngine

View File

@@ -0,0 +1,64 @@
"""Compiler submodule"""
from typing import List
from _zamalang._compiler import CompilerEngine as _CompilerEngine
from _zamalang._compiler import round_trip as _round_trip
def round_trip(mlir_str: str) -> str:
"""Parse the MLIR input, then return it back.
Args:
mlir_str (str): MLIR code to parse.
Raises:
TypeError: if the argument is not an str.
Returns:
str: parsed MLIR input.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return _round_trip(mlir_str)
class CompilerEngine:
def __init__(self, mlir_str: str = None):
self._engine = _CompilerEngine()
if mlir_str is not None:
self.compile_fhe(mlir_str)
def compile_fhe(self, mlir_str: str) -> "CompilerEngine":
"""Compile the MLIR input and build a CompilerEngine.
Args:
mlir_str (str): MLIR to compile.
Raises:
TypeError: if the argument is not an str.
Returns:
CompilerEngine: engine used for execution.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return self._engine.compile_fhe(mlir_str)
def run(self, *args: List[int]) -> int:
"""Run the compiled code.
Raises:
TypeError: if arguments aren't of type int
Returns:
int: result of execution.
"""
if not all(isinstance(arg, int) for arg in args):
raise TypeError("arguments must be of type int")
return self._engine.run(args)
def get_compiled_module(self) -> str:
"""Compiled module in printable form.
Returns:
str: Compiled module in printable form.
"""
return self._engine.get_compiled_module()

View File

@@ -1 +1,2 @@
# We need this helpers from the mlir bindings, they are used in the generated files
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context

View File

@@ -1,2 +1,3 @@
"""HLFHE dialect module"""
from ._HLFHE_ops_gen import *
from _zamalang._hlfhe import *