From f948db122831bce5005de9aacb139fcb11062bb9 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 13 Aug 2021 15:53:14 +0100 Subject: [PATCH] feat(python): CompilerEngine to compile and run --- .github/workflows/conformance.yml | 2 +- .github/workflows/docker-zamalang.yml | 6 +- .../include/zamalang/Support/CompilerEngine.h | 49 ++++++++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/CompilerEngine.cpp | 106 ++++++++++++++++++ compiler/python/CMakeLists.txt | 5 + compiler/python/CompilerAPIModule.cpp | 47 +++++++- compiler/python/zamalang/__init__.py | 3 +- compiler/python/zamalang/compiler.py | 64 +++++++++++ .../python/zamalang/dialects/_ods_common.py | 1 + compiler/python/zamalang/dialects/hlfhe.py | 1 + compiler/tests/python/test_compiler_engine.py | 23 ++++ 12 files changed, 298 insertions(+), 10 deletions(-) create mode 100644 compiler/include/zamalang/Support/CompilerEngine.h create mode 100644 compiler/lib/Support/CompilerEngine.cpp create mode 100644 compiler/python/zamalang/compiler.py create mode 100644 compiler/tests/python/test_compiler_engine.py diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 2f47d872f..9c5fba914 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -38,4 +38,4 @@ jobs: username: ${{ secrets.GHCR_LOGIN }} password: ${{ secrets.GHCR_PASSWORD }} options: -v ${{ github.workspace }}:/workspace -e PYTHONPATH=/llvm-project/build/python:/workspace/compiler/build/python - run: cd /workspace/compiler && mkdir build && cmake -B build . -DCONCRETE_FFI_RELEASE=/workspace/concrete/target/release -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ all zamacompiler && make test && pip install pytest && make test_python + run: cd /workspace/compiler && mkdir build && cmake -B build . -DCONCRETE_FFI_RELEASE=/workspace/concrete/target/release -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ all zamacompiler && make test && pip install pytest && LD_PRELOAD=/workspace/concrete/target/release/libconcrete_ffi.so make test_python diff --git a/.github/workflows/docker-zamalang.yml b/.github/workflows/docker-zamalang.yml index 851449b66..b249ac47d 100644 --- a/.github/workflows/docker-zamalang.yml +++ b/.github/workflows/docker-zamalang.yml @@ -20,11 +20,11 @@ jobs: steps: - uses: actions/checkout@v2 - - name: build - run: docker build -t $IMAGE -f builders/Dockerfile.zamalang-env . - - name: login run: echo "${{ secrets.GHCR_PASSWORD }}" | docker login -u ${{ secrets.GHCR_LOGIN }} --password-stdin ghcr.io + + - name: build + run: docker build -t $IMAGE -f builders/Dockerfile.zamalang-env . - name: tag and publish run: | diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h new file mode 100644 index 000000000..5bdf7171d --- /dev/null +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -0,0 +1,49 @@ +#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H +#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H + +#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" +#include "zamalang/Support/CompilerTools.h" +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { +class CompilerEngine { +public: + CompilerEngine() { + context = new mlir::MLIRContext(); + loadDialects(); + } + ~CompilerEngine() { + if (context != nullptr) + delete context; + } + + // Compile an MLIR input + llvm::Expected compileFHE(std::string mlir_input); + + // Run the compiled module + llvm::Expected run(std::vector args); + + // Get a printable representation of the compiled module + std::string getCompiledModule(); + +private: + // Load the necessary dialects into the engine's context + void loadDialects(); + + mlir::OwningModuleRef module_ref; + mlir::MLIRContext *context; + std::unique_ptr keySet; +}; +} // namespace zamalang +} // namespace mlir + +#endif \ No newline at end of file diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 0f677b3dc..42b8e2ee9 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(ZamalangSupport CompilerTools.cpp + CompilerEngine.cpp V0Parameters.cpp ClientParameters.cpp KeySet.cpp diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp new file mode 100644 index 000000000..6e708fabb --- /dev/null +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -0,0 +1,106 @@ +#include "zamalang/Support/CompilerEngine.h" +#include "zamalang/Conversion/Passes.h" +#include +#include + +namespace mlir { +namespace zamalang { + +void CompilerEngine::loadDialects() { + context->getOrLoadDialect(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); +} + +std::string CompilerEngine::getCompiledModule() { + std::string compiledModule; + llvm::raw_string_ostream os(compiledModule); + module_ref->print(os); + return os.str(); +} + +llvm::Expected +CompilerEngine::compileFHE(std::string mlir_input) { + module_ref = mlir::parseSourceString(mlir_input, context); + if (!module_ref) { + return llvm::make_error("mlir parsing failed", + llvm::inconvertibleErrorCode()); + } + mlir::zamalang::FHECircuitConstraint constraint; + mlir::zamalang::V0Parameter v0Parameter; + // Lower to MLIR Std + if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( + *context, module_ref.get(), constraint, v0Parameter) + .failed()) { + return llvm::make_error("failed to lower to MLIR Std", + llvm::inconvertibleErrorCode()); + } + // Create the client parameters + auto clientParameter = mlir::zamalang::createClientParametersForV0( + v0Parameter, constraint.p, "main", module_ref.get()); + if (auto err = clientParameter.takeError()) { + return llvm::make_error( + "cannot generate client parameters", llvm::inconvertibleErrorCode()); + } + auto maybeKeySet = + mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0); + if (auto err = maybeKeySet.takeError()) { + return llvm::make_error("cannot generate keyset", + llvm::inconvertibleErrorCode()); + } + keySet = std::move(maybeKeySet.get()); + + // Lower to MLIR LLVM Dialect + if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( + *context, module_ref.get()) + .failed()) { + return llvm::make_error( + "failed to lower to LLVM dialect", llvm::inconvertibleErrorCode()); + } + return mlir::success(); +} + +llvm::Expected CompilerEngine::run(std::vector args) { + // Create the JIT lambda + auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); + auto module = module_ref.get(); + auto maybeLambda = + mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline); + if (!maybeLambda) { + return llvm::make_error("couldn't create lambda", + llvm::inconvertibleErrorCode()); + } + auto lambda = std::move(maybeLambda.get()); + + // Create the arguments of the JIT lambda + auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(*keySet); + if (auto err = maybeArguments.takeError()) { + return llvm::make_error("cannot create lambda args", + llvm::inconvertibleErrorCode()); + } + // Set the arguments + auto arguments = std::move(maybeArguments.get()); + for (auto i = 0; i < args.size(); i++) { + if (auto err = arguments->setArg(i, args[i])) { + return llvm::make_error( + "cannot push argument", llvm::inconvertibleErrorCode()); + } + } + // Invoke the lambda + if (lambda->invoke(*arguments)) { + return llvm::make_error("failed execution", + llvm::inconvertibleErrorCode()); + } + uint64_t res = 0; + if (auto err = arguments->getResult(0, res)) { + return llvm::make_error("cannot get result", + llvm::inconvertibleErrorCode()); + } + return res; +} +} // namespace zamalang +} // namespace mlir \ No newline at end of file diff --git a/compiler/python/CMakeLists.txt b/compiler/python/CMakeLists.txt index 3f6ae83fe..ab7e6b263 100644 --- a/compiler/python/CMakeLists.txt +++ b/compiler/python/CMakeLists.txt @@ -14,6 +14,11 @@ add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang CompilerAPIModule.cpp LINK_LIBS ZAMALANGCAPIHLFHE + ZamalangSupport + LowLFHEDialect + MidLFHEDialect + HLFHEDialect + Concrete ) add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension) diff --git a/compiler/python/CompilerAPIModule.cpp b/compiler/python/CompilerAPIModule.cpp index c294b0af2..a39e98346 100644 --- a/compiler/python/CompilerAPIModule.cpp +++ b/compiler/python/CompilerAPIModule.cpp @@ -1,15 +1,26 @@ #include "CompilerAPIModule.h" +#include "zamalang/Conversion/Passes.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" -#include +#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" +#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" +#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" +#include "zamalang/Support/CompilerEngine.h" +#include "zamalang/Support/CompilerTools.h" +#include #include +#include +#include #include #include #include +#include #include using namespace zamalang; +using mlir::zamalang::CompilerEngine; /// Populate the compiler API python module. void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { @@ -19,13 +30,39 @@ void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { mlir::MLIRContext context; context.getOrLoadDialect(); context.getOrLoadDialect(); - auto mlir_module = mlir::parseSourceString(mlir_input, &context); - if (!mlir_module) { + context.getOrLoadDialect(); + auto module_ref = mlir::parseSourceString(mlir_input, &context); + if (!module_ref) { throw std::logic_error("mlir parsing failed"); } std::string result; llvm::raw_string_ostream os(result); - mlir_module->print(os); + module_ref->print(os); return os.str(); }); -} \ No newline at end of file + + pybind11::class_(m, "CompilerEngine") + .def(pybind11::init()) + .def("run", + [](CompilerEngine &engine, std::vector args) { + auto result = engine.run(args); + if (!result) { + llvm::errs() + << "Execution failed: " << result.takeError() << "\n"; + throw std::runtime_error( + "failed running, see previous logs for more info"); + } + return result.get(); + }) + .def("compile_fhe", + [](CompilerEngine &engine, std::string mlir_input) { + auto result = engine.compileFHE(mlir_input); + if (!result) { + llvm::errs() + << "Compilation failed: " << result.takeError() << "\n"; + throw std::runtime_error( + "failed compiling, see previous logs for more info"); + } + }) + .def("get_compiled_module", &CompilerEngine::getCompiledModule); +} diff --git a/compiler/python/zamalang/__init__.py b/compiler/python/zamalang/__init__.py index abb0d8426..647b40273 100644 --- a/compiler/python/zamalang/__init__.py +++ b/compiler/python/zamalang/__init__.py @@ -1,2 +1,3 @@ +"""Zamalang python module""" from _zamalang import * -import _zamalang._compiler as compiler +from .compiler import CompilerEngine diff --git a/compiler/python/zamalang/compiler.py b/compiler/python/zamalang/compiler.py new file mode 100644 index 000000000..5e6e9ab2e --- /dev/null +++ b/compiler/python/zamalang/compiler.py @@ -0,0 +1,64 @@ +"""Compiler submodule""" +from typing import List +from _zamalang._compiler import CompilerEngine as _CompilerEngine +from _zamalang._compiler import round_trip as _round_trip + + +def round_trip(mlir_str: str) -> str: + """Parse the MLIR input, then return it back. + + Args: + mlir_str (str): MLIR code to parse. + + Raises: + TypeError: if the argument is not an str. + + Returns: + str: parsed MLIR input. + """ + if not isinstance(mlir_str, str): + raise TypeError("input must be an `str`") + return _round_trip(mlir_str) + +class CompilerEngine: + def __init__(self, mlir_str: str = None): + self._engine = _CompilerEngine() + if mlir_str is not None: + self.compile_fhe(mlir_str) + + def compile_fhe(self, mlir_str: str) -> "CompilerEngine": + """Compile the MLIR input and build a CompilerEngine. + + Args: + mlir_str (str): MLIR to compile. + + Raises: + TypeError: if the argument is not an str. + + Returns: + CompilerEngine: engine used for execution. + """ + if not isinstance(mlir_str, str): + raise TypeError("input must be an `str`") + return self._engine.compile_fhe(mlir_str) + + def run(self, *args: List[int]) -> int: + """Run the compiled code. + + Raises: + TypeError: if arguments aren't of type int + + Returns: + int: result of execution. + """ + if not all(isinstance(arg, int) for arg in args): + raise TypeError("arguments must be of type int") + return self._engine.run(args) + + def get_compiled_module(self) -> str: + """Compiled module in printable form. + + Returns: + str: Compiled module in printable form. + """ + return self._engine.get_compiled_module() diff --git a/compiler/python/zamalang/dialects/_ods_common.py b/compiler/python/zamalang/dialects/_ods_common.py index 58451fd03..2682f6bf5 100644 --- a/compiler/python/zamalang/dialects/_ods_common.py +++ b/compiler/python/zamalang/dialects/_ods_common.py @@ -1 +1,2 @@ +# We need this helpers from the mlir bindings, they are used in the generated files from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context diff --git a/compiler/python/zamalang/dialects/hlfhe.py b/compiler/python/zamalang/dialects/hlfhe.py index c52f2c3ad..3d996e829 100644 --- a/compiler/python/zamalang/dialects/hlfhe.py +++ b/compiler/python/zamalang/dialects/hlfhe.py @@ -1,2 +1,3 @@ +"""HLFHE dialect module""" from ._HLFHE_ops_gen import * from _zamalang._hlfhe import * diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py new file mode 100644 index 000000000..14787edb4 --- /dev/null +++ b/compiler/tests/python/test_compiler_engine.py @@ -0,0 +1,23 @@ +import pytest +from zamalang import CompilerEngine + + +@pytest.mark.parametrize( + "mlir_input, args, expected_result", + [ + ( + """ + func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> + } + """, + (5, 7), + 12, + ), + ], +) +def test_compile_and_run(mlir_input, args, expected_result): + engine = CompilerEngine() + engine.compile_fhe(mlir_input) + assert engine.run(*args) == expected_result