refactor(python): rework the bindings with latest MLIR version

- Go through CAPI for python bindings
- Consuming LLVM errors in CAPI: fixes previous issue which made this
  impossible in the python bindings
This commit is contained in:
youben11
2021-09-24 15:23:19 +01:00
committed by Ayoub Benaissa
parent 3406b322d5
commit 2972fa4403
24 changed files with 366 additions and 255 deletions

View File

@@ -7,20 +7,33 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
find_package(MLIR REQUIRED CONFIG)
message(STATUS "Using MLIR cmake file from: ${MLIR_DIR}")
find_package(LLVM REQUIRED CONFIG)
message(STATUS "Using LLVM cmake file from: ${LLVM_DIR}")
# If we are trying to build the compiler with LLVM/MLIR as libraries
if( NOT DEFINED LLVM_EXTERNAL_ZAMALANG_SOURCE_DIR )
message(FATAL_ERROR "Concrete compiler requires a unified build with LLVM/MLIR")
endif()
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(TableGen)
include(AddLLVM)
include(AddMLIR)
# CMake library generation settings.
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Default to building a static mondo-lib")
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON CACHE BOOL
"Python soname linked libraries are bad")
set(CMAKE_VISIBILITY_INLINES_HIDDEN ON CACHE BOOL "Hide inlines")
# The -fvisibility=hidden option only works for static builds.
if (BUILD_SHARED_LIBS AND (CMAKE_CXX_VISIBILITY_PRESET STREQUAL "hidden"))
message(FATAL_ERROR "CMAKE_CXX_VISIBILITY_PRESET=hidden is incompatible \
with BUILD_SHARED_LIBS.")
endif()
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir ) # --src-root
set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include ) # --includedir
set(MLIR_TABLEGEN_OUTPUT_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
set(MLIR_TABLEGEN_EXE $<TARGET_FILE:mlir-tblgen>)
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
list(APPEND CMAKE_MODULE_PATH "${MLIR_MAIN_SRC_DIR}/cmake/modules")
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
@@ -54,6 +67,8 @@ if(ZAMALANG_BINDINGS_PYTHON_ENABLED)
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "
"extension = '${PYTHON_MODULE_EXTENSION}'")
set(ZAMALANG_PYTHON_PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/python_packages)
else()
message(STATUS "ZamaLang Python bindings are disabled.")
endif()
@@ -69,7 +84,3 @@ add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(src)
add_subdirectory(tests)
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
add_subdirectory(python)
endif()

View File

@@ -6,8 +6,7 @@ build:
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm \
-DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir \
-DZAMALANG_BINDINGS_PYTHON_ENABLED=ON \
-DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release \
-DLLVM_EXTERNAL_PROJECTS=zamalang \
-DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=.
@@ -19,10 +18,10 @@ zamacompiler: build
cmake --build build --target zamacompiler
python-bindings: build
cmake --build build --target ZamalangBindingsPython
cmake --build build --target ZamalangMLIRPythonModules ZamalangPythonModules
test-check: zamacompiler
${LLVM_PROJECT}/build/bin/llvm-lit -v tests/
test-check: zamacompiler file-check not
./build/bin/llvm-lit -v tests/
test-end-to-end-jit: build-end-to-end-jit
./build/bin/end_to_end_jit_test
@@ -30,4 +29,13 @@ test-end-to-end-jit: build-end-to-end-jit
test: test-check test-end-to-end-jit
test-python: python-bindings
PYTHONPATH=${PYTHONPATH}:./build/tools/zamalang/python:./build/python LD_PRELOAD=./build/lib/libZamalangRuntime.so pytest -v tests/python
PYTHONPATH=${PYTHONPATH}:./build/tools/zamalang/python_packages/zamalang_core:./build/tools/zamalang/python_packages/zamalang_core/mlir/_mlir_libs/ LD_PRELOAD=./build/lib/libZamalangRuntime.so pytest -vs tests/python
# LLVM/MLIR dependencies
all-deps: file-check not
file-check:
cmake --build build/ --target FileCheck
not:
cmake --build build/ --target not

View File

