mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(python): rework the bindings with latest MLIR version
- Go through CAPI for python bindings - Consuming LLVM errors in CAPI: fixes previous issue which made this impossible in the python bindings
This commit is contained in:
@@ -7,20 +7,33 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_CXX_STANDARD 14)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
find_package(MLIR REQUIRED CONFIG)
|
||||
message(STATUS "Using MLIR cmake file from: ${MLIR_DIR}")
|
||||
|
||||
find_package(LLVM REQUIRED CONFIG)
|
||||
message(STATUS "Using LLVM cmake file from: ${LLVM_DIR}")
|
||||
# If we are trying to build the compiler with LLVM/MLIR as libraries
|
||||
if( NOT DEFINED LLVM_EXTERNAL_ZAMALANG_SOURCE_DIR )
|
||||
message(FATAL_ERROR "Concrete compiler requires a unified build with LLVM/MLIR")
|
||||
endif()
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
|
||||
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
|
||||
include(TableGen)
|
||||
include(AddLLVM)
|
||||
include(AddMLIR)
|
||||
# CMake library generation settings.
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Default to building a static mondo-lib")
|
||||
set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON CACHE BOOL
|
||||
"Python soname linked libraries are bad")
|
||||
set(CMAKE_VISIBILITY_INLINES_HIDDEN ON CACHE BOOL "Hide inlines")
|
||||
|
||||
# The -fvisibility=hidden option only works for static builds.
|
||||
if (BUILD_SHARED_LIBS AND (CMAKE_CXX_VISIBILITY_PRESET STREQUAL "hidden"))
|
||||
message(FATAL_ERROR "CMAKE_CXX_VISIBILITY_PRESET=hidden is incompatible \
|
||||
with BUILD_SHARED_LIBS.")
|
||||
endif()
|
||||
|
||||
set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir ) # --src-root
|
||||
set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include ) # --includedir
|
||||
set(MLIR_TABLEGEN_OUTPUT_DIR ${LLVM_BINARY_DIR}/tools/mlir/include)
|
||||
set(MLIR_TABLEGEN_EXE $<TARGET_FILE:mlir-tblgen>)
|
||||
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
|
||||
include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${MLIR_MAIN_SRC_DIR}/cmake/modules")
|
||||
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
include_directories(${MLIR_INCLUDE_DIRS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include)
|
||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||
@@ -54,6 +67,8 @@ if(ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
|
||||
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
||||
"extension = '${PYTHON_MODULE_EXTENSION}'")
|
||||
|
||||
set(ZAMALANG_PYTHON_PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/python_packages)
|
||||
else()
|
||||
message(STATUS "ZamaLang Python bindings are disabled.")
|
||||
endif()
|
||||
@@ -69,7 +84,3 @@ add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(tests)
|
||||
|
||||
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
add_subdirectory(python)
|
||||
endif()
|
||||
|
||||
@@ -6,8 +6,7 @@ build:
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DLLVM_DIR=${LLVM_PROJECT}/build/lib/cmake/llvm \
|
||||
-DMLIR_DIR=${LLVM_PROJECT}/build/lib/cmake/mlir \
|
||||
-DZAMALANG_BINDINGS_PYTHON_ENABLED=ON \
|
||||
-DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release \
|
||||
-DLLVM_EXTERNAL_PROJECTS=zamalang \
|
||||
-DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=.
|
||||
@@ -19,10 +18,10 @@ zamacompiler: build
|
||||
cmake --build build --target zamacompiler
|
||||
|
||||
python-bindings: build
|
||||
cmake --build build --target ZamalangBindingsPython
|
||||
cmake --build build --target ZamalangMLIRPythonModules ZamalangPythonModules
|
||||
|
||||
test-check: zamacompiler
|
||||
${LLVM_PROJECT}/build/bin/llvm-lit -v tests/
|
||||
test-check: zamacompiler file-check not
|
||||
./build/bin/llvm-lit -v tests/
|
||||
|
||||
test-end-to-end-jit: build-end-to-end-jit
|
||||
./build/bin/end_to_end_jit_test
|
||||
@@ -30,4 +29,13 @@ test-end-to-end-jit: build-end-to-end-jit
|
||||
test: test-check test-end-to-end-jit
|
||||
|
||||
test-python: python-bindings
|
||||
PYTHONPATH=${PYTHONPATH}:./build/tools/zamalang/python:./build/python LD_PRELOAD=./build/lib/libZamalangRuntime.so pytest -v tests/python
|
||||
PYTHONPATH=${PYTHONPATH}:./build/tools/zamalang/python_packages/zamalang_core:./build/tools/zamalang/python_packages/zamalang_core/mlir/_mlir_libs/ LD_PRELOAD=./build/lib/libZamalangRuntime.so pytest -vs tests/python
|
||||
|
||||
# LLVM/MLIR dependencies
|
||||
|
||||
all-deps: file-check not
|
||||
|
||||
file-check:
|
||||
cmake --build build/ --target FileCheck
|
||||
not:
|
||||
cmake --build build/ --target not
|
||||
|
||||
36
compiler/include/zamalang-c/Support/CompilerEngine.h
Normal file
36
compiler/include/zamalang-c/Support/CompilerEngine.h
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef ZAMALANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct compilerEngine {
|
||||
mlir::zamalang::CompilerEngine *ptr;
|
||||
};
|
||||
typedef struct compilerEngine compilerEngine;
|
||||
|
||||
struct executionArguments {
|
||||
mlir::zamalang::ExecutionArgument *data;
|
||||
size_t size;
|
||||
};
|
||||
typedef struct executionArguments exectuionArguments;
|
||||
|
||||
// Compile an MLIR module
|
||||
MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine,
|
||||
const char *module);
|
||||
|
||||
// Run the compiled module
|
||||
MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e,
|
||||
executionArguments args);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // ZAMALANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef ZAMALANG_PYTHON_COMPILER_API_MODULE_H
|
||||
#define ZAMALANG_PYTHON_COMPILER_API_MODULE_H
|
||||
#ifndef ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
|
||||
#define ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace python {
|
||||
|
||||
// Frontend object to abstract the different types of possible arguments,
|
||||
// namely, integers, and tensors.
|
||||
@@ -31,8 +31,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
ExecutionArgument(int arg)
|
||||
: isTensorArg(false), intArg(arg) {}
|
||||
ExecutionArgument(int arg) : isTensorArg(false), intArg(arg) {}
|
||||
|
||||
ExecutionArgument(std::vector<uint8_t> tensor)
|
||||
: isTensorArg(true), tensorArg(tensor) {}
|
||||
@@ -42,9 +41,7 @@ private:
|
||||
bool isTensorArg;
|
||||
};
|
||||
|
||||
void populateCompilerAPISubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H
|
||||
#endif
|
||||
3
compiler/lib/Bindings/CMakeLists.txt
Normal file
3
compiler/lib/Bindings/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
add_subdirectory(Python)
|
||||
endif()
|
||||
94
compiler/lib/Bindings/Python/CMakeLists.txt
Normal file
94
compiler/lib/Bindings/Python/CMakeLists.txt
Normal file
@@ -0,0 +1,94 @@
|
||||
include(AddMLIRPython)
|
||||
|
||||
################################################################################
|
||||
# Decalare native Python extension
|
||||
################################################################################
|
||||
|
||||
declare_mlir_python_sources(ZamalangBindingsPythonExtension)
|
||||
|
||||
declare_mlir_python_extension(ZamalangBindingsPythonExtension.Core
|
||||
MODULE_NAME _zamalang
|
||||
ADD_TO_PARENT ZamalangBindingsPythonExtension
|
||||
SOURCES
|
||||
ZamalangModule.cpp
|
||||
HLFHEModule.cpp
|
||||
CompilerAPIModule.cpp
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
ZAMALANGCAPIHLFHE
|
||||
ZAMALANGCAPISupport
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Declare python sources
|
||||
################################################################################
|
||||
|
||||
declare_mlir_python_sources(ZamalangBindingsPythonSources
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
SOURCES
|
||||
zamalang/__init__.py
|
||||
zamalang/compiler.py
|
||||
zamalang/dialects/_ods_common.py)
|
||||
|
||||
################################################################################
|
||||
# Declare dialect-specific bindings.
|
||||
################################################################################
|
||||
|
||||
declare_mlir_python_sources(ZamalangBindingsPythonSources.Dialects
|
||||
ADD_TO_PARENT ZamalangBindingsPythonSources)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT ZamalangBindingsPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
ZAMALANGBindingsPythonHLFHEOps
|
||||
TD_FILE zamalang/dialects/HLFHEOps.td
|
||||
SOURCES
|
||||
zamalang/dialects/hlfhe.py
|
||||
DIALECT_NAME HLFHE)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Build composite binaries
|
||||
################################################################################
|
||||
|
||||
# Bundle our own, self-contained CAPI library with all of our deps.
|
||||
add_mlir_python_common_capi_library(ZamalangBindingsPythonCAPI
|
||||
INSTALL_COMPONENT ZamalangBindingsPythonModules
|
||||
INSTALL_DESTINATION python_packages/zamalang_core/mlir/_mlir_libs
|
||||
# NOTE: When the MLIR API is relocated under zamalang, this would change to
|
||||
# .../zamalang/_mlir_libs
|
||||
OUTPUT_DIRECTORY "${ZAMALANG_PYTHON_PACKAGES_DIR}/zamalang_core/mlir/_mlir_libs"
|
||||
RELATIVE_INSTALL_ROOT "../../../.."
|
||||
DECLARED_SOURCES
|
||||
# TODO: This can be chopped down significantly for size.
|
||||
MLIRPythonSources
|
||||
MLIRPythonExtension.AllPassesRegistration
|
||||
ZamalangBindingsPythonSources
|
||||
ZamalangBindingsPythonExtension
|
||||
)
|
||||
|
||||
# Bundle the MLIR python sources into our package.
|
||||
# The MLIR API is position independent, so we explicitly output it to the mlir/
|
||||
# folder as a temporary measure. It will eventually migrate under the zamalang/
|
||||
# folder and be accessible under the unified "import zamalang..." namespace.
|
||||
add_mlir_python_modules(ZamalangMLIRPythonModules
|
||||
ROOT_PREFIX "${ZAMALANG_PYTHON_PACKAGES_DIR}/zamalang_core/mlir"
|
||||
INSTALL_PREFIX "python_packages/zamalang_core/mlir"
|
||||
DECLARED_SOURCES
|
||||
MLIRPythonSources
|
||||
MLIRPythonExtension.AllPassesRegistration
|
||||
# We need the circt extensions co-located with the MLIR extensions. When
|
||||
# the namespace is unified, this moves to the below.
|
||||
ZamalangBindingsPythonExtension
|
||||
COMMON_CAPI_LINK_LIBS
|
||||
ZamalangBindingsPythonCAPI
|
||||
)
|
||||
|
||||
# Bundle the ZAMALANG python sources into our package.
|
||||
add_mlir_python_modules(ZamalangPythonModules
|
||||
ROOT_PREFIX "${ZAMALANG_PYTHON_PACKAGES_DIR}/zamalang_core"
|
||||
INSTALL_PREFIX "python_packages/zamalang_core"
|
||||
DECLARED_SOURCES
|
||||
ZamalangBindingsPythonSources
|
||||
COMMON_CAPI_LINK_LIBS
|
||||
ZamalangBindingsPythonCAPI
|
||||
)
|
||||
64
compiler/lib/Bindings/Python/CompilerAPIModule.cpp
Normal file
64
compiler/lib/Bindings/Python/CompilerAPIModule.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "zamalang-c/Support/CompilerEngine.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
using mlir::zamalang::ExecutionArgument;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
m.doc() = "Zamalang compiler python API";
|
||||
|
||||
m.def("round_trip", [](std::string mlir_input) {
|
||||
mlir::MLIRContext context;
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
auto module_ref = mlir::parseSourceString(mlir_input, &context);
|
||||
if (!module_ref) {
|
||||
throw std::logic_error("mlir parsing failed");
|
||||
}
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
module_ref->print(os);
|
||||
return os.str();
|
||||
});
|
||||
|
||||
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
|
||||
m, "ExecutionArgument")
|
||||
.def("create",
|
||||
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
|
||||
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
|
||||
&ExecutionArgument::create))
|
||||
.def("is_tensor", &ExecutionArgument::isTensor)
|
||||
.def("is_int", &ExecutionArgument::isInt);
|
||||
|
||||
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def("run",
|
||||
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
|
||||
// wrap and call CAPI
|
||||
compilerEngine e{&engine};
|
||||
exectuionArguments a{args.data(), args.size()};
|
||||
return compilerEngineRun(e, a);
|
||||
})
|
||||
.def("compile_fhe",
|
||||
[](CompilerEngine &engine, std::string mlir_input) {
|
||||
// wrap and call CAPI
|
||||
compilerEngine e{&engine};
|
||||
compilerEngineCompile(e, mlir_input.c_str());
|
||||
})
|
||||
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
|
||||
}
|
||||
16
compiler/lib/Bindings/Python/CompilerAPIModule.h
Normal file
16
compiler/lib/Bindings/Python/CompilerAPIModule.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef ZAMALANG_PYTHON_COMPILER_API_MODULE_H
|
||||
#define ZAMALANG_PYTHON_COMPILER_API_MODULE_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace python {
|
||||
|
||||
void populateCompilerAPISubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace python {
|
||||
|
||||
@@ -10,5 +11,6 @@ void populateDialectHLFHESubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H
|
||||
@@ -11,11 +11,11 @@
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
using namespace zamalang;
|
||||
using namespace mlir::zamalang;
|
||||
using namespace mlir::python::adaptors;
|
||||
|
||||
/// Populate the hlfhe python module.
|
||||
void zamalang::python::populateDialectHLFHESubmodule(pybind11::module &m) {
|
||||
void mlir::zamalang::python::populateDialectHLFHESubmodule(pybind11::module &m) {
|
||||
m.doc() = "HLFHE dialect Python native extension";
|
||||
|
||||
mlir_type_subclass(m, "EncryptedIntegerType",
|
||||
@@ -32,8 +32,8 @@ PYBIND11_MODULE(_zamalang, m) {
|
||||
"Register Zamalang dialects on a PyMlirContext.");
|
||||
|
||||
py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API");
|
||||
zamalang::python::populateDialectHLFHESubmodule(hlfhe);
|
||||
mlir::zamalang::python::populateDialectHLFHESubmodule(hlfhe);
|
||||
|
||||
py::module api = m.def_submodule("_compiler", "Compiler API");
|
||||
zamalang::python::populateCompilerAPISubmodule(api);
|
||||
mlir::zamalang::python::populateCompilerAPISubmodule(api);
|
||||
}
|
||||
@@ -1 +1,6 @@
|
||||
add_subdirectory(Dialect)
|
||||
# CAPI is mainly used by python and need to throw exceptions
|
||||
# for proper handling of errors on the python-side
|
||||
add_compile_options(-fexceptions)
|
||||
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Support)
|
||||
|
||||
13
compiler/lib/CAPI/Support/CMakeLists.txt
Normal file
13
compiler/lib/CAPI/Support/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
set(LLVM_OPTIONAL_SOURCES CompilerEngine.cpp)
|
||||
|
||||
add_mlir_library(ZAMALANGCAPISupport
|
||||
|
||||
CompilerEngine.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir-c
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
ZamalangSupport
|
||||
)
|
||||
62
compiler/lib/CAPI/Support/CompilerEngine.cpp
Normal file
62
compiler/lib/CAPI/Support/CompilerEngine.cpp
Normal file
@@ -0,0 +1,62 @@
|
||||
#include "zamalang-c/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
using mlir::zamalang::ExecutionArgument;
|
||||
|
||||
void compilerEngineCompile(compilerEngine engine, const char *module) {
|
||||
auto error = engine.ptr->compile(module);
|
||||
if (error) {
|
||||
llvm::errs() << "Compilation failed: " << error << "\n";
|
||||
llvm::consumeError(std::move(error));
|
||||
throw std::runtime_error(
|
||||
"failed compiling, see previous logs for more info");
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) {
|
||||
auto args_size = args.size;
|
||||
auto maybeArgument = engine.ptr->buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error(
|
||||
"failed building arguments, see previous logs for more info");
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args_size; i++) {
|
||||
if (args.data[i].isInt()) { // integer argument
|
||||
if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed pushing integer argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
} else { // tensor argument
|
||||
assert(args.data[i].isTensor() && "should be tensor argument");
|
||||
if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(),
|
||||
args.data[i].getTensorSize())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed pushing tensor argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = engine.ptr->invoke(*arguments)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed running, see previous logs for more info");
|
||||
}
|
||||
uint64_t result = 0;
|
||||
if (auto err = arguments->getResult(0, result)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error(
|
||||
"failed getting result, see previous logs for more info");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -2,6 +2,7 @@ add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Support)
|
||||
add_subdirectory(Runtime)
|
||||
add_subdirectory(Bindings)
|
||||
|
||||
# CAPI needed only for python bindings
|
||||
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
include(AddMLIRPython)
|
||||
add_custom_target(ZamalangBindingsPython)
|
||||
|
||||
################################################################################
|
||||
# Build native Python extension
|
||||
################################################################################
|
||||
|
||||
add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
|
||||
INSTALL_DIR
|
||||
python
|
||||
SOURCES
|
||||
ZamalangModule.cpp
|
||||
HLFHEModule.cpp
|
||||
CompilerAPIModule.cpp
|
||||
LINK_LIBS
|
||||
ZAMALANGCAPIHLFHE
|
||||
ZamalangSupport
|
||||
LowLFHEDialect
|
||||
MidLFHEDialect
|
||||
HLFHEDialect
|
||||
ZamalangRuntime
|
||||
)
|
||||
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)
|
||||
|
||||
################################################################################
|
||||
# Copy python source tree.
|
||||
################################################################################
|
||||
|
||||
file(GLOB_RECURSE PY_SRC_FILES
|
||||
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/zamalang/*.py")
|
||||
|
||||
add_custom_target(ZAMALANGBindingsPythonSources ALL
|
||||
DEPENDS
|
||||
${PY_SRC_FILES}
|
||||
)
|
||||
|
||||
add_dependencies(ZamalangBindingsPython ZAMALANGBindingsPythonSources)
|
||||
|
||||
foreach(PY_SRC_FILE ${PY_SRC_FILES})
|
||||
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
|
||||
get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY)
|
||||
file(MAKE_DIRECTORY "${PY_DEST_DIR}")
|
||||
add_custom_command(
|
||||
TARGET ZAMALANGBindingsPythonSources PRE_BUILD
|
||||
COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}"
|
||||
DEPENDS "${PY_SRC_FILE}"
|
||||
BYPRODUCTS "${PY_DEST_FILE}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E create_symlink
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}"
|
||||
)
|
||||
endforeach()
|
||||
|
||||
# Note that we copy from the source tree just like for headers because
|
||||
# it will not be polluted with py_cache runtime artifacts (from testing and
|
||||
# such).
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/zamalang
|
||||
DESTINATION python
|
||||
COMPONENT ZAMALANGBindingsPythonSources
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
if (NOT LLVM_ENABLE_IDE)
|
||||
add_llvm_install_targets(
|
||||
install-ZAMALANGBindingsPythonSources
|
||||
DEPENDS ZAMALANGBindingsPythonSources
|
||||
COMPONENT ZAMALANGBindingsPythonSources)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Generated sources.
|
||||
################################################################################
|
||||
|
||||
add_subdirectory(zamalang/dialects)
|
||||
@@ -1,110 +0,0 @@
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Parser.h>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using namespace zamalang;
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
using zamalang::python::ExecutionArgument;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
m.doc() = "Zamalang compiler python API";
|
||||
|
||||
m.def("round_trip", [](std::string mlir_input) {
|
||||
mlir::MLIRContext context;
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
auto module_ref = mlir::parseSourceString(mlir_input, &context);
|
||||
if (!module_ref) {
|
||||
throw std::logic_error("mlir parsing failed");
|
||||
}
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
module_ref->print(os);
|
||||
return os.str();
|
||||
});
|
||||
|
||||
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
|
||||
m, "ExecutionArgument")
|
||||
.def("create",
|
||||
pybind11::overload_cast<uint64_t>(&ExecutionArgument::create))
|
||||
.def("create", pybind11::overload_cast<std::vector<uint8_t>>(
|
||||
&ExecutionArgument::create))
|
||||
.def("is_tensor", &ExecutionArgument::isTensor)
|
||||
.def("is_int", &ExecutionArgument::isInt);
|
||||
|
||||
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def(
|
||||
"run",
|
||||
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed building arguments, see previous logs for more info");
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args.size(); i++) {
|
||||
if (args[i].isInt()) { // integer argument
|
||||
if (auto err =
|
||||
arguments->setArg(i, args[i].getIntegerArgument())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed pushing integer argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
} else { // tensor argument
|
||||
assert(args[i].isTensor() && "should be tensor argument");
|
||||
if (auto err = arguments->setArg(i, args[i].getTensorArgument(),
|
||||
args[i].getTensorSize())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed pushing tensor argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = engine.invoke(*arguments)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed running, see previous logs for more info");
|
||||
}
|
||||
uint64_t result = 0;
|
||||
if (auto err = arguments->getResult(0, result)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed getting result, see previous logs for more info");
|
||||
}
|
||||
return result;
|
||||
})
|
||||
.def("compile_fhe",
|
||||
[](CompilerEngine &engine, std::string mlir_input) {
|
||||
auto error = engine.compile(mlir_input);
|
||||
if (error) {
|
||||
llvm::errs() << "Compilation failed: " << error << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed compiling, see previous logs for more info");
|
||||
}
|
||||
})
|
||||
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
include(AddMLIRPython)
|
||||
|
||||
################################################################################
|
||||
# Generate dialect-specific bindings.
|
||||
################################################################################
|
||||
|
||||
add_mlir_dialect_python_bindings(ZAMALANGBindingsPythonHLFHEOps
|
||||
TD_FILE HLFHEOps.td
|
||||
DIALECT_NAME HLFHE)
|
||||
add_dependencies(ZAMALANGBindingsPythonSources ZAMALANGBindingsPythonHLFHEOps)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Installation.
|
||||
################################################################################
|
||||
|
||||
install(
|
||||
DIRECTORY ${PROJECT_BINARY_DIR}/python/zamalang/dialects
|
||||
DESTINATION python/zamalang
|
||||
COMPONENT ZAMALANGBindingsPythonDialects
|
||||
FILES_MATCHING PATTERN "_*_gen.py"
|
||||
PATTERN "__pycache__" EXCLUDE
|
||||
PATTERN "__init__.py" EXCLUDE
|
||||
)
|
||||
|
||||
if (NOT LLVM_ENABLE_IDE)
|
||||
add_llvm_install_targets(
|
||||
install-ZAMALANGBindingsPythonDialects
|
||||
DEPENDS ZAMALANGBindingsPythonSources
|
||||
COMPONENT ZAMALANGBindingsPythonDialects)
|
||||
endif()
|
||||
@@ -11,14 +11,29 @@ def main():
|
||||
zamalang.register_dialects(ctx)
|
||||
|
||||
module = Module.create()
|
||||
eint16 = hlfhe.EncryptedIntegerType.get(ctx, 16)
|
||||
eint6 = hlfhe.EncryptedIntegerType.get(ctx, 6)
|
||||
with InsertionPoint(module.body):
|
||||
func_types = [RankedTensorType.get((10, 10), eint16) for _ in range(2)]
|
||||
func_types = [MemRefType.get((10, 10), eint6) for _ in range(2)]
|
||||
@builtin.FuncOp.from_py_func(*func_types)
|
||||
def fhe_circuit(*arg):
|
||||
def main(*arg):
|
||||
return arg[0]
|
||||
|
||||
print(module)
|
||||
m = """
|
||||
func @main(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
%0 = constant 1 : i3
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i3) -> (!HLFHE.eint<2>)
|
||||
return %1: !HLFHE.eint<2>
|
||||
}"""
|
||||
## Working when HFLFHE and MLIR aren't linked
|
||||
zamalang.compiler.round_trip("module{}")
|
||||
zamalang.compiler.round_trip(str(module))
|
||||
## END OF WORKING
|
||||
## Doesn't work yet for both modules
|
||||
engine = zamalang.CompilerEngine()
|
||||
engine.compile_fhe(m)
|
||||
# engine.compile_fhe(str(module))
|
||||
print(engine.run(2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user