feat(python): CompilerEngine to compile and run

This commit is contained in:
youben11
2021-08-13 15:53:14 +01:00
committed by Quentin Bourgerie
parent 5613c69602
commit f948db1228
12 changed files with 298 additions and 10 deletions

View File

@@ -38,4 +38,4 @@ jobs:
username: ${{ secrets.GHCR_LOGIN }}
password: ${{ secrets.GHCR_PASSWORD }}
options: -v ${{ github.workspace }}:/workspace -e PYTHONPATH=/llvm-project/build/python:/workspace/compiler/build/python
run: cd /workspace/compiler && mkdir build && cmake -B build . -DCONCRETE_FFI_RELEASE=/workspace/concrete/target/release -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ all zamacompiler && make test && pip install pytest && make test_python
run: cd /workspace/compiler && mkdir build && cmake -B build . -DCONCRETE_FFI_RELEASE=/workspace/concrete/target/release -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ all zamacompiler && make test && pip install pytest && LD_PRELOAD=/workspace/concrete/target/release/libconcrete_ffi.so make test_python

View File

@@ -20,11 +20,11 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: build
run: docker build -t $IMAGE -f builders/Dockerfile.zamalang-env .
- name: login
run: echo "${{ secrets.GHCR_PASSWORD }}" | docker login -u ${{ secrets.GHCR_LOGIN }} --password-stdin ghcr.io
- name: build
run: docker build -t $IMAGE -f builders/Dockerfile.zamalang-env .
- name: tag and publish
run: |

View File

@@ -0,0 +1,49 @@
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <string>
namespace mlir {
namespace zamalang {
class CompilerEngine {
public:
CompilerEngine() {
context = new mlir::MLIRContext();
loadDialects();
}
~CompilerEngine() {
if (context != nullptr)
delete context;
}
// Compile an MLIR input
llvm::Expected<mlir::LogicalResult> compileFHE(std::string mlir_input);
// Run the compiled module
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
// Get a printable representation of the compiled module
std::string getCompiledModule();
private:
// Load the necessary dialects into the engine's context
void loadDialects();
mlir::OwningModuleRef module_ref;
mlir::MLIRContext *context;
std::unique_ptr<mlir::zamalang::KeySet> keySet;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1,5 +1,6 @@
add_mlir_library(ZamalangSupport
CompilerTools.cpp
CompilerEngine.cpp
V0Parameters.cpp
ClientParameters.cpp
KeySet.cpp

View File

@@ -0,0 +1,106 @@
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Conversion/Passes.h"
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
namespace mlir {
namespace zamalang {
void CompilerEngine::loadDialects() {
context->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
context->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
context->getOrLoadDialect<mlir::StandardOpsDialect>();
context->getOrLoadDialect<mlir::memref::MemRefDialect>();
context->getOrLoadDialect<mlir::linalg::LinalgDialect>();
context->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
}
std::string CompilerEngine::getCompiledModule() {
std::string compiledModule;
llvm::raw_string_ostream os(compiledModule);
module_ref->print(os);
return os.str();
}
llvm::Expected<mlir::LogicalResult>
CompilerEngine::compileFHE(std::string mlir_input) {
module_ref = mlir::parseSourceString(mlir_input, context);
if (!module_ref) {
return llvm::make_error<llvm::StringError>("mlir parsing failed",
llvm::inconvertibleErrorCode());
}
mlir::zamalang::FHECircuitConstraint constraint;
mlir::zamalang::V0Parameter v0Parameter;
// Lower to MLIR Std
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
*context, module_ref.get(), constraint, v0Parameter)
.failed()) {
return llvm::make_error<llvm::StringError>("failed to lower to MLIR Std",
llvm::inconvertibleErrorCode());
}
// Create the client parameters
auto clientParameter = mlir::zamalang::createClientParametersForV0(
v0Parameter, constraint.p, "main", module_ref.get());
if (auto err = clientParameter.takeError()) {
return llvm::make_error<llvm::StringError>(
"cannot generate client parameters", llvm::inconvertibleErrorCode());
}
auto maybeKeySet =
mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0);
if (auto err = maybeKeySet.takeError()) {
return llvm::make_error<llvm::StringError>("cannot generate keyset",
llvm::inconvertibleErrorCode());
}
keySet = std::move(maybeKeySet.get());
// Lower to MLIR LLVM Dialect
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
*context, module_ref.get())
.failed()) {
return llvm::make_error<llvm::StringError>(
"failed to lower to LLVM dialect", llvm::inconvertibleErrorCode());
}
return mlir::success();
}
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
// Create the JIT lambda
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
auto module = module_ref.get();
auto maybeLambda =
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
if (!maybeLambda) {
return llvm::make_error<llvm::StringError>("couldn't create lambda",
llvm::inconvertibleErrorCode());
}
auto lambda = std::move(maybeLambda.get());
// Create the arguments of the JIT lambda
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(*keySet);
if (auto err = maybeArguments.takeError()) {
return llvm::make_error<llvm::StringError>("cannot create lambda args",
llvm::inconvertibleErrorCode());
}
// Set the arguments
auto arguments = std::move(maybeArguments.get());
for (auto i = 0; i < args.size(); i++) {
if (auto err = arguments->setArg(i, args[i])) {
return llvm::make_error<llvm::StringError>(
"cannot push argument", llvm::inconvertibleErrorCode());
}
}
// Invoke the lambda
if (lambda->invoke(*arguments)) {
return llvm::make_error<llvm::StringError>("failed execution",
llvm::inconvertibleErrorCode());
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
return llvm::make_error<llvm::StringError>("cannot get result",
llvm::inconvertibleErrorCode());
}
return res;
}
} // namespace zamalang
} // namespace mlir