@@ -0,0 +1,36 @@
#ifndef ZAMALANG_C_SUPPORT_COMPILER_ENGINE_H
#define ZAMALANG_C_SUPPORT_COMPILER_ENGINE_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
#ifdef __cplusplus
extern "C" {
#endif
struct compilerEngine {
mlir::zamalang::CompilerEngine *ptr;
};
typedef struct compilerEngine compilerEngine;
struct executionArguments {
mlir::zamalang::ExecutionArgument *data;
size_t size;
};
typedef struct executionArguments exectuionArguments;
// Compile an MLIR module
MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine,
const char *module);
// Run the compiled module
MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e,
executionArguments args);
#ifdef __cplusplus
}
#endif
#endif // ZAMALANG_C_SUPPORT_COMPILER_ENGINE_H

View File

@@ -1,10 +1,10 @@
#ifndef ZAMALANG_PYTHON_COMPILER_API_MODULE_H
#define ZAMALANG_PYTHON_COMPILER_API_MODULE_H
#ifndef ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
#define ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
#include <pybind11/pybind11.h>
#include <vector>
namespace mlir {
namespace zamalang {
namespace python {
// Frontend object to abstract the different types of possible arguments,
// namely, integers, and tensors.
@@ -31,8 +31,7 @@ public:
}
private:
ExecutionArgument(int arg)
: isTensorArg(false), intArg(arg) {}
ExecutionArgument(int arg) : isTensorArg(false), intArg(arg) {}
ExecutionArgument(std::vector<uint8_t> tensor)
: isTensorArg(true), tensorArg(tensor) {}
@@ -42,9 +41,7 @@ private:
bool isTensorArg;
};
void populateCompilerAPISubmodule(pybind11::module &m);
} // namespace python
} // namespace zamalang
} // namespace mlir
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H
#endif

View File

@@ -0,0 +1,3 @@
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
add_subdirectory(Python)
endif()

View File

@@ -0,0 +1,94 @@
include(AddMLIRPython)
################################################################################
# Decalare native Python extension
################################################################################
declare_mlir_python_sources(ZamalangBindingsPythonExtension)
declare_mlir_python_extension(ZamalangBindingsPythonExtension.Core
MODULE_NAME _zamalang
ADD_TO_PARENT ZamalangBindingsPythonExtension
SOURCES
ZamalangModule.cpp
HLFHEModule.cpp
CompilerAPIModule.cpp
EMBED_CAPI_LINK_LIBS
ZAMALANGCAPIHLFHE
ZAMALANGCAPISupport
)
################################################################################
# Declare python sources
################################################################################
declare_mlir_python_sources(ZamalangBindingsPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
SOURCES
zamalang/__init__.py
zamalang/compiler.py
zamalang/dialects/_ods_common.py)
################################################################################
# Declare dialect-specific bindings.
################################################################################
declare_mlir_python_sources(ZamalangBindingsPythonSources.Dialects
ADD_TO_PARENT ZamalangBindingsPythonSources)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT ZamalangBindingsPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
ZAMALANGBindingsPythonHLFHEOps
TD_FILE zamalang/dialects/HLFHEOps.td
SOURCES
zamalang/dialects/hlfhe.py
DIALECT_NAME HLFHE)
################################################################################
# Build composite binaries
################################################################################
# Bundle our own, self-contained CAPI library with all of our deps.
add_mlir_python_common_capi_library(ZamalangBindingsPythonCAPI
INSTALL_COMPONENT ZamalangBindingsPythonModules
INSTALL_DESTINATION python_packages/zamalang_core/mlir/_mlir_libs
# NOTE: When the MLIR API is relocated under zamalang, this would change to
# .../zamalang/_mlir_libs
OUTPUT_DIRECTORY "${ZAMALANG_PYTHON_PACKAGES_DIR}/zamalang_core/mlir/_mlir_libs"
RELATIVE_INSTALL_ROOT "../../../.."
DECLARED_SOURCES
# TODO: This can be chopped down significantly for size.
MLIRPythonSources
MLIRPythonExtension.AllPassesRegistration
ZamalangBindingsPythonSources
ZamalangBindingsPythonExtension
)
# Bundle the MLIR python sources into our package.
# The MLIR API is position independent, so we explicitly output it to the mlir/
# folder as a temporary measure. It will eventually migrate under the zamalang/
# folder and be accessible under the unified "import zamalang..." namespace.
add_mlir_python_modules(ZamalangMLIRPythonModules
ROOT_PREFIX "${ZAMALANG_PYTHON_PACKAGES_DIR}/zamalang_core/mlir"
INSTALL_PREFIX "python_packages/zamalang_core/mlir"
DECLARED_SOURCES
MLIRPythonSources
MLIRPythonExtension.AllPassesRegistration
# We need the circt extensions co-located with the MLIR extensions. When
# the namespace is unified, this moves to the below.
ZamalangBindingsPythonExtension
COMMON_CAPI_LINK_LIBS
ZamalangBindingsPythonCAPI
)
# Bundle the ZAMALANG python sources into our package.
add_mlir_python_modules(ZamalangPythonModules
ROOT_PREFIX "${ZAMALANG_PYTHON_PACKAGES_DIR}/zamalang_core"
INSTALL_PREFIX "python_packages/zamalang_core"
DECLARED_SOURCES
ZamalangBindingsPythonSources
COMMON_CAPI_LINK_LIBS
ZamalangBindingsPythonCAPI
)

