diff --git a/.github/workflows/conformance.yml b/.github/workflows/conformance.yml index 48e782bcf..d2b38592e 100644 --- a/.github/workflows/conformance.yml +++ b/.github/workflows/conformance.yml @@ -24,6 +24,6 @@ jobs: - name: Build and test compiler uses: addnab/docker-run-action@v3 with: - image: qbozama/mlir:cc9283 + image: qbozama/mlir:latest options: -v ${{ github.workspace }}:/workspace run: cd /workspace/compiler && mkdir build && cmake -B build . -DLLVM_DIR=$LLVM_PROJECT/build/lib/cmake/llvm -DMLIR_DIR=$LLVM_PROJECT/build/lib/cmake/mlir && make -C build/ zamacompiler && make test diff --git a/.gitignore b/.gitignore index 9e080aa6d..5df00f4a2 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,5 @@ # Jetbrains tools .idea/ +# Python cache +__pycache__/ diff --git a/builders/Dockerfile.mlir-env b/builders/Dockerfile.mlir-env index 3cd930f6d..2ee62c896 100644 --- a/builders/Dockerfile.mlir-env +++ b/builders/Dockerfile.mlir-env @@ -2,15 +2,18 @@ FROM ubuntu:latest RUN apt-get update --fix-missing RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y curl cmake g++ build-essential python3 python3-pip python3-setuptools ninja-build git +RUN pip install numpy pybind11==2.6.2 PyYAML RUN git clone --depth 1 https://github.com/llvm/llvm-project.git ENV LLVM_PROJECT=$PWD/llvm-project +RUN cd ${LLVM_PROJECT} && git log -1 RUN mkdir ${LLVM_PROJECT}/build RUN cd ${LLVM_PROJECT}/build && cmake -GNinja ../llvm \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_BUILD_EXAMPLES=OFF \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON RUN cd ${LLVM_PROJECT}/build && cmake --build . --target check-mlir ENV PATH=${LLVM_PROJECT}/build/bin:${PATH} @@ -19,4 +22,5 @@ COPY --from=0 /llvm-project/ /llvm-project/ ENV LLVM_PROJECT=/llvm-project ENV PATH=${LLVM_PROJECT}/build/bin:${PATH} RUN apt-get update -RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y cmake g++ build-essential python3 zlib1g-dev \ No newline at end of file +RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y cmake g++ build-essential python3 zlib1g-dev python3-pip python3-setuptools +RUN pip install numpy pybind11==2.6.2 PyYAML \ No newline at end of file diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index 08d668e31..e818b19b0 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -27,6 +27,38 @@ include_directories(${PROJECT_BINARY_DIR}/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) + +#------------------------------------------------------------------------------- +# Python Configuration +#------------------------------------------------------------------------------- + +option(ZAMALANG_BINDINGS_PYTHON_ENABLED "Enables ZamaLang Python bindings." ON) + +if(ZAMALANG_BINDINGS_PYTHON_ENABLED) + message(STATUS "ZamaLang Python bindings are enabled.") + + include(MLIRDetectPythonEnv) + find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + message(STATUS "Found Python include dirs: ${Python3_INCLUDE_DIRS}") + message(STATUS "Found Python libraries: ${Python3_LIBRARIES}") + message(STATUS "Found Python executable: ${Python3_EXECUTABLE}") + + mlir_detect_pybind11_install() + find_package(pybind11 2.6 CONFIG REQUIRED) + message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIR}") + message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', " + "suffix = '${PYTHON_MODULE_SUFFIX}', " + "extension = '${PYTHON_MODULE_EXTENSION}'") +else() + message(STATUS "ZamaLang Python bindings are disabled.") +endif() + + + add_subdirectory(include) add_subdirectory(lib) add_subdirectory(src) + +if (ZAMALANG_BINDINGS_PYTHON_ENABLED) + add_subdirectory(python) +endif() diff --git a/compiler/include/zamalang-c/Dialect/HLFHE.h b/compiler/include/zamalang-c/Dialect/HLFHE.h new file mode 100644 index 000000000..512a6f763 --- /dev/null +++ b/compiler/include/zamalang-c/Dialect/HLFHE.h @@ -0,0 +1,24 @@ +#ifndef ZAMALANG_C_DIALECT_HLFHE_H +#define ZAMALANG_C_DIALECT_HLFHE_H + +#include "mlir-c/IR.h" +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe); + +/// Creates an encrypted integer type of `width` bits +MLIR_CAPI_EXPORTED MlirType hlfheEncryptedIntegerTypeGet(MlirContext context, + unsigned width); + +/// If the type is an EncryptedInteger +MLIR_CAPI_EXPORTED bool hlfheTypeIsAnEncryptedIntegerType(MlirType); + +#ifdef __cplusplus +} +#endif + +#endif // ZAMALANG_C_DIALECT_HLFHE_H diff --git a/compiler/lib/CAPI/CMakeLists.txt b/compiler/lib/CAPI/CMakeLists.txt new file mode 100644 index 000000000..e6f347c8c --- /dev/null +++ b/compiler/lib/CAPI/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) \ No newline at end of file diff --git a/compiler/lib/CAPI/Dialect/CMakeLists.txt b/compiler/lib/CAPI/Dialect/CMakeLists.txt new file mode 100644 index 000000000..4a69936d3 --- /dev/null +++ b/compiler/lib/CAPI/Dialect/CMakeLists.txt @@ -0,0 +1,13 @@ +set(LLVM_OPTIONAL_SOURCES HLFHE.cpp) + +add_mlir_library(ZAMALANGCAPIHLFHE + + HLFHE.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir-c + + LINK_LIBS PUBLIC + MLIRCAPIIR + HLFHEDialect + ) diff --git a/compiler/lib/CAPI/Dialect/HLFHE.cpp b/compiler/lib/CAPI/Dialect/HLFHE.cpp new file mode 100644 index 000000000..3557646ee --- /dev/null +++ b/compiler/lib/CAPI/Dialect/HLFHE.cpp @@ -0,0 +1,27 @@ +#include "zamalang-c/Dialect/HLFHE.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Support.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" + +using namespace mlir::zamalang::HLFHE; + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(HLFHE, hlfhe, HLFHEDialect) + +//===----------------------------------------------------------------------===// +// Type API. +//===----------------------------------------------------------------------===// + +bool hlfheTypeIsAnEncryptedIntegerType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType hlfheEncryptedIntegerTypeGet(MlirContext ctx, unsigned width) { + return wrap(EncryptedIntegerType::get(unwrap(ctx), width)); +} diff --git a/compiler/lib/CMakeLists.txt b/compiler/lib/CMakeLists.txt index 0ca0f41c5..0a6112b37 100644 --- a/compiler/lib/CMakeLists.txt +++ b/compiler/lib/CMakeLists.txt @@ -1 +1,5 @@ add_subdirectory(Dialect) +# CAPI needed only for python bindings +if (ZAMALANG_BINDINGS_PYTHON_ENABLED) + add_subdirectory(CAPI) +endif() diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp index d0bf22b05..5453a5302 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEDialect.cpp @@ -22,20 +22,24 @@ void HLFHEDialect::initialize() { } ::mlir::Type HLFHEDialect::parseType(::mlir::DialectAsmParser &parser) const { - if (parser.parseKeyword("eint").failed()) - return ::mlir::Type(); + mlir::Type type; - return EncryptedIntegerType::parse(this->getContext(), parser); + if (parser.parseOptionalKeyword("eint").succeeded()) { + generatedTypeParser(this->getContext(), parser, "eint", type); + return type; + } + + // TODO + // Don't have a parser for a custom type + // We shouldn't call the default parser + // but what should we do instead? + parser.parseType(type); + return type; } void HLFHEDialect::printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const { - mlir::zamalang::HLFHE::EncryptedIntegerType eint = - type.dyn_cast_or_null(); - if (eint != nullptr) { - eint.print(printer); - return; - } - // TODO - What should be done here? - printer << "unknwontype"; + if (generatedTypePrinter(type, printer).failed()) + // Calling default printer if failed to print HLFHE type + printer.printType(type); } diff --git a/compiler/python/CMakeLists.txt b/compiler/python/CMakeLists.txt new file mode 100644 index 000000000..eaf260c5e --- /dev/null +++ b/compiler/python/CMakeLists.txt @@ -0,0 +1,69 @@ +include(AddMLIRPython) +add_custom_target(ZamalangBindingsPython) + +################################################################################ +# Build native Python extension +################################################################################ + +add_mlir_python_extension(ZamalangBindingsPythonExtension _zamalang + INSTALL_DIR + python + SOURCES + ZamalangModule.cpp + HLFHEModule.cpp + LINK_LIBS + ZAMALANGCAPIHLFHE +) +add_dependencies(ZamalangBindingsPython ZamalangBindingsPythonExtension) + +################################################################################ +# Copy python source tree. +################################################################################ + +file(GLOB_RECURSE PY_SRC_FILES + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/zamalang/*.py") + +add_custom_target(ZAMALANGBindingsPythonSources ALL + DEPENDS + ${PY_SRC_FILES} +) + +add_dependencies(ZamalangBindingsPython ZAMALANGBindingsPythonSources) + +foreach(PY_SRC_FILE ${PY_SRC_FILES}) + set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}") + get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${PY_DEST_DIR}") + add_custom_command( + TARGET ZAMALANGBindingsPythonSources PRE_BUILD + COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}" + DEPENDS "${PY_SRC_FILE}" + BYPRODUCTS "${PY_DEST_FILE}" + COMMAND "${CMAKE_COMMAND}" -E create_symlink + "${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}" + ) +endforeach() + +# Note that we copy from the source tree just like for headers because +# it will not be polluted with py_cache runtime artifacts (from testing and +# such). +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/zamalang + DESTINATION python + COMPONENT ZAMALANGBindingsPythonSources + FILES_MATCHING PATTERN "*.py" +) + +if (NOT LLVM_ENABLE_IDE) + add_llvm_install_targets( + install-ZAMALANGBindingsPythonSources + DEPENDS ZAMALANGBindingsPythonSources + COMPONENT ZAMALANGBindingsPythonSources) +endif() + +################################################################################ +# Generated sources. +################################################################################ + +add_subdirectory(zamalang/dialects) diff --git a/compiler/python/DialectModules.h b/compiler/python/DialectModules.h new file mode 100644 index 000000000..70414683f --- /dev/null +++ b/compiler/python/DialectModules.h @@ -0,0 +1,14 @@ +#ifndef ZAMALANG_PYTHON_DIALECTMODULES_H +#define ZAMALANG_PYTHON_DIALECTMODULES_H + +#include + +namespace zamalang { +namespace python { + +void populateDialectHLFHESubmodule(pybind11::module &m); + +} // namespace python +} // namespace zamalang + +#endif // ZAMALANG_PYTHON_DIALECTMODULES_H \ No newline at end of file diff --git a/compiler/python/HLFHEModule.cpp b/compiler/python/HLFHEModule.cpp new file mode 100644 index 000000000..2a1e0fa52 --- /dev/null +++ b/compiler/python/HLFHEModule.cpp @@ -0,0 +1,27 @@ +#include "DialectModules.h" + +#include "zamalang-c/Dialect/HLFHE.h" + +#include "mlir-c/BuiltinAttributes.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +using namespace zamalang; +using namespace mlir::python::adaptors; + +/// Populate the hlfhe python module. +void zamalang::python::populateDialectHLFHESubmodule(pybind11::module &m) { + m.doc() = "HLFHE dialect Python native extension"; + + mlir_type_subclass(m, "EncryptedIntegerType", + hlfheTypeIsAnEncryptedIntegerType) + .def_classmethod( + "get", [](pybind11::object cls, MlirContext ctx, unsigned width) { + return cls(hlfheEncryptedIntegerTypeGet(ctx, width)); + }); +} \ No newline at end of file diff --git a/compiler/python/ZamalangModule.cpp b/compiler/python/ZamalangModule.cpp new file mode 100644 index 000000000..e2755ef35 --- /dev/null +++ b/compiler/python/ZamalangModule.cpp @@ -0,0 +1,35 @@ +#include "DialectModules.h" + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Registration.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "zamalang-c/Dialect/HLFHE.h" + +#include "llvm-c/ErrorHandling.h" +#include "llvm/Support/Signals.h" + +#include +namespace py = pybind11; + +PYBIND11_MODULE(_zamalang, m) { + m.doc() = "Zamalang Python Native Extension"; + llvm::sys::PrintStackTraceOnErrorSignal(/*argv=*/""); + LLVMEnablePrettyStackTrace(); + + m.def( + "register_dialects", + [](py::object capsule) { + // Get the MlirContext capsule from PyMlirContext capsule. + auto wrappedCapsule = capsule.attr(MLIR_PYTHON_CAPI_PTR_ATTR); + MlirContext context = mlirPythonCapsuleToContext(wrappedCapsule.ptr()); + + // Collect Zamalang dialects to register. + MlirDialectHandle hlfhe = mlirGetDialectHandle__hlfhe__(); + mlirDialectHandleRegisterDialect(hlfhe, context); + mlirDialectHandleLoadDialect(hlfhe, context); + }, + "Register Zamalang dialects on a PyMlirContext."); + + py::module hlfhe = m.def_submodule("_hlfhe", "HLFHE API"); + zamalang::python::populateDialectHLFHESubmodule(hlfhe); +} \ No newline at end of file diff --git a/compiler/python/zamalang/__init__.py b/compiler/python/zamalang/__init__.py new file mode 100644 index 000000000..4e4982057 --- /dev/null +++ b/compiler/python/zamalang/__init__.py @@ -0,0 +1 @@ +from _zamalang import * diff --git a/compiler/python/zamalang/dialects/CMakeLists.txt b/compiler/python/zamalang/dialects/CMakeLists.txt new file mode 100644 index 000000000..e700178bd --- /dev/null +++ b/compiler/python/zamalang/dialects/CMakeLists.txt @@ -0,0 +1,31 @@ +include(AddMLIRPython) + +################################################################################ +# Generate dialect-specific bindings. +################################################################################ + +add_mlir_dialect_python_bindings(ZAMALANGBindingsPythonHLFHEOps + TD_FILE HLFHEOps.td + DIALECT_NAME HLFHE) +add_dependencies(ZAMALANGBindingsPythonSources ZAMALANGBindingsPythonHLFHEOps) + + +################################################################################ +# Installation. +################################################################################ + +install( + DIRECTORY ${PROJECT_BINARY_DIR}/python/zamalang/dialects + DESTINATION python/zamalang + COMPONENT ZAMALANGBindingsPythonDialects + FILES_MATCHING PATTERN "_*_gen.py" + PATTERN "__pycache__" EXCLUDE + PATTERN "__init__.py" EXCLUDE +) + +if (NOT LLVM_ENABLE_IDE) + add_llvm_install_targets( + install-ZAMALANGBindingsPythonDialects + DEPENDS ZAMALANGBindingsPythonSources + COMPONENT ZAMALANGBindingsPythonDialects) +endif() diff --git a/compiler/python/zamalang/dialects/HLFHEOps.td b/compiler/python/zamalang/dialects/HLFHEOps.td new file mode 100644 index 000000000..c4ec57dfe --- /dev/null +++ b/compiler/python/zamalang/dialects/HLFHEOps.td @@ -0,0 +1,7 @@ +#ifndef PYTHON_BINDINGS_HLFHE_OPS +#define PYTHON_BINDINGS_HLFHE_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "zamalang/Dialect/HLFHE/IR/HLFHEOps.td" + +#endif diff --git a/compiler/python/zamalang/dialects/_ods_common.py b/compiler/python/zamalang/dialects/_ods_common.py new file mode 100644 index 000000000..58451fd03 --- /dev/null +++ b/compiler/python/zamalang/dialects/_ods_common.py @@ -0,0 +1 @@ +from mlir.dialects._ods_common import _cext, segmented_accessor, equally_sized_accessor, extend_opview_class, get_default_loc_context diff --git a/compiler/python/zamalang/dialects/hlfhe.py b/compiler/python/zamalang/dialects/hlfhe.py new file mode 100644 index 000000000..c52f2c3ad --- /dev/null +++ b/compiler/python/zamalang/dialects/hlfhe.py @@ -0,0 +1,2 @@ +from ._HLFHE_ops_gen import * +from _zamalang._hlfhe import * diff --git a/compiler/test_python.py b/compiler/test_python.py new file mode 100755 index 000000000..94a576c76 --- /dev/null +++ b/compiler/test_python.py @@ -0,0 +1,25 @@ +import zamalang +import zamalang.dialects.hlfhe as hlfhe +import mlir.dialects.builtin as builtin +import mlir.dialects.std as std +from mlir.ir import * + + +def main(): + with Context() as ctx, Location.unknown(): + # register zamalang's dialects + zamalang.register_dialects(ctx) + + module = Module.create() + eint16 = hlfhe.EncryptedIntegerType.get(ctx, 16) + with InsertionPoint(module.body): + func_types = [RankedTensorType.get((10, 10), eint16) for _ in range(2)] + @builtin.FuncOp.from_py_func(*func_types) + def fhe_circuit(*arg): + return arg[0] + + print(module) + + +if __name__ == "__main__": + main() \ No newline at end of file