View File

@@ -14,6 +14,11 @@ add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
CompilerAPIModule.cpp
LINK_LIBS
ZAMALANGCAPIHLFHE
ZamalangSupport
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
Concrete
)
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)

View File

@@ -1,15 +1,26 @@
#include "CompilerAPIModule.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include <mlir/Parser.h>
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/CompilerTools.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <stdexcept>
#include <string>
using namespace zamalang;
using mlir::zamalang::CompilerEngine;
/// Populate the compiler API python module.
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
@@ -19,13 +30,39 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
auto mlir_module = mlir::parseSourceString(mlir_input, &context);
if (!mlir_module) {
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
auto module_ref = mlir::parseSourceString(mlir_input, &context);
if (!module_ref) {
throw std::logic_error("mlir parsing failed");
}
std::string result;
llvm::raw_string_ostream os(result);
mlir_module->print(os);
module_ref->print(os);
return os.str();
});
}
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
.def(pybind11::init())
.def("run",
[](CompilerEngine &engine, std::vector<uint64_t> args) {
auto result = engine.run(args);
if (!result) {
llvm::errs()
<< "Execution failed: " << result.takeError() << "\n";
throw std::runtime_error(
"failed running, see previous logs for more info");
}
return result.get();
})
.def("compile_fhe",
[](CompilerEngine &engine, std::string mlir_input) {
auto result = engine.compileFHE(mlir_input);
if (!result) {
llvm::errs()
<< "Compilation failed: " << result.takeError() << "\n";
throw std::runtime_error(
"failed compiling, see previous logs for more info");
}
})
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
}

View File

@@ -1,2 +1,3 @@
"""Zamalang python module"""
from _zamalang import *
import _zamalang._compiler as compiler
from .compiler import CompilerEngine

View File

@@ -0,0 +1,64 @@
"""Compiler submodule"""
from typing import List
from _zamalang._compiler import CompilerEngine as _CompilerEngine
from _zamalang._compiler import round_trip as _round_trip
def round_trip(mlir_str: str) -> str:
"""Parse the MLIR input, then return it back.
Args:
mlir_str (str): MLIR code to parse.
Raises:
TypeError: if the argument is not an str.
Returns:
str: parsed MLIR input.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return _round_trip(mlir_str)
class CompilerEngine:
def __init__(self, mlir_str: str = None):
self._engine = _CompilerEngine()
if mlir_str is not None:
self.compile_fhe(mlir_str)
def compile_fhe(self, mlir_str: str) -> "CompilerEngine":
"""Compile the MLIR input and build a CompilerEngine.
Args:
mlir_str (str): MLIR to compile.
Raises:
TypeError: if the argument is not an str.
Returns:
CompilerEngine: engine used for execution.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return self._engine.compile_fhe(mlir_str)
def run(self, *args: List[int]) -> int:
"""Run the compiled code.
Raises:
TypeError: if arguments aren't of type int
Returns:
int: result of execution.
"""
if not all(isinstance(arg, int) for arg in args):
raise TypeError("arguments must be of type int")
return self._engine.run(args)
def get_compiled_module(self) -> str:
"""Compiled module in printable form.
Returns:
str: Compiled module in printable form.
"""
return self._engine.get_compiled_module()

View File

@@ -1 +1,2 @@
# We need this helpers from the mlir bindings, they are used in the generated files
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context

View File

@@ -1,2 +1,3 @@
"""HLFHE dialect module"""
from ._HLFHE_ops_gen import *
from _zamalang._hlfhe import *

View File

@@ -0,0 +1,23 @@
import pytest
from zamalang import CompilerEngine
@pytest.mark.parametrize(
"mlir_input, args, expected_result",
[
(
"""
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
""",
(5, 7),
12,
),
],
)
def test_compile_and_run(mlir_input, args, expected_result):
engine = CompilerEngine()
engine.compile_fhe(mlir_input)
assert engine.run(*args) == expected_result