View File

@@ -0,0 +1,64 @@
#include "CompilerAPIModule.h"
#include "zamalang-c/Support/CompilerEngine.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <stdexcept>
#include <string>
using mlir::zamalang::CompilerEngine;
using mlir::zamalang::ExecutionArgument;
/// Populate the compiler API python module.
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
m.doc() = "Zamalang compiler python API";
m.def("round_trip", [](std::string mlir_input) {
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
auto module_ref = mlir::parseSourceString(mlir_input, &context);
if (!module_ref) {
throw std::logic_error("mlir parsing failed");
}
std::string result;
llvm::raw_string_ostream os(result);
module_ref->print(os);
return os.str();
});
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
m, "ExecutionArgument")
.def("create",
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
&ExecutionArgument::create))
.def("is_tensor", &ExecutionArgument::isTensor)
.def("is_int", &ExecutionArgument::isInt);
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
.def(pybind11::init())
.def("run",
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
// wrap and call CAPI
compilerEngine e{&engine};
exectuionArguments a{args.data(), args.size()};
return compilerEngineRun(e, a);
})
.def("compile_fhe",
[](CompilerEngine &engine, std::string mlir_input) {
// wrap and call CAPI
compilerEngine e{&engine};
compilerEngineCompile(e, mlir_input.c_str());
})
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
}

View File

@@ -0,0 +1,16 @@
#ifndef ZAMALANG_PYTHON_COMPILER_API_MODULE_H
#define ZAMALANG_PYTHON_COMPILER_API_MODULE_H
#include <pybind11/pybind11.h>
namespace mlir {
namespace zamalang {
namespace python {
void populateCompilerAPISubmodule(pybind11::module &m);
} // namespace python
} // namespace zamalang
} // namespace mlir
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H

View File

@@ -3,6 +3,7 @@
#include <pybind11/pybind11.h>
namespace mlir {
namespace zamalang {
namespace python {
@@ -10,5 +11,6 @@ void populateDialectHLFHESubmodule(pybind11::module &m);
} // namespace python
} // namespace zamalang
} // namespace mlir
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H

View File

