mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(python): CompilerEngine to compile and run
This commit is contained in:
committed by
Quentin Bourgerie
parent
5613c69602
commit
f948db1228
2
.github/workflows/conformance.yml
vendored
2
.github/workflows/conformance.yml
vendored
@@ -38,4 +38,4 @@ jobs:
|
||||
username: ${{ secrets.GHCR_LOGIN }}
|
||||
password: ${{ secrets.GHCR_PASSWORD }}
|
||||
options: -v ${{ github.workspace }}:/workspace -e PYTHONPATH=/llvm-project/build/python:/workspace/compiler/build/python
|
||||
run: cd /workspace/compiler && mkdir build && cmake -B build . -DCONCRETE_FFI_RELEASE=/workspace/concrete/target/release -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ all zamacompiler && make test && pip install pytest && make test_python
|
||||
run: cd /workspace/compiler && mkdir build && cmake -B build . -DCONCRETE_FFI_RELEASE=/workspace/concrete/target/release -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ all zamacompiler && make test && pip install pytest && LD_PRELOAD=/workspace/concrete/target/release/libconcrete_ffi.so make test_python
|
||||
|
||||
6
.github/workflows/docker-zamalang.yml
vendored
6
.github/workflows/docker-zamalang.yml
vendored
@@ -20,11 +20,11 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: build
|
||||
run: docker build -t $IMAGE -f builders/Dockerfile.zamalang-env .
|
||||
|
||||
- name: login
|
||||
run: echo "${{ secrets.GHCR_PASSWORD }}" | docker login -u ${{ secrets.GHCR_LOGIN }} --password-stdin ghcr.io
|
||||
|
||||
- name: build
|
||||
run: docker build -t $IMAGE -f builders/Dockerfile.zamalang-env .
|
||||
|
||||
- name: tag and publish
|
||||
run: |
|
||||
|
||||
49
compiler/include/zamalang/Support/CompilerEngine.h
Normal file
49
compiler/include/zamalang/Support/CompilerEngine.h
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
class CompilerEngine {
|
||||
public:
|
||||
CompilerEngine() {
|
||||
context = new mlir::MLIRContext();
|
||||
loadDialects();
|
||||
}
|
||||
~CompilerEngine() {
|
||||
if (context != nullptr)
|
||||
delete context;
|
||||
}
|
||||
|
||||
// Compile an MLIR input
|
||||
llvm::Expected<mlir::LogicalResult> compileFHE(std::string mlir_input);
|
||||
|
||||
// Run the compiled module
|
||||
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
|
||||
|
||||
// Get a printable representation of the compiled module
|
||||
std::string getCompiledModule();
|
||||
|
||||
private:
|
||||
// Load the necessary dialects into the engine's context
|
||||
void loadDialects();
|
||||
|
||||
mlir::OwningModuleRef module_ref;
|
||||
mlir::MLIRContext *context;
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
};
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,5 +1,6 @@
|
||||
add_mlir_library(ZamalangSupport
|
||||
CompilerTools.cpp
|
||||
CompilerEngine.cpp
|
||||
V0Parameters.cpp
|
||||
ClientParameters.cpp
|
||||
KeySet.cpp
|
||||
|
||||
106
compiler/lib/Support/CompilerEngine.cpp
Normal file
106
compiler/lib/Support/CompilerEngine.cpp
Normal file
@@ -0,0 +1,106 @@
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
void CompilerEngine::loadDialects() {
|
||||
context->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context->getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
context->getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
context->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
std::string CompilerEngine::getCompiledModule() {
|
||||
std::string compiledModule;
|
||||
llvm::raw_string_ostream os(compiledModule);
|
||||
module_ref->print(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::LogicalResult>
|
||||
CompilerEngine::compileFHE(std::string mlir_input) {
|
||||
module_ref = mlir::parseSourceString(mlir_input, context);
|
||||
if (!module_ref) {
|
||||
return llvm::make_error<llvm::StringError>("mlir parsing failed",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
mlir::zamalang::FHECircuitConstraint constraint;
|
||||
mlir::zamalang::V0Parameter v0Parameter;
|
||||
// Lower to MLIR Std
|
||||
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect(
|
||||
*context, module_ref.get(), constraint, v0Parameter)
|
||||
.failed()) {
|
||||
return llvm::make_error<llvm::StringError>("failed to lower to MLIR Std",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
v0Parameter, constraint.p, "main", module_ref.get());
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot generate client parameters", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto maybeKeySet =
|
||||
mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0);
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
return llvm::make_error<llvm::StringError>("cannot generate keyset",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
keySet = std::move(maybeKeySet.get());
|
||||
|
||||
// Lower to MLIR LLVM Dialect
|
||||
if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect(
|
||||
*context, module_ref.get())
|
||||
.failed()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to lower to LLVM dialect", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
|
||||
// Create the JIT lambda
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
auto module = module_ref.get();
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
|
||||
if (!maybeLambda) {
|
||||
return llvm::make_error<llvm::StringError>("couldn't create lambda",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(*keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
return llvm::make_error<llvm::StringError>("cannot create lambda args",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
for (auto i = 0; i < args.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, args[i])) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"cannot push argument", llvm::inconvertibleErrorCode());
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (lambda->invoke(*arguments)) {
|
||||
return llvm::make_error<llvm::StringError>("failed execution",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
return llvm::make_error<llvm::StringError>("cannot get result",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -14,6 +14,11 @@ add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
|
||||
CompilerAPIModule.cpp
|
||||
LINK_LIBS
|
||||
ZAMALANGCAPIHLFHE
|
||||
ZamalangSupport
|
||||
LowLFHEDialect
|
||||
MidLFHEDialect
|
||||
HLFHEDialect
|
||||
Concrete
|
||||
)
|
||||
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)
|
||||
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include <mlir/Parser.h>
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/CompilerTools.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using namespace zamalang;
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
@@ -19,13 +30,39 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
mlir::MLIRContext context;
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
auto mlir_module = mlir::parseSourceString(mlir_input, &context);
|
||||
if (!mlir_module) {
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
auto module_ref = mlir::parseSourceString(mlir_input, &context);
|
||||
if (!module_ref) {
|
||||
throw std::logic_error("mlir parsing failed");
|
||||
}
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
mlir_module->print(os);
|
||||
module_ref->print(os);
|
||||
return os.str();
|
||||
});
|
||||
}
|
||||
|
||||
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def("run",
|
||||
[](CompilerEngine &engine, std::vector<uint64_t> args) {
|
||||
auto result = engine.run(args);
|
||||
if (!result) {
|
||||
llvm::errs()
|
||||
<< "Execution failed: " << result.takeError() << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed running, see previous logs for more info");
|
||||
}
|
||||
return result.get();
|
||||
})
|
||||
.def("compile_fhe",
|
||||
[](CompilerEngine &engine, std::string mlir_input) {
|
||||
auto result = engine.compileFHE(mlir_input);
|
||||
if (!result) {
|
||||
llvm::errs()
|
||||
<< "Compilation failed: " << result.takeError() << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed compiling, see previous logs for more info");
|
||||
}
|
||||
})
|
||||
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
"""Zamalang python module"""
|
||||
from _zamalang import *
|
||||
import _zamalang._compiler as compiler
|
||||
from .compiler import CompilerEngine
|
||||
|
||||
64
compiler/python/zamalang/compiler.py
Normal file
64
compiler/python/zamalang/compiler.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Compiler submodule"""
|
||||
from typing import List
|
||||
from _zamalang._compiler import CompilerEngine as _CompilerEngine
|
||||
from _zamalang._compiler import round_trip as _round_trip
|
||||
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
"""Parse the MLIR input, then return it back.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR code to parse.
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
|
||||
Returns:
|
||||
str: parsed MLIR input.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
return _round_trip(mlir_str)
|
||||
|
||||
class CompilerEngine:
|
||||
def __init__(self, mlir_str: str = None):
|
||||
self._engine = _CompilerEngine()
|
||||
if mlir_str is not None:
|
||||
self.compile_fhe(mlir_str)
|
||||
|
||||
def compile_fhe(self, mlir_str: str) -> "CompilerEngine":
|
||||
"""Compile the MLIR input and build a CompilerEngine.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR to compile.
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine used for execution.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
return self._engine.compile_fhe(mlir_str)
|
||||
|
||||
def run(self, *args: List[int]) -> int:
|
||||
"""Run the compiled code.
|
||||
|
||||
Raises:
|
||||
TypeError: if arguments aren't of type int
|
||||
|
||||
Returns:
|
||||
int: result of execution.
|
||||
"""
|
||||
if not all(isinstance(arg, int) for arg in args):
|
||||
raise TypeError("arguments must be of type int")
|
||||
return self._engine.run(args)
|
||||
|
||||
def get_compiled_module(self) -> str:
|
||||
"""Compiled module in printable form.
|
||||
|
||||
Returns:
|
||||
str: Compiled module in printable form.
|
||||
"""
|
||||
return self._engine.get_compiled_module()
|
||||
@@ -1 +1,2 @@
|
||||
# We need this helpers from the mlir bindings, they are used in the generated files
|
||||
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
"""HLFHE dialect module"""
|
||||
from ._HLFHE_ops_gen import *
|
||||
from _zamalang._hlfhe import *
|
||||
|
||||
23
compiler/tests/python/test_compiler_engine.py
Normal file
23
compiler/tests/python/test_compiler_engine.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
from zamalang import CompilerEngine
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result",
|
||||
[
|
||||
(
|
||||
"""
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(5, 7),
|
||||
12,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
assert engine.run(*args) == expected_result
|
||||
Reference in New Issue
Block a user