feat(compiler): Python bindings (#53)

- feat(compiler): python bindings
- build: update docker image for python bindings
- pin pybind11 to 2.6.2, 2.7 is not having correct include_dirs set (still
a question why?)
- using generated parser/printer
This commit is contained in:
Ayoub Benaissa
2021-07-28 15:58:51 +01:00
committed by GitHub
parent 812268000c
commit ab53ef71c6
20 changed files with 337 additions and 14 deletions

View File

@@ -24,6 +24,6 @@ jobs:
- name: Build and test compiler
uses: addnab/docker-run-action@v3
with:
image: qbozama/mlir:cc9283
image: qbozama/mlir:latest
options: -v ${{ github.workspace }}:/workspace
run: cd /workspace/compiler && mkdir build && cmake -B build . -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ zamacompiler && make test

2
.gitignore vendored
View File

@@ -39,3 +39,5 @@
# Jetbrains tools
.idea/
# Python cache
__pycache__/

View File

@@ -2,15 +2,18 @@ FROM ubuntu:latest
RUN apt-get update --fix-missing
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y curl cmake g++ build-essential python3 python3-pip python3-setuptools ninja-build git
RUN pip install numpy pybind11==2.6.2 PyYAML
RUN git clone --depth 1 https://github.com/llvm/llvm-project.git
ENV LLVM_PROJECT=$PWD/llvm-project
RUN cd ${LLVM_PROJECT} && git log -1
RUN mkdir ${LLVM_PROJECT}/build
RUN cd ${LLVM_PROJECT}/build && cmake -GNinja ../llvm \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="host" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
RUN cd ${LLVM_PROJECT}/build && cmake --build . --target check-mlir
ENV PATH=${LLVM_PROJECT}/build/bin:${PATH}
@@ -19,4 +22,5 @@ COPY --from=0 /llvm-project/ /llvm-project/
ENV LLVM_PROJECT=/llvm-project
ENV PATH=${LLVM_PROJECT}/build/bin:${PATH}
RUN apt-get update
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y cmake g++ build-essential python3 zlib1g-dev
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y cmake g++ build-essential python3 zlib1g-dev python3-pip python3-setuptools
RUN pip install numpy pybind11==2.6.2 PyYAML

View File

@@ -27,6 +27,38 @@ include_directories(${PROJECT_BINARY_DIR}/include)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
#-------------------------------------------------------------------------------
# Python Configuration
#-------------------------------------------------------------------------------
option(ZAMALANG_BINDINGS_PYTHON_ENABLED "Enables ZamaLang Python bindings." ON)
if(ZAMALANG_BINDINGS_PYTHON_ENABLED)
message(STATUS "ZamaLang Python bindings are enabled.")
include(MLIRDetectPythonEnv)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
message(STATUS "Found Python include dirs: ${Python3_INCLUDE_DIRS}")
message(STATUS "Found Python libraries: ${Python3_LIBRARIES}")
message(STATUS "Found Python executable: ${Python3_EXECUTABLE}")
mlir_detect_pybind11_install()
find_package(pybind11 2.6 CONFIG REQUIRED)
message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}")
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "
"extension = '${PYTHON_MODULE_EXTENSION}'")
else()
message(STATUS "ZamaLang Python bindings are disabled.")
endif()
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(src)
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
add_subdirectory(python)
endif()

View File

@@ -0,0 +1,24 @@
#ifndef ZAMALANG_C_DIALECT_HLFHE_H
#define ZAMALANG_C_DIALECT_HLFHE_H
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe);
/// Creates an encrypted integer type of `width` bits
MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGet(MlirContext context,
unsigned width);
/// If the type is an EncryptedInteger
MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType);
#ifdef __cplusplus
}
#endif
#endif // ZAMALANG_C_DIALECT_HLFHE_H

View File

@@ -0,0 +1 @@
add_subdirectory(Dialect)

View File

@@ -0,0 +1,13 @@
set(LLVM_OPTIONAL_SOURCES HLFHE.cpp)
add_mlir_library(ZAMALANGCAPIHLFHE
HLFHE.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir-c
LINK_LIBS PUBLIC
MLIRCAPIIR
HLFHEDialect
)

View File