@@ -11,11 +11,11 @@
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
using namespace zamalang;
using namespace mlir::zamalang;
using namespace mlir::python::adaptors;
/// Populate the hlfhe python module.
void zamalang::python::populateDialectHLFHESubmodule(pybind11::module &m) {
void mlir::zamalang::python::populateDialectHLFHESubmodule(pybind11::module &m) {
m.doc() = "HLFHE dialect Python native extension";
mlir_type_subclass(m, "EncryptedIntegerType",

View File

@@ -32,8 +32,8 @@ PYBIND11_MODULE(_zamalang, m) {
"Register Zamalang dialects on a PyMlirContext.");
py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API");
zamalang::python::populateDialectHLFHESubmodule(hlfhe);
mlir::zamalang::python::populateDialectHLFHESubmodule(hlfhe);
py::module api = m.def_submodule("_compiler", "Compiler API");
zamalang::python::populateCompilerAPISubmodule(api);
mlir::zamalang::python::populateCompilerAPISubmodule(api);
}

View File

@@ -1 +1,6 @@
add_subdirectory(Dialect)
# CAPI is mainly used by python and need to throw exceptions
# for proper handling of errors on the python-side
add_compile_options(-fexceptions)
add_subdirectory(Dialect)
add_subdirectory(Support)

View File

@@ -0,0 +1,13 @@
set(LLVM_OPTIONAL_SOURCES CompilerEngine.cpp)
add_mlir_library(ZAMALANGCAPISupport
CompilerEngine.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir-c
LINK_LIBS PUBLIC
MLIRCAPIIR
ZamalangSupport
)

View File

@@ -0,0 +1,62 @@
#include "zamalang-c/Support/CompilerEngine.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
using mlir::zamalang::CompilerEngine;
using mlir::zamalang::ExecutionArgument;
void compilerEngineCompile(compilerEngine engine, const char *module) {
auto error = engine.ptr->compile(module);
if (error) {
llvm::errs() << "Compilation failed: " << error << "\n";
llvm::consumeError(std::move(error));
throw std::runtime_error(
"failed compiling, see previous logs for more info");
}
}
uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) {
auto args_size = args.size;
auto maybeArgument = engine.ptr->buildArgument();
if (auto err = maybeArgument.takeError()) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error(
"failed building arguments, see previous logs for more info");
}
// Set the integer/tensor arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args_size; i++) {
if (args.data[i].isInt()) { // integer argument
if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed pushing integer argument, see "
"previous logs for more info");
}
} else { // tensor argument
assert(args.data[i].isTensor() && "should be tensor argument");
if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(),
args.data[i].getTensorSize())) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed pushing tensor argument, see "
"previous logs for more info");
}
}
}
// Invoke the lambda
if (auto err = engine.ptr->invoke(*arguments)) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed running, see previous logs for more info");
}
uint64_t result = 0;
if (auto err = arguments->getResult(0, result)) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error(
"failed getting result, see previous logs for more info");
}
return result;
}

View File

@@ -2,6 +2,7 @@ add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(Support)
add_subdirectory(Runtime)
add_subdirectory(Bindings)
# CAPI needed only for python bindings
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)

View File

