mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(python): CompilerEngine to compile and run
This commit is contained in:
committed by
Quentin Bourgerie
parent
5613c69602
commit
f948db1228
@@ -14,6 +14,11 @@ add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
|
||||
CompilerAPIModule.cpp
|
||||
LINK_LIBS
|
||||
ZAMALANGCAPIHLFHE
|
||||
ZamalangSupport
|
||||
LowLFHEDialect
|
||||
MidLFHEDialect
|
||||
HLFHEDialect
|
||||
Concrete
|
||||
)
|
||||
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)
|
||||
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include <mlir/Parser.h>
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using namespace zamalang;
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
@@ -19,13 +30,39 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
mlir::MLIRContext context;
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
auto mlir_module = mlir::parseSourceString(mlir_input, &context);
|
||||
if (!mlir_module) {
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
auto module_ref = mlir::parseSourceString(mlir_input, &context);
|
||||
if (!module_ref) {
|
||||
throw std::logic_error("mlir parsing failed");
|
||||
}
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
mlir_module->print(os);
|
||||
module_ref->print(os);
|
||||
return os.str();
|
||||
});
|
||||
}
|
||||
|
||||
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def("run",
|
||||
[](CompilerEngine &engine, std::vector<uint64_t> args) {
|
||||
auto result = engine.run(args);
|
||||
if (!result) {
|
||||
llvm::errs()
|
||||
<< "Execution failed: " << result.takeError() << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed running, see previous logs for more info");
|
||||
}
|
||||
return result.get();
|
||||
})
|
||||
.def("compile_fhe",
|
||||
[](CompilerEngine &engine, std::string mlir_input) {
|
||||
auto result = engine.compileFHE(mlir_input);
|
||||
if (!result) {
|
||||
llvm::errs()
|
||||
<< "Compilation failed: " << result.takeError() << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed compiling, see previous logs for more info");
|
||||
}
|
||||
})
|
||||
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
"""Zamalang python module"""
|
||||
from _zamalang import *
|
||||
import _zamalang._compiler as compiler
|
||||
from .compiler import CompilerEngine
|
||||
|
||||
64
compiler/python/zamalang/compiler.py
Normal file
64
compiler/python/zamalang/compiler.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Compiler submodule"""
|
||||
from typing import List
|
||||
from _zamalang._compiler import CompilerEngine as _CompilerEngine
|
||||
from _zamalang._compiler import round_trip as _round_trip
|
||||
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
"""Parse the MLIR input, then return it back.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR code to parse.
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
|
||||
Returns:
|
||||
str: parsed MLIR input.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
return _round_trip(mlir_str)
|
||||
|
||||
class CompilerEngine:
|
||||
def __init__(self, mlir_str: str = None):
|
||||
self._engine = _CompilerEngine()
|
||||
if mlir_str is not None:
|
||||
self.compile_fhe(mlir_str)
|
||||
|
||||
def compile_fhe(self, mlir_str: str) -> "CompilerEngine":
|
||||
"""Compile the MLIR input and build a CompilerEngine.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR to compile.
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine used for execution.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
return self._engine.compile_fhe(mlir_str)
|
||||
|
||||
def run(self, *args: List[int]) -> int:
|
||||
"""Run the compiled code.
|
||||
|
||||
Raises:
|
||||
TypeError: if arguments aren't of type int
|
||||
|
||||
Returns:
|
||||
int: result of execution.
|
||||
"""
|
||||
if not all(isinstance(arg, int) for arg in args):
|
||||
raise TypeError("arguments must be of type int")
|
||||
return self._engine.run(args)
|
||||
|
||||
def get_compiled_module(self) -> str:
|
||||
"""Compiled module in printable form.
|
||||
|
||||
Returns:
|
||||
str: Compiled module in printable form.
|
||||
"""
|
||||
return self._engine.get_compiled_module()
|
||||
@@ -1 +1,2 @@
|
||||
# We need this helpers from the mlir bindings, they are used in the generated files
|
||||
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
"""HLFHE dialect module"""
|
||||
from ._HLFHE_ops_gen import *
|
||||
from _zamalang._hlfhe import *
|
||||
|
||||
Reference in New Issue
Block a user