@@ -0,0 +1,27 @@
#include "zamalang-c/Dialect/HLFHE.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
using namespace mlir::zamalang::HLFHE;
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe, HLFHEDialect)
//===----------------------------------------------------------------------===//
// Type API.
//===----------------------------------------------------------------------===//
bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) {
return unwrap(type).isa<EncryptedIntegerType>();
}
MlirType hlfheEncryptedIntegerTypeGet(MlirContext ctx, unsigned width) {
return wrap(EncryptedIntegerType::get(unwrap(ctx), width));
}

View File

@@ -1 +1,5 @@
add_subdirectory(Dialect)
# CAPI needed only for python bindings
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)
add_subdirectory(CAPI)
endif()

View File

@@ -22,20 +22,24 @@ void HLFHEDialect::initialize() {
}
::mlir::Type HLFHEDialect::parseType(::mlir::DialectAsmParser &parser) const {
if (parser.parseKeyword("eint").failed())
return ::mlir::Type();
mlir::Type type;
return EncryptedIntegerType::parse(this->getContext(), parser);
if (parser.parseOptionalKeyword("eint").succeeded()) {
generatedTypeParser(this->getContext(), parser, "eint", type);
return type;
}
// TODO
// Don't have a parser for a custom type
// We shouldn't call the default parser
// but what should we do instead?
parser.parseType(type);
return type;
}
void HLFHEDialect::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &printer) const {
mlir::zamalang::HLFHE::EncryptedIntegerType eint =
type.dyn_cast_or_null<mlir::zamalang::HLFHE::EncryptedIntegerType>();
if (eint != nullptr) {
eint.print(printer);
return;
}
// TODO - What should be done here?
printer << "unknwontype";
if (generatedTypePrinter(type, printer).failed())
// Calling default printer if failed to print HLFHE type
printer.printType(type);
}

View File

@@ -0,0 +1,69 @@
include(AddMLIRPython)
add_custom_target(ZamalangBindingsPython)
################################################################################
# Build native Python extension
################################################################################
add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang
INSTALL_DIR
python
SOURCES
ZamalangModule.cpp
HLFHEModule.cpp
LINK_LIBS
ZAMALANGCAPIHLFHE
)
add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension)
################################################################################
# Copy python source tree.
################################################################################
file(GLOB_RECURSE PY_SRC_FILES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/zamalang/*.py")
add_custom_target(ZAMALANGBindingsPythonSources ALL
DEPENDS
${PY_SRC_FILES}
)
add_dependencies(ZamalangBindingsPython ZAMALANGBindingsPythonSources)
foreach(PY_SRC_FILE ${PY_SRC_FILES})
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY)
file(MAKE_DIRECTORY "${PY_DEST_DIR}")
add_custom_command(
TARGET ZAMALANGBindingsPythonSources PRE_BUILD
COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}"
DEPENDS "${PY_SRC_FILE}"
BYPRODUCTS "${PY_DEST_FILE}"
COMMAND "${CMAKE_COMMAND}" -E create_symlink
"${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}"
)
endforeach()
# Note that we copy from the source tree just like for headers because
# it will not be polluted with py_cache runtime artifacts (from testing and
# such).
install(
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/zamalang
DESTINATION python
COMPONENT ZAMALANGBindingsPythonSources
FILES_MATCHING PATTERN "*.py"
)
if (NOT LLVM_ENABLE_IDE)
add_llvm_install_targets(
install-ZAMALANGBindingsPythonSources
DEPENDS ZAMALANGBindingsPythonSources
COMPONENT ZAMALANGBindingsPythonSources)
endif()
################################################################################
# Generated sources.
################################################################################
add_subdirectory(zamalang/dialects)

View File

@@ -0,0 +1,14 @@
#ifndef ZAMALANG_PYTHON_DIALECTMODULES_H
#define ZAMALANG_PYTHON_DIALECTMODULES_H
#include <pybind11/pybind11.h>
namespace zamalang {
namespace python {
void populateDialectHLFHESubmodule(pybind11::module &m);
} // namespace python
} // namespace zamalang
#endif // ZAMALANG_PYTHON_DIALECTMODULES_H

View File

@@ -0,0 +1,27 @@
#include "DialectModules.h"
#include "zamalang-c/Dialect/HLFHE.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/raw_ostream.h"
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
using namespace zamalang;
using namespace mlir::python::adaptors;
/// Populate the hlfhe python module.
void zamalang::python::populateDialectHLFHESubmodule(pybind11::module &m) {
m.doc() = "HLFHE dialect Python native extension";
mlir_type_subclass(m, "EncryptedIntegerType",
hlfheTypeIsAnEncryptedIntegerType)
.def_classmethod(
"get", [](pybind11::object cls, MlirContext ctx, unsigned width) {
return cls(hlfheEncryptedIntegerTypeGet(ctx, width));
});
}

View File

@@ -0,0 +1,35 @@
#include "DialectModules.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Registration.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "zamalang-c/Dialect/HLFHE.h"
#include "llvm-c/ErrorHandling.h"
#include "llvm/Support/Signals.h"
#include <pybind11/pybind11.h>
namespace py = pybind11;
PYBIND11_MODULE(_zamalang, m) {
m.doc() = "Zamalang Python Native Extension";
llvm::sys::PrintStackTraceOnErrorSignal(/*argv=*/"");
LLVMEnablePrettyStackTrace();
m.def(
"register_dialects",
[](py::object capsule) {
// Get the MlirContext capsule from PyMlirContext capsule.
auto wrappedCapsule = capsule.attr(MLIR_PYTHON_CAPI_PTR_ATTR);
MlirContext context = mlirPythonCapsuleToContext(wrappedCapsule.ptr());
// Collect Zamalang dialects to register.
MlirDialectHandle hlfhe = mlirGetDialectHandle__hlfhe__();
mlirDialectHandleRegisterDialect(hlfhe, context);
mlirDialectHandleLoadDialect(hlfhe, context);
},
"Register Zamalang dialects on a PyMlirContext.");
py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API");
zamalang::python::populateDialectHLFHESubmodule(hlfhe);
}