@@ -1,75 +0,0 @@
include(AddMLIRPython)
add_custom_target(ZamalangBindingsPython)
################################################################################
# Build native Python extension
################################################################################
add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
INSTALL_DIR
python
SOURCES
ZamalangModule.cpp
HLFHEModule.cpp
CompilerAPIModule.cpp
LINK_LIBS
ZAMALANGCAPIHLFHE
ZamalangSupport
LowLFHEDialect
MidLFHEDialect
HLFHEDialect
ZamalangRuntime
)
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)
################################################################################
# Copy python source tree.
################################################################################
file(GLOB_RECURSE PY_SRC_FILES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/zamalang/*.py")
add_custom_target(ZAMALANGBindingsPythonSources ALL
DEPENDS
${PY_SRC_FILES}
)
add_dependencies(ZamalangBindingsPython ZAMALANGBindingsPythonSources)
foreach(PY_SRC_FILE ${PY_SRC_FILES})
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${PY_DEST_DIR}")
add_custom_command(
TARGET ZAMALANGBindingsPythonSources PRE_BUILD
COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}"
DEPENDS "${PY_SRC_FILE}"
BYPRODUCTS "${PY_DEST_FILE}"
COMMAND "${CMAKE_COMMAND}" -E create_symlink
"${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}"
)
endforeach()
# Note that we copy from the source tree just like for headers because
# it will not be polluted with py_cache runtime artifacts (from testing and
# such).
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/zamalang
DESTINATION python
COMPONENT ZAMALANGBindingsPythonSources
FILES_MATCHING PATTERN "*.py"
)
if (NOT LLVM_ENABLE_IDE)
add_llvm_install_targets(
install-ZAMALANGBindingsPythonSources
DEPENDS ZAMALANGBindingsPythonSources
COMPONENT ZAMALANGBindingsPythonSources)
endif()
################################################################################
# Generated sources.
################################################################################
add_subdirectory(zamalang/dialects)

View File

@@ -1,110 +0,0 @@
#include "CompilerAPIModule.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerEngine.h"
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <stdexcept>
#include <string>
using namespace zamalang;
using mlir::zamalang::CompilerEngine;
using zamalang::python::ExecutionArgument;
/// Populate the compiler API python module.
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
m.doc() = "Zamalang compiler python API";
m.def("round_trip", [](std::string mlir_input) {
mlir::MLIRContext context;
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
context.getOrLoadDialect<mlir::StandardOpsDialect>();
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
auto module_ref = mlir::parseSourceString(mlir_input, &context);
if (!module_ref) {
throw std::logic_error("mlir parsing failed");
}
std::string result;
llvm::raw_string_ostream os(result);
module_ref->print(os);
return os.str();
});
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
m, "ExecutionArgument")
.def("create",
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
&ExecutionArgument::create))
.def("is_tensor", &ExecutionArgument::isTensor)
.def("is_int", &ExecutionArgument::isInt);
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
.def(pybind11::init())
.def(
"run",
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
auto maybeArgument = engine.buildArgument();
if (auto err = maybeArgument.takeError()) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed building arguments, see previous logs for more info");
}
// Set the integer/tensor arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args.size(); i++) {
if (args[i].isInt()) { // integer argument
if (auto err =
arguments->setArg(i, args[i].getIntegerArgument())) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed pushing integer argument, see "
"previous logs for more info");
}
} else { // tensor argument
assert(args[i].isTensor() && "should be tensor argument");
if (auto err = arguments->setArg(i, args[i].getTensorArgument(),
args[i].getTensorSize())) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed pushing tensor argument, see "
"previous logs for more info");
}
}
}
// Invoke the lambda
if (auto err = engine.invoke(*arguments)) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed running, see previous logs for more info");
}
uint64_t result = 0;
if (auto err = arguments->getResult(0, result)) {
llvm::errs() << "Execution failed: " << err << "\n";
throw std::runtime_error(
"failed getting result, see previous logs for more info");
}
return result;
})
.def("compile_fhe",
[](CompilerEngine &engine, std::string mlir_input) {
auto error = engine.compile(mlir_input);
if (error) {
llvm::errs() << "Compilation failed: " << error << "\n";
throw std::runtime_error(
"failed compiling, see previous logs for more info");
}
})
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
}

View File

@@ -1,31 +0,0 @@
include(AddMLIRPython)
################################################################################
# Generate dialect-specific bindings.
################################################################################
add_mlir_dialect_python_bindings(ZAMALANGBindingsPythonHLFHEOps
TD_FILE HLFHEOps.td
DIALECT_NAME HLFHE)
add_dependencies(ZAMALANGBindingsPythonSources ZAMALANGBindingsPythonHLFHEOps)
################################################################################
# Installation.
################################################################################
install(
DIRECTORY ${PROJECT_BINARY_DIR}/python/zamalang/dialects
DESTINATION python/zamalang
COMPONENT ZAMALANGBindingsPythonDialects
FILES_MATCHING PATTERN "_*_gen.py"
PATTERN "__pycache__" EXCLUDE
PATTERN "__init__.py" EXCLUDE
)
if (NOT LLVM_ENABLE_IDE)
add_llvm_install_targets(
install-ZAMALANGBindingsPythonDialects
DEPENDS ZAMALANGBindingsPythonSources
COMPONENT ZAMALANGBindingsPythonDialects)
endif()

View File

@@ -11,14 +11,29 @@ def main():
zamalang.register_dialects(ctx)
module = Module.create()
eint16 = hlfhe.EncryptedIntegerType.get(ctx, 16)
eint6 = hlfhe.EncryptedIntegerType.get(ctx, 6)
with InsertionPoint(module.body):
func_types = [RankedTensorType.get((10, 10), eint16) for _ in range(2)]
func_types = [MemRefType.get((10, 10), eint6) for _ in range(2)]
@builtin.FuncOp.from_py_func(*func_types)
def fhe_circuit(*arg):
def main(*arg):
return arg[0]
print(module)
m = """
func @main(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
%0 = constant 1 : i3
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>
}"""
## Working when HFLFHE and MLIR aren't linked
zamalang.compiler.round_trip("module{}")
zamalang.compiler.round_trip(str(module))
## END OF WORKING
## Doesn't work yet for both modules
engine = zamalang.CompilerEngine()
engine.compile_fhe(m)
# engine.compile_fhe(str(module))
print(engine.run(2))
if __name__ == "__main__":