mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compiler): Python bindings (#53)
- feat(compiler): python bindings - build: update docker image for python bindings - pin pybind11 to 2.6.2, 2.7 is not having correct include_dirs set (still a question why?) - using generated parser/printer
This commit is contained in:
2
.github/workflows/conformance.yml
vendored
2
.github/workflows/conformance.yml
vendored
@@ -24,6 +24,6 @@ jobs:
|
||||
- name: Build and test compiler
|
||||
uses: addnab/docker-run-action@v3
|
||||
with:
|
||||
image: qbozama/mlir:cc9283
|
||||
image: qbozama/mlir:latest
|
||||
options: -v ${{ github.workspace }}:/workspace
|
||||
run: cd /workspace/compiler && mkdir build && cmake -B build . -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ zamacompiler && make test
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -39,3 +39,5 @@
|
||||
# Jetbrains tools
|
||||
.idea/
|
||||
|
||||
# Python cache
|
||||
__pycache__/
|
||||
|
||||
@@ -2,15 +2,18 @@ FROM ubuntu:latest
|
||||
|
||||
RUN apt-get update --fix-missing
|
||||
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y curl cmake g++ build-essential python3 python3-pip python3-setuptools ninja-build git
|
||||
RUN pip install numpy pybind11==2.6.2 PyYAML
|
||||
RUN git clone --depth 1 https://github.com/llvm/llvm-project.git
|
||||
ENV LLVM_PROJECT=$PWD/llvm-project
|
||||
RUN cd ${LLVM_PROJECT} && git log -1
|
||||
RUN mkdir ${LLVM_PROJECT}/build
|
||||
RUN cd ${LLVM_PROJECT}/build && cmake -GNinja ../llvm \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_BUILD_EXAMPLES=OFF \
|
||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
|
||||
RUN cd ${LLVM_PROJECT}/build && cmake --build . --target check-mlir
|
||||
ENV PATH=${LLVM_PROJECT}/build/bin:${PATH}
|
||||
|
||||
@@ -19,4 +22,5 @@ COPY --from=0 /llvm-project/ /llvm-project/
|
||||
ENV LLVM_PROJECT=/llvm-project
|
||||
ENV PATH=${LLVM_PROJECT}/build/bin:${PATH}
|
||||
RUN apt-get update
|
||||
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y cmake g++ build-essential python3 zlib1g-dev
|
||||
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y cmake g++ build-essential python3 zlib1g-dev python3-pip python3-setuptools
|
||||
RUN pip install numpy pybind11==2.6.2 PyYAML
|
||||
@@ -27,6 +27,38 @@ include_directories(${PROJECT_BINARY_DIR}/include)
|
||||
link_directories(${LLVM_BUILD_LIBRARY_DIR})
|
||||
add_definitions(${LLVM_DEFINITIONS})
|
||||
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Python Configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
option(ZAMALANG_BINDINGS_PYTHON_ENABLED "Enables ZamaLang Python bindings." ON)
|
||||
|
||||
if(ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
message(STATUS "ZamaLang Python bindings are enabled.")
|
||||
|
||||
include(MLIRDetectPythonEnv)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
message(STATUS "Found Python include dirs: ${Python3_INCLUDE_DIRS}")
|
||||
message(STATUS "Found Python libraries: ${Python3_LIBRARIES}")
|
||||
message(STATUS "Found Python executable: ${Python3_EXECUTABLE}")
|
||||
|
||||
mlir_detect_pybind11_install()
|
||||
find_package(pybind11 2.6 CONFIG REQUIRED)
|
||||
message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}")
|
||||
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
|
||||
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
||||
"extension = '${PYTHON_MODULE_EXTENSION}'")
|
||||
else()
|
||||
message(STATUS "ZamaLang Python bindings are disabled.")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(src)
|
||||
|
||||
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
add_subdirectory(python)
|
||||
endif()
|
||||
|
||||
24
compiler/include/zamalang-c/Dialect/HLFHE.h
Normal file
24
compiler/include/zamalang-c/Dialect/HLFHE.h
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef ZAMALANG_C_DIALECT_HLFHE_H
|
||||
#define ZAMALANG_C_DIALECT_HLFHE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe);
|
||||
|
||||
/// Creates an encrypted integer type of `width` bits
|
||||
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGet(MlirContext context,
|
||||
unsigned width);
|
||||
|
||||
/// If the type is an EncryptedInteger
|
||||
MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // ZAMALANG_C_DIALECT_HLFHE_H
|
||||
1
compiler/lib/CAPI/CMakeLists.txt
Normal file
1
compiler/lib/CAPI/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(Dialect)
|
||||
13
compiler/lib/CAPI/Dialect/CMakeLists.txt
Normal file
13
compiler/lib/CAPI/Dialect/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
set(LLVM_OPTIONAL_SOURCES HLFHE.cpp)
|
||||
|
||||
add_mlir_library(ZAMALANGCAPIHLFHE
|
||||
|
||||
HLFHE.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir-c
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
HLFHEDialect
|
||||
)
|
||||
27
compiler/lib/CAPI/Dialect/HLFHE.cpp
Normal file
27
compiler/lib/CAPI/Dialect/HLFHE.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "zamalang-c/Dialect/HLFHE.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
|
||||
using namespace mlir::zamalang::HLFHE;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe, HLFHEDialect)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) {
|
||||
return unwrap(type).isa<EncryptedIntegerType>();
|
||||
}
|
||||
|
||||
MlirType hlfheEncryptedIntegerTypeGet(MlirContext ctx, unsigned width) {
|
||||
return wrap(EncryptedIntegerType::get(unwrap(ctx), width));
|
||||
}
|
||||
@@ -1 +1,5 @@
|
||||
add_subdirectory(Dialect)
|
||||
# CAPI needed only for python bindings
|
||||
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
add_subdirectory(CAPI)
|
||||
endif()
|
||||
|
||||
@@ -22,20 +22,24 @@ void HLFHEDialect::initialize() {
|
||||
}
|
||||
|
||||
::mlir::Type HLFHEDialect::parseType(::mlir::DialectAsmParser &parser) const {
|
||||
if (parser.parseKeyword("eint").failed())
|
||||
return ::mlir::Type();
|
||||
mlir::Type type;
|
||||
|
||||
return EncryptedIntegerType::parse(this->getContext(), parser);
|
||||
if (parser.parseOptionalKeyword("eint").succeeded()) {
|
||||
generatedTypeParser(this->getContext(), parser, "eint", type);
|
||||
return type;
|
||||
}
|
||||
|
||||
// TODO
|
||||
// Don't have a parser for a custom type
|
||||
// We shouldn't call the default parser
|
||||
// but what should we do instead?
|
||||
parser.parseType(type);
|
||||
return type;
|
||||
}
|
||||
|
||||
void HLFHEDialect::printType(::mlir::Type type,
|
||||
::mlir::DialectAsmPrinter &printer) const {
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType eint =
|
||||
type.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
if (eint != nullptr) {
|
||||
eint.print(printer);
|
||||
return;
|
||||
}
|
||||
// TODO - What should be done here?
|
||||
printer << "unknwontype";
|
||||
if (generatedTypePrinter(type, printer).failed())
|
||||
// Calling default printer if failed to print HLFHE type
|
||||
printer.printType(type);
|
||||
}
|
||||
|
||||
69
compiler/python/CMakeLists.txt
Normal file
69
compiler/python/CMakeLists.txt
Normal file
@@ -0,0 +1,69 @@
|
||||
include(AddMLIRPython)
|
||||
add_custom_target(ZamalangBindingsPython)
|
||||
|
||||
################################################################################
|
||||
# Build native Python extension
|
||||
################################################################################
|
||||
|
||||
add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
|
||||
INSTALL_DIR
|
||||
python
|
||||
SOURCES
|
||||
ZamalangModule.cpp
|
||||
HLFHEModule.cpp
|
||||
LINK_LIBS
|
||||
ZAMALANGCAPIHLFHE
|
||||
)
|
||||
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)
|
||||
|
||||
################################################################################
|
||||
# Copy python source tree.
|
||||
################################################################################
|
||||
|
||||
file(GLOB_RECURSE PY_SRC_FILES
|
||||
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/zamalang/*.py")
|
||||
|
||||
add_custom_target(ZAMALANGBindingsPythonSources ALL
|
||||
DEPENDS
|
||||
${PY_SRC_FILES}
|
||||
)
|
||||
|
||||
add_dependencies(ZamalangBindingsPython ZAMALANGBindingsPythonSources)
|
||||
|
||||
foreach(PY_SRC_FILE ${PY_SRC_FILES})
|
||||
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
|
||||
get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY)
|
||||
file(MAKE_DIRECTORY "${PY_DEST_DIR}")
|
||||
add_custom_command(
|
||||
TARGET ZAMALANGBindingsPythonSources PRE_BUILD
|
||||
COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}"
|
||||
DEPENDS "${PY_SRC_FILE}"
|
||||
BYPRODUCTS "${PY_DEST_FILE}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E create_symlink
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}"
|
||||
)
|
||||
endforeach()
|
||||
|
||||
# Note that we copy from the source tree just like for headers because
|
||||
# it will not be polluted with py_cache runtime artifacts (from testing and
|
||||
# such).
|
||||
install(
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/zamalang
|
||||
DESTINATION python
|
||||
COMPONENT ZAMALANGBindingsPythonSources
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
if (NOT LLVM_ENABLE_IDE)
|
||||
add_llvm_install_targets(
|
||||
install-ZAMALANGBindingsPythonSources
|
||||
DEPENDS ZAMALANGBindingsPythonSources
|
||||
COMPONENT ZAMALANGBindingsPythonSources)
|
||||
endif()
|
||||
|
||||
################################################################################
|
||||
# Generated sources.
|
||||
################################################################################
|
||||
|
||||
add_subdirectory(zamalang/dialects)
|
||||
14
compiler/python/DialectModules.h
Normal file
14
compiler/python/DialectModules.h
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef ZAMALANG_PYTHON_DIALECTMODULES_H
|
||||
#define ZAMALANG_PYTHON_DIALECTMODULES_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace zamalang {
|
||||
namespace python {
|
||||
|
||||
void populateDialectHLFHESubmodule(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace zamalang
|
||||
|
||||
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H
|
||||
27
compiler/python/HLFHEModule.cpp
Normal file
27
compiler/python/HLFHEModule.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "DialectModules.h"
|
||||
|
||||
#include "zamalang-c/Dialect/HLFHE.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
using namespace zamalang;
|
||||
using namespace mlir::python::adaptors;
|
||||
|
||||
/// Populate the hlfhe python module.
|
||||
void zamalang::python::populateDialectHLFHESubmodule(pybind11::module &m) {
|
||||
m.doc() = "HLFHE dialect Python native extension";
|
||||
|
||||
mlir_type_subclass(m, "EncryptedIntegerType",
|
||||
hlfheTypeIsAnEncryptedIntegerType)
|
||||
.def_classmethod(
|
||||
"get", [](pybind11::object cls, MlirContext ctx, unsigned width) {
|
||||
return cls(hlfheEncryptedIntegerTypeGet(ctx, width));
|
||||
});
|
||||
}
|
||||
35
compiler/python/ZamalangModule.cpp
Normal file
35
compiler/python/ZamalangModule.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
#include "DialectModules.h"
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "zamalang-c/Dialect/HLFHE.h"
|
||||
|
||||
#include "llvm-c/ErrorHandling.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_zamalang, m) {
|
||||
m.doc() = "Zamalang Python Native Extension";
|
||||
llvm::sys::PrintStackTraceOnErrorSignal(/*argv=*/"");
|
||||
LLVMEnablePrettyStackTrace();
|
||||
|
||||
m.def(
|
||||
"register_dialects",
|
||||
[](py::object capsule) {
|
||||
// Get the MlirContext capsule from PyMlirContext capsule.
|
||||
auto wrappedCapsule = capsule.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
|
||||
MlirContext context = mlirPythonCapsuleToContext(wrappedCapsule.ptr());
|
||||
|
||||
// Collect Zamalang dialects to register.
|
||||
MlirDialectHandle hlfhe = mlirGetDialectHandle__hlfhe__();
|
||||
mlirDialectHandleRegisterDialect(hlfhe, context);
|
||||
mlirDialectHandleLoadDialect(hlfhe, context);
|
||||
},
|
||||
"Register Zamalang dialects on a PyMlirContext.");
|
||||
|
||||
py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API");
|
||||
zamalang::python::populateDialectHLFHESubmodule(hlfhe);
|
||||
}
|
||||
1
compiler/python/zamalang/__init__.py
Normal file
1
compiler/python/zamalang/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from _zamalang import *
|
||||
31
compiler/python/zamalang/dialects/CMakeLists.txt
Normal file
31
compiler/python/zamalang/dialects/CMakeLists.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
include(AddMLIRPython)
|
||||
|
||||
################################################################################
|
||||
# Generate dialect-specific bindings.
|
||||
################################################################################
|
||||
|
||||
add_mlir_dialect_python_bindings(ZAMALANGBindingsPythonHLFHEOps
|
||||
TD_FILE HLFHEOps.td
|
||||
DIALECT_NAME HLFHE)
|
||||
add_dependencies(ZAMALANGBindingsPythonSources ZAMALANGBindingsPythonHLFHEOps)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Installation.
|
||||
################################################################################
|
||||
|
||||
install(
|
||||
DIRECTORY ${PROJECT_BINARY_DIR}/python/zamalang/dialects
|
||||
DESTINATION python/zamalang
|
||||
COMPONENT ZAMALANGBindingsPythonDialects
|
||||
FILES_MATCHING PATTERN "_*_gen.py"
|
||||
PATTERN "__pycache__" EXCLUDE
|
||||
PATTERN "__init__.py" EXCLUDE
|
||||
)
|
||||
|
||||
if (NOT LLVM_ENABLE_IDE)
|
||||
add_llvm_install_targets(
|
||||
install-ZAMALANGBindingsPythonDialects
|
||||
DEPENDS ZAMALANGBindingsPythonSources
|
||||
COMPONENT ZAMALANGBindingsPythonDialects)
|
||||
endif()
|
||||
7
compiler/python/zamalang/dialects/HLFHEOps.td
Normal file
7
compiler/python/zamalang/dialects/HLFHEOps.td
Normal file
@@ -0,0 +1,7 @@
|
||||
#ifndef PYTHON_BINDINGS_HLFHE_OPS
|
||||
#define PYTHON_BINDINGS_HLFHE_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "zamalang/Dialect/HLFHE/IR/HLFHEOps.td"
|
||||
|
||||
#endif
|
||||
1
compiler/python/zamalang/dialects/_ods_common.py
Normal file
1
compiler/python/zamalang/dialects/_ods_common.py
Normal file
@@ -0,0 +1 @@
|
||||
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context
|
||||
2
compiler/python/zamalang/dialects/hlfhe.py
Normal file
2
compiler/python/zamalang/dialects/hlfhe.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from ._HLFHE_ops_gen import *
|
||||
from _zamalang._hlfhe import *
|
||||
25
compiler/test_python.py
Executable file
25
compiler/test_python.py
Executable file
@@ -0,0 +1,25 @@
|
||||
import zamalang
|
||||
import zamalang.dialects.hlfhe as hlfhe
|
||||
import mlir.dialects.builtin as builtin
|
||||
import mlir.dialects.std as std
|
||||
from mlir.ir import *
|
||||
|
||||
|
||||
def main():
|
||||
with Context() as ctx, Location.unknown():
|
||||
# register zamalang's dialects
|
||||
zamalang.register_dialects(ctx)
|
||||
|
||||
module = Module.create()
|
||||
eint16 = hlfhe.EncryptedIntegerType.get(ctx, 16)
|
||||
with InsertionPoint(module.body):
|
||||
func_types = [RankedTensorType.get((10, 10), eint16) for _ in range(2)]
|
||||
@builtin.FuncOp.from_py_func(*func_types)
|
||||
def fhe_circuit(*arg):
|
||||
return arg[0]
|
||||
|
||||
print(module)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user