View File

@@ -0,0 +1 @@
from _zamalang import *

View File

@@ -0,0 +1,31 @@
include(AddMLIRPython)
################################################################################
# Generate dialect-specific bindings.
################################################################################
add_mlir_dialect_python_bindings(ZAMALANGBindingsPythonHLFHEOps
TD_FILE HLFHEOps.td
DIALECT_NAME HLFHE)
add_dependencies(ZAMALANGBindingsPythonSources ZAMALANGBindingsPythonHLFHEOps)
################################################################################
# Installation.
################################################################################
install(
DIRECTORY ${PROJECT_BINARY_DIR}/python/zamalang/dialects
DESTINATION python/zamalang
COMPONENT ZAMALANGBindingsPythonDialects
FILES_MATCHING PATTERN "_*_gen.py"
PATTERN "__pycache__" EXCLUDE
PATTERN "__init__.py" EXCLUDE
)
if (NOT LLVM_ENABLE_IDE)
add_llvm_install_targets(
install-ZAMALANGBindingsPythonDialects
DEPENDS ZAMALANGBindingsPythonSources
COMPONENT ZAMALANGBindingsPythonDialects)
endif()

View File

@@ -0,0 +1,7 @@
#ifndef PYTHON_BINDINGS_HLFHE_OPS
#define PYTHON_BINDINGS_HLFHE_OPS
include "mlir/Bindings/Python/Attributes.td"
include "zamalang/Dialect/HLFHE/IR/HLFHEOps.td"
#endif

View File

@@ -0,0 +1 @@
from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context

View File

@@ -0,0 +1,2 @@
from ._HLFHE_ops_gen import *
from _zamalang._hlfhe import *

25
compiler/test_python.py Executable file
View File

@@ -0,0 +1,25 @@
import zamalang
import zamalang.dialects.hlfhe as hlfhe
import mlir.dialects.builtin as builtin
import mlir.dialects.std as std
from mlir.ir import *
def main():
with Context() as ctx, Location.unknown():
# register zamalang's dialects
zamalang.register_dialects(ctx)
module = Module.create()
eint16 = hlfhe.EncryptedIntegerType.get(ctx, 16)
with InsertionPoint(module.body):
func_types = [RankedTensorType.get((10, 10), eint16) for _ in range(2)]
@builtin.FuncOp.from_py_func(*func_types)
def fhe_circuit(*arg):
return arg[0]
print(module)
if __name__ == "__main__":
main()