From 7b29600721c7f1c807acfc1822248c399aaa9e81 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 21 Oct 2021 08:52:48 +0100 Subject: [PATCH 01/19] refactor: don't use designated initializers --- compiler/lib/Support/ClientParameters.cpp | 78 +++++++++++------------ compiler/lib/Support/Pipeline.cpp | 3 +- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/compiler/lib/Support/ClientParameters.cpp b/compiler/lib/Support/ClientParameters.cpp index 368050a0a..b94e105a1 100644 --- a/compiler/lib/Support/ClientParameters.cpp +++ b/compiler/lib/Support/ClientParameters.cpp @@ -28,24 +28,24 @@ llvm::Expected gateFromMLIRType(std::string secretKeyID, width = type.getIntOrFloatBitWidth(); } return CircuitGate{ - .encryption = llvm::None, - .shape = - { - .width = width, - .size = 0, - }, + /*.encryption = */ llvm::None, + /*.shape = */ + { + /*.width = */ width, + /*.size = */ 0, + }, }; } if (type.isa()) { // TODO - Get the width from the LWECiphertextType instead of global // precision (could be possible after merge lowlfhe-ciphertext-parameter) return CircuitGate{ - .encryption = llvm::Optional({ - .secretKeyID = secretKeyID, - .variance = variance, - .encoding = {.precision = precision}, + /*.encryption = */ llvm::Optional({ + /*.secretKeyID = */ secretKeyID, + /*.variance = */ variance, + /*.encoding = */ {/*.precision = */ precision}, }), - .shape = {.width = precision, .size = 0}, + /*.shape = */ {/*.width = */ precision, /*.size = */ 0}, }; } auto tensor = type.dyn_cast_or_null(); @@ -70,37 +70,37 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, v0Curve->getVariance(1, 1 << v0Param.polynomialSize, 64); Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64); // Static client parameters from global parameters for v0 - ClientParameters c{ - .secretKeys{ - {"small", {.size = v0Param.nSmall}}, - {"big", {.size = v0Param.getNBigGlweSize()}}, - }, - .bootstrapKeys{ + ClientParameters c = {}; + c.secretKeys = { + {"small", {/*.size = */ v0Param.nSmall}}, + {"big", {/*.size = */ v0Param.getNBigGlweSize()}}, + }; + c.bootstrapKeys = { + { + "bsk_v0", { - "bsk_v0", - { - .inputSecretKeyID = "small", - .outputSecretKeyID = "big", - .level = v0Param.brLevel, - .baseLog = v0Param.brLogBase, - .k = v0Param.k, - .variance = encryptionVariance, - }, - }, - }, - .keyswitchKeys{ - { - "ksk_v0", - { - .inputSecretKeyID = "big", - .outputSecretKeyID = "small", - .level = v0Param.ksLevel, - .baseLog = v0Param.ksLogBase, - .variance = keyswitchVariance, - }, + /*.inputSecretKeyID = */ "small", + /*.outputSecretKeyID = */ "big", + /*.level = */ v0Param.brLevel, + /*.baseLog = */ v0Param.brLogBase, + /*.k = */ v0Param.k, + /*.variance = */ encryptionVariance, }, }, }; + c.keyswitchKeys = { + { + "ksk_v0", + { + /*.inputSecretKeyID = */ "big", + /*.outputSecretKeyID = */ "small", + /*.level = */ v0Param.ksLevel, + /*.baseLog = */ v0Param.ksLogBase, + /*.variance = */ keyswitchVariance, + }, + }, + }; + // Find the input function auto rangeOps = module.getOps(); auto funcOp = llvm::find_if( @@ -113,7 +113,7 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name, // For the v0 the precision is global auto precision = fheContext.constraint.p; - Encoding encoding = {.precision = fheContext.constraint.p}; + Encoding encoding = {/*.precision = */ fheContext.constraint.p}; // Create input and output circuit gate parameters auto funcType = (*funcOp).getType(); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 166ba8879..69b0e39b3 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -76,7 +76,8 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) { if (oMax2norm.hasValue() && oMaxWidth.hasValue()) { ret = llvm::Optional( - {.norm2 = ceilLog2(oMax2norm.getValue()), .p = oMaxWidth.getValue()}); + {/*.norm2 = */ ceilLog2(oMax2norm.getValue()), + /*.p = */ oMaxWidth.getValue()}); } return ret; From 2e3560654ff5ca71488bfc2e2c729a35edd90d94 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 21 Oct 2021 08:58:51 +0100 Subject: [PATCH 02/19] chore: bump LLVM to 55e76c70 --- llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm-project b/llvm-project index f1e9ecea4..55e76c70a 160000 --- a/llvm-project +++ b/llvm-project @@ -1 +1 @@ -Subproject commit f1e9ecea442a2f839e5ac85f840b720db1ee7914 +Subproject commit 55e76c70a4f7fd5e13cf6c317a183bc3e6c59a03 From 5a2e9460fb92b35be3ec76c16fbb66f633468c12 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 11 Oct 2021 15:23:29 +0100 Subject: [PATCH 03/19] build: setup build tools for python package - Docker image to build wheels for linux_x86_64 CPython 3.[8,9,10] with GLIBC >= 2.24 - Specify which Python to use in Makefile - Fix cmake build to handle when libpython isn't available (cmake>3.18) --- README.md | 48 +++++++++++++++- .../Dockerfile.release_manylinux_2_24_x86_64 | 24 ++++++++ compiler/CMakeLists.txt | 13 ++++- compiler/Makefile | 22 +++++++- compiler/lib/Bindings/Python/CMakeLists.txt | 1 + .../lib/Bindings/Python/zamalang/__init__.py | 2 +- .../lib/Bindings/Python/zamalang/compiler.py | 6 +- .../Python/zamalang/dialects/__init__.py | 0 .../Python/zamalang/dialects/hlfhe.py | 2 +- compiler/setup.py | 55 +++++++++++++++++++ 10 files changed, 164 insertions(+), 9 deletions(-) create mode 100644 builders/Dockerfile.release_manylinux_2_24_x86_64 create mode 100644 compiler/lib/Bindings/Python/zamalang/dialects/__init__.py create mode 100644 compiler/setup.py diff --git a/README.md b/README.md index af3f44889..d0d8547e5 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,49 @@ # Homomorphizer -The homomorphizer is a compiler that takes a high level computation model and produces a programs that evaluate the model in an homomorphic way. \ No newline at end of file +The homomorphizer is a compiler that takes a high level computation model and produces a programs that evaluate the model in an homomorphic way. + + + +## Build the Python Package + +Currently supported platforms: +- Linux x86_64 for python 3.8, 3.9, and 3.10 + +### Linux + +We use the [manylinux](https://github.com/pypa/manylinux) docker images for building python packages for Linux. Those packages should work on distributions that have GLIBC >= 2.24. + +You can use Make to build the python wheels using these docker images: + +```bash +$ cd compiler +$ make package_py38 # package_py39 package_py310 +``` + +This will build the image for the appropriate python version then copy the wheels out under `/wheels` + +### Build wheels in your environment + +#### Temporary MLIR issue + +Due to an issue with MLIR, you will need to manually add `__init__.py` files to the `mlir` python package after the build. + +```bash +$ make python-bindings +$ touch build/tools/zamalang/python_packages/zamalang_core/mlir/__init__.py +$ touch build/tools/zamalang/python_packages/zamalang_core/mlir/dialects/__init__.py +``` + +#### Build wheel + +Building the wheels is actually simple. + +```bash +$ pip wheel --no-deps -w ../wheels . +``` + +Depending on the platform you are using (specially Linux), you might need to use `auditwheel` to specify the platform this wheel is targeting. For example, in our build of the package for Linux x86_64 and GLIBC 2.24, we also run: + +```bash +$ auditwheel repair ../wheels/*.whl --plat manylinux_2_24_x86_64 -w ../wheels +``` diff --git a/builders/Dockerfile.release_manylinux_2_24_x86_64 b/builders/Dockerfile.release_manylinux_2_24_x86_64 new file mode 100644 index 000000000..efd00a0fd --- /dev/null +++ b/builders/Dockerfile.release_manylinux_2_24_x86_64 @@ -0,0 +1,24 @@ +FROM quay.io/pypa/manylinux_2_24_x86_64 + +RUN apt-get update +RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y build-essential ninja-build +# Set the python path. Options: [cp38-cp38, cp39-cp39, cp310-cp310] +ARG python_tag=cp38-cp38 +# Install python deps +RUN /opt/python/${python_tag}/bin/pip install numpy pybind11==2.6.2 PyYAML +# Setup LLVM +COPY /llvm-project /llvm-project +# Setup Concrete +COPY --from=ghcr.io/zama-ai/concrete-api-env:latest /target/release /concrete/target/release +ENV CONCRETE_PROJECT=/concrete +# Setup and build compiler +COPY /compiler /compiler +WORKDIR /compiler +RUN make Python3_EXECUTABLE=/opt/python/${python_tag}/bin/python build +RUN make python-bindings +# Fix MLIR package +RUN touch build/tools/zamalang/python_packages/zamalang_core/mlir/__init__.py +RUN touch build/tools/zamalang/python_packages/zamalang_core/mlir/dialects/__init__.py +# Build wheel +RUN /opt/python/${python_tag}/bin/pip wheel --no-deps -w /wheels . +RUN auditwheel repair /wheels/*.whl --plat manylinux_2_24_x86_64 -w /wheels \ No newline at end of file diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index 7c8adbdda..1de40aa21 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -56,7 +56,18 @@ if(ZAMALANG_BINDINGS_PYTHON_ENABLED) message(STATUS "ZamaLang Python bindings are enabled.") include(MLIRDetectPythonEnv) - find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + # After CMake 3.18, we are able to limit the scope of the search to just + # Development.Module. Searching for Development will fail in situations where + # the Python libraries are not available. When possible, limit to just + # Development.Module. + # See https://pybind11.readthedocs.io/en/stable/compiling.html#findpython-mode + if(CMAKE_VERSION VERSION_LESS "3.18.0") + set(_python_development_component Development) + else() + set(_python_development_component Development.Module) + endif() + find_package(Python3 COMPONENTS Interpreter ${_python_development_component} REQUIRED) + unset(_python_development_component) message(STATUS "Found Python include dirs: ${Python3_INCLUDE_DIRS}") message(STATUS "Found Python libraries: ${Python3_LIBRARIES}") message(STATUS "Found Python executable: ${Python3_EXECUTABLE}") diff --git a/compiler/Makefile b/compiler/Makefile index dda452d31..325753c13 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -1,4 +1,5 @@ BUILD_DIR=./build +Python3_EXECUTABLE= build: @@ -12,7 +13,8 @@ build: -DZAMALANG_BINDINGS_PYTHON_ENABLED=ON \ -DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release \ -DLLVM_EXTERNAL_PROJECTS=zamalang \ - -DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=. + -DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=. \ + -DPython3_EXECUTABLE=${Python3_EXECUTABLE} build-end-to-end-jit: build cmake --build $(BUILD_DIR) --target end_to_end_jit_test @@ -30,7 +32,7 @@ test-end-to-end-jit: build-end-to-end-jit $(BUILD_DIR)/bin/end_to_end_jit_test test-python: python-bindings - PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core/mlir/_mlir_libs/ LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python + PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python test: test-check test-end-to-end-jit test-python @@ -42,3 +44,19 @@ file-check: cmake --build $(BUILD_DIR) --target FileCheck not: cmake --build $(BUILD_DIR) --target not + +# Python packages + +define build_image_and_copy_wheels + docker image build -t concretefhe-compiler-manylinux:$(1) --build-arg python_tag=$(1) -f ../builders/Dockerfile.release_manylinux_2_24_x86_64 .. + docker container run --rm -v ${PWD}/../wheels:/wheels_volume concretefhe-compiler-manylinux:$(1) cp -r /wheels/. /wheels_volume/. +endef + +package_py38: + $(call build_image_and_copy_wheels,cp38-cp38) + +package_py39: + $(call build_image_and_copy_wheels,cp39-cp39) + +package_py310: + $(call build_image_and_copy_wheels,cp310-cp310) diff --git a/compiler/lib/Bindings/Python/CMakeLists.txt b/compiler/lib/Bindings/Python/CMakeLists.txt index c80c9c889..4dcf5f79b 100644 --- a/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compiler/lib/Bindings/Python/CMakeLists.txt @@ -27,6 +27,7 @@ declare_mlir_python_sources(ZamalangBindingsPythonSources SOURCES zamalang/__init__.py zamalang/compiler.py + zamalang/dialects/__init__.py zamalang/dialects/_ods_common.py) ################################################################################ diff --git a/compiler/lib/Bindings/Python/zamalang/__init__.py b/compiler/lib/Bindings/Python/zamalang/__init__.py index 647b40273..05a481546 100644 --- a/compiler/lib/Bindings/Python/zamalang/__init__.py +++ b/compiler/lib/Bindings/Python/zamalang/__init__.py @@ -1,3 +1,3 @@ """Zamalang python module""" -from _zamalang import * +from mlir._mlir_libs._zamalang import * from .compiler import CompilerEngine diff --git a/compiler/lib/Bindings/Python/zamalang/compiler.py b/compiler/lib/Bindings/Python/zamalang/compiler.py index 185e4a169..76e372463 100644 --- a/compiler/lib/Bindings/Python/zamalang/compiler.py +++ b/compiler/lib/Bindings/Python/zamalang/compiler.py @@ -1,8 +1,8 @@ """Compiler submodule""" from typing import List, Union -from _zamalang._compiler import CompilerEngine as _CompilerEngine -from _zamalang._compiler import ExecutionArgument as _ExecutionArgument -from _zamalang._compiler import round_trip as _round_trip +from mlir._mlir_libs._zamalang._compiler import CompilerEngine as _CompilerEngine +from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument +from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip def round_trip(mlir_str: str) -> str: diff --git a/compiler/lib/Bindings/Python/zamalang/dialects/__init__.py b/compiler/lib/Bindings/Python/zamalang/dialects/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/compiler/lib/Bindings/Python/zamalang/dialects/hlfhe.py b/compiler/lib/Bindings/Python/zamalang/dialects/hlfhe.py index 3d996e829..b87ed8953 100644 --- a/compiler/lib/Bindings/Python/zamalang/dialects/hlfhe.py +++ b/compiler/lib/Bindings/Python/zamalang/dialects/hlfhe.py @@ -1,3 +1,3 @@ """HLFHE dialect module""" from ._HLFHE_ops_gen import * -from _zamalang._hlfhe import * +from mlir._mlir_libs._zamalang._hlfhe import * diff --git a/compiler/setup.py b/compiler/setup.py new file mode 100644 index 000000000..48eccdc9a --- /dev/null +++ b/compiler/setup.py @@ -0,0 +1,55 @@ +import os +import subprocess +import setuptools + +from setuptools import Extension +from setuptools.command.build_ext import build_ext + + +def read(fname): + return open(os.path.join(os.path.dirname(__file__), fname)).read() + + +class MakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class MakeBuild(build_ext): + def run(self): + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + subprocess.check_call(["make", "python-bindings"]) + + +setuptools.setup( + name="concretefhe-compiler", + version="0.1.0", + author="Zama Team", + author_email="hello@zama.ai", + description="Concrete Compiler", + license="", + keywords="homomorphic encryption compiler", + long_description=read("README.md"), + long_description_content_type="text/markdown", + url="https://github.com/zama-ai/homomorphizer", + packages=setuptools.find_packages( + where="build/tools/zamalang/python_packages/zamalang_core", + include=["zamalang", "zamalang.*", "mlir", "mlir.*"], + ), + package_dir={"": "build/tools/zamalang/python_packages/zamalang_core"}, + include_package_data=True, + package_data={"": ["*.so"]}, + classifiers=[ + "Programming Language :: C++", + "Programming Language :: Python :: 3", + "Topic :: Software Development :: Compilers", + "Topic :: Security :: Cryptography", + ], + ext_modules=[MakeExtension("python-bindings")], + cmdclass=dict(build_ext=MakeBuild), + zip_safe=False, +) From ab7a208112a3815a1be8412a3e821b6ea1b41dc1 Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 22 Oct 2021 10:48:57 +0100 Subject: [PATCH 04/19] fix: store OpPassManager& before call to addPass weird bug with c++ 6.3 on the manylinux_2_24 image (Debian9) generating erroneous asm instructions for call to nest on PassManager --- compiler/lib/Support/Pipeline.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 69b0e39b3..64880551f 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -24,7 +24,8 @@ static void addPotentiallyNestedPass(mlir::PassManager &pm, if (!pass->getOpName() || *pass->getOpName() == "builtin.module") { pm.addPass(std::move(pass)); } else { - pm.nest(*pass->getOpName()).addPass(std::move(pass)); + mlir::OpPassManager &p = pm.nest(*pass->getOpName()); + p.addPass(std::move(pass)); } } From 527887bbf9452803d0b17a2af68f1efa41a9d397 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 19 Oct 2021 21:48:55 +0200 Subject: [PATCH 05/19] fix(compiler): Makefile: Do not let target 'build' depend on directory The target `build` creates a build directory with the same name and initializes through an invocation of CMake. Regardless of the success or failure of the CMake invocation, all subsequent invocations of the target do not invoke CMake anymore, as the target's prerequisites are satisfied through the existence of the build directory created upon the first invocation. This patch changes the dependencies to the build directory with an intermediate target that depends on a stamp file that is only created when the first CMake invocation in the build directory succeeds. --- compiler/Makefile | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/compiler/Makefile b/compiler/Makefile index 325753c13..5ba032d56 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -2,7 +2,7 @@ BUILD_DIR=./build Python3_EXECUTABLE= -build: +$(BUILD_DIR)/configured.stamp: cmake -B $(BUILD_DIR) -GNinja ../llvm-project/llvm/ \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_BUILD_EXAMPLES=OFF \ @@ -15,14 +15,17 @@ build: -DLLVM_EXTERNAL_PROJECTS=zamalang \ -DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=. \ -DPython3_EXECUTABLE=${Python3_EXECUTABLE} + touch $@ -build-end-to-end-jit: build +build-initialized: $(BUILD_DIR)/configured.stamp + +build-end-to-end-jit: build-initialized cmake --build $(BUILD_DIR) --target end_to_end_jit_test -zamacompiler: build +zamacompiler: build-initialized cmake --build $(BUILD_DIR) --target zamacompiler -python-bindings: build +python-bindings: build-initialized cmake --build $(BUILD_DIR) --target ZamalangMLIRPythonModules ZamalangPythonModules test-check: zamacompiler file-check not From e7b258263914e9a876809ab8299a07955089e892 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 19 Oct 2021 21:56:26 +0200 Subject: [PATCH 06/19] fix(compiler): Makefile: Invoke CMake for each Python target separately The Makefile target `python-bindings` invokes CMake with multiple targets specified after the `--target` commandline option. However, as per the CMake manpage, only one target may be specified at once. This changes the single invocation of CMake to separate invocations for each target. Tested with CMake version 3.13.4. --- compiler/Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compiler/Makefile b/compiler/Makefile index 5ba032d56..7055fd1dc 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -26,7 +26,8 @@ zamacompiler: build-initialized cmake --build $(BUILD_DIR) --target zamacompiler python-bindings: build-initialized - cmake --build $(BUILD_DIR) --target ZamalangMLIRPythonModules ZamalangPythonModules + cmake --build $(BUILD_DIR) --target ZamalangMLIRPythonModules + cmake --build $(BUILD_DIR) --target ZamalangPythonModules test-check: zamacompiler file-check not $(BUILD_DIR)/bin/llvm-lit -v tests/ From 2c63018ed2a898f540b7e97b30667523bc5eed27 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 26 Oct 2021 16:49:09 +0200 Subject: [PATCH 07/19] fix(compiler): Makefile: Make targets without file dependencies PHONY Most of the targets in `Makefile` do not deped on files produced by other targets and use target names solely for dependency management. Make all such targets PHONY in order to avoid that they are skipped accidentially when a file with the same name is present. --- compiler/Makefile | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/compiler/Makefile b/compiler/Makefile index 7055fd1dc..24b61831d 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -64,3 +64,18 @@ package_py39: package_py310: $(call build_image_and_copy_wheels,cp310-cp310) + +.PHONY: build-initialized \ + build-end-to-end-jit \ + zamacompiler \ + python-bindings \ + test-check \ + test-end-to-end-jit \ + test-python \ + test \ + add-deps \ + file-check \ + not \ + package_py38 \ + package_py39 \ + package_py310 From dc2d6a362e1300aef4b586eeb20e2b3068046db2 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 28 Oct 2021 09:09:07 +0100 Subject: [PATCH 08/19] chore: stop using "build" target in CI --- .github/workflows/concrete-lib-compatibility.yml | 2 +- builders/Dockerfile.release_manylinux_2_24_x86_64 | 3 +-- builders/Dockerfile.zamalang-env | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/concrete-lib-compatibility.yml b/.github/workflows/concrete-lib-compatibility.yml index 1082fb054..bf8d114e0 100644 --- a/.github/workflows/concrete-lib-compatibility.yml +++ b/.github/workflows/concrete-lib-compatibility.yml @@ -53,7 +53,7 @@ jobs: cd /compiler pip install pytest export CONCRETE_PROJECT=/concrete - make -B BUILD_DIR=/build build + make -B BUILD_DIR=/build build-initialized make BUILD_DIR=/build test - name: Send Slack Notification diff --git a/builders/Dockerfile.release_manylinux_2_24_x86_64 b/builders/Dockerfile.release_manylinux_2_24_x86_64 index efd00a0fd..ed539a31f 100644 --- a/builders/Dockerfile.release_manylinux_2_24_x86_64 +++ b/builders/Dockerfile.release_manylinux_2_24_x86_64 @@ -14,8 +14,7 @@ ENV CONCRETE_PROJECT=/concrete # Setup and build compiler COPY /compiler /compiler WORKDIR /compiler -RUN make Python3_EXECUTABLE=/opt/python/${python_tag}/bin/python build -RUN make python-bindings +RUN make Python3_EXECUTABLE=/opt/python/${python_tag}/bin/python python-bindings # Fix MLIR package RUN touch build/tools/zamalang/python_packages/zamalang_core/mlir/__init__.py RUN touch build/tools/zamalang/python_packages/zamalang_core/mlir/dialects/__init__.py diff --git a/builders/Dockerfile.zamalang-env b/builders/Dockerfile.zamalang-env index 471e62bc5..3ab5f138f 100644 --- a/builders/Dockerfile.zamalang-env +++ b/builders/Dockerfile.zamalang-env @@ -12,7 +12,6 @@ COPY /llvm-project /llvm-project COPY /compiler /compiler WORKDIR /compiler RUN mkdir -p /build -RUN make BUILD_DIR=/build -B build RUN make BUILD_DIR=/build zamacompiler python-bindings ENV PYTHONPATH "$PYTHONPATH:/build/tools/zamalang/python_packages/zamalang_core:/build/tools/zamalang/python_packages/zamalang_core/mlir/_mlir_libs/" ENV PATH "$PATH:/build/bin" From 941465060e0ff471bf4541ab18b7fd26b91edd52 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 28 Oct 2021 12:13:19 +0100 Subject: [PATCH 09/19] build: setup target and tools for building tarballs --- README.md | 7 ++++++ .../Dockerfile.release_tarball_linux_x86_64 | 23 +++++++++++++++++++ compiler/Makefile | 12 +++++++--- 3 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 builders/Dockerfile.release_tarball_linux_x86_64 diff --git a/README.md b/README.md index d0d8547e5..26eb050da 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,14 @@ The homomorphizer is a compiler that takes a high level computation model and produces a programs that evaluate the model in an homomorphic way. +## Build tarball +The final tarball contains intallation instructions. We only support Linux x86_64 for the moment. You can find the output tarball under `/tarballs`. + +```bash +$ cd compiler +$ make release_tarballs +``` ## Build the Python Package diff --git a/builders/Dockerfile.release_tarball_linux_x86_64 b/builders/Dockerfile.release_tarball_linux_x86_64 new file mode 100644 index 000000000..ae35dd736 --- /dev/null +++ b/builders/Dockerfile.release_tarball_linux_x86_64 @@ -0,0 +1,23 @@ +FROM quay.io/pypa/manylinux_2_24_x86_64 + +RUN apt-get update +RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y build-essential ninja-build +# Setup LLVM +COPY /llvm-project /llvm-project +# Setup Concrete +COPY --from=ghcr.io/zama-ai/concrete-api-env:latest /target/release /concrete/target/release +ENV CONCRETE_PROJECT=/concrete +# Setup and build compiler +COPY /compiler /compiler +WORKDIR /compiler +RUN make BINDINGS_PYTHON_ENABLED=OFF zamacompiler +# Build tarball +RUN mkdir -p /tarballs/zamacompiler/lib /tarballs/zamacompiler/bin && \ + cp /compiler/build/bin/zamacompiler /tarballs/zamacompiler/bin && \ + cp /compiler/build/lib/libZamalangRuntime.so /tarballs/zamacompiler/lib +RUN echo "# Installation\n"\ + "You can install the compiler by either:\n"\ + "1. Extracting the tarball as is somewhere of your choosing, and add /path/to/tarball/zamacompiler/bin to your \$PATH\n"\ + "2. Extracting the tarball and putting the bin/zamacompiler into a path already in your \$PATH, and lib/libZamalangRuntime.so into one of your lib folders (e.g /usr/lib)"\ + >> /tarballs/zamacompiler/Installation.md +RUN cd /tarballs && tar -czvf zamacompiler.tar.gz zamacompiler \ No newline at end of file diff --git a/compiler/Makefile b/compiler/Makefile index 24b61831d..3c6a22498 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -1,5 +1,6 @@ BUILD_DIR=./build Python3_EXECUTABLE= +BINDINGS_PYTHON_ENABLED=ON $(BUILD_DIR)/configured.stamp: @@ -9,8 +10,8 @@ $(BUILD_DIR)/configured.stamp: -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=Release \ -DLLVM_ENABLE_ASSERTIONS=ON \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DZAMALANG_BINDINGS_PYTHON_ENABLED=ON \ + -DMLIR_ENABLE_BINDINGS_PYTHON=$(BINDINGS_PYTHON_ENABLED) \ + -DZAMALANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \ -DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release \ -DLLVM_EXTERNAL_PROJECTS=zamalang \ -DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=. \ @@ -65,6 +66,10 @@ package_py39: package_py310: $(call build_image_and_copy_wheels,cp310-cp310) +release_tarballs: + docker image build -t concretefhe-compiler-manylinux:linux_x86_64_tarball -f ../builders/Dockerfile.release_tarball_linux_x86_64 .. + docker container run --rm -v ${PWD}/../tarballs:/tarballs_volume concretefhe-compiler-manylinux:linux_x86_64_tarball cp -r /tarballs/. /tarballs_volume/. + .PHONY: build-initialized \ build-end-to-end-jit \ zamacompiler \ @@ -78,4 +83,5 @@ package_py310: not \ package_py38 \ package_py39 \ - package_py310 + package_py310 \ + release_tarballs From 0423a05db8428c1aef45f131bd089345679c9f6f Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 15 Oct 2021 14:35:08 +0200 Subject: [PATCH 10/19] feat(compiler): Add support for tensor.extract operations in MANP pass Add support for `tensor.extract` operations in the MANP pass. This currently only supports extract operations on tensors of encrypted integers, which are passed as function arguments, e.g.: func @extract_ith(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{ %c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>> return %c : !HLFHE.eint<5> } --- compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 57 ++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 17e106fa6..7161d22b0 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -77,6 +79,19 @@ protected: llvm::Optional manp; }; +// Checks if `lhs` is equal to `rhs`, where both values are assumed to +// be positive. The bit width of the smaller `APInt` is extended +// before comparison via `APInt::operator==`. +static bool APIntWidthExtendCompare(const llvm::APInt &lhs, + const llvm::APInt &rhs) { + if (lhs.getBitWidth() < rhs.getBitWidth()) + return lhs.zext(rhs.getBitWidth()) == rhs; + else if (lhs.getBitWidth() > rhs.getBitWidth()) + return lhs == rhs.zext(lhs.getBitWidth()); + else + return lhs == rhs; +} + // Checks if `lhs` is less than `rhs`, where both values are assumed // to be positive. The bit width of the smaller `APInt` is extended // before comparison via `APInt::ult`. @@ -305,6 +320,37 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUAdd(a, b); } +// Calculates the squared Minimal Arithmetic Noise Padding of a dot +// operation that is equivalent to an `tensor.extract` +// operation. Currently, this only supports extractions of elements +// from tensors passed as function arguments, for which the MANP is +// assumed to be 1. +static llvm::APInt getSqMANP( + mlir::tensor::ExtractOp op, + llvm::ArrayRef *> operandMANPs) { + mlir::zamalang::HLFHE::EncryptedIntegerType elTy = + op.getOperand(0) + .getType() + .dyn_cast() + .getElementType() + .dyn_cast_or_null(); + + assert(elTy && "Can only calculate MANP for tensor.extract operations on " + "HLFHE.eint tensors"); + + assert(operandMANPs.size() >= 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "MANP value for tensor is unknown"); + + llvm::APInt one{1, 1, false}; + + assert(APIntWidthExtendCompare( + operandMANPs[0]->getValue().getMANP().getValue(), one) && + "MANP value for tensor is not 1 as expected"); + + return one; +} + // Calculates the squared Minimal Arithmetic Noise Padding of a dot operation // that is equivalent to an `HLFHE.sub_int_eint` operation. static llvm::APInt getSqMANP( @@ -404,6 +450,17 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; + } else if (auto tensorExtractOp = + llvm::dyn_cast(op)) { + // Only handle extract operations that produce an encrypted integer + if (tensorExtractOp->getResultTypes() + .front() + .dyn_cast_or_null< + mlir::zamalang::HLFHE::EncryptedIntegerType>()) { + norm2SqEquiv = getSqMANP(tensorExtractOp, operands); + } else { + isDummy = true; + } } else if (llvm::isa(op)) { isDummy = true; } else if (llvm::isa( From d4b4839d6e631f25511785e5f8227d690f5634b8 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 15 Oct 2021 15:13:35 +0200 Subject: [PATCH 11/19] fix(compiler): Take into account function parameters in MaxMANPPass When determining the maximum MANP and precision, the `MaxMANPPass` only takes into account results generated by an operation, but ignores function parameters. However, encrypted function parameters are assumed to have a MANP value of 1 and can have an arbitrary precision. This patch takes into account function arguments by using their default MANP values and the extracting their precision. --- compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 83 +++++++++++++++++--- 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 7161d22b0..ba1b44e4d 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -24,6 +24,51 @@ namespace mlir { namespace zamalang { namespace { + +// Returns `true` if the given value is a scalar or tensor argument of +// a function, for which a MANP of 1 can be assumed. +static bool isEncryptedFunctionParameter(mlir::Value value) { + if (!value.isa()) + return false; + + mlir::Block *block = value.cast().getOwner(); + + if (!block || !block->getParentOp() || + !llvm::isa(block->getParentOp())) { + return false; + } + + return (value.getType().isa() || + (value.getType().isa() && + value.getType() + .cast() + .getElementType() + .isa())); +} + +// Returns the bit width of `value` if `value` is an encrypted integer +// or the bit width of the elements if `value` is a tensor of +// encrypted integers. +static unsigned int getEintPrecision(mlir::Value value) { + if (auto ty = value.getType() + .dyn_cast_or_null< + mlir::zamalang::HLFHE::EncryptedIntegerType>()) { + return ty.getWidth(); + } else if (auto tensorTy = + value.getType().dyn_cast_or_null()) { + if (auto ty = tensorTy.getElementType() + .dyn_cast_or_null< + mlir::zamalang::HLFHE::EncryptedIntegerType>()) + return ty.getWidth(); + } + + assert(false && + "Value is neither an encrypted integer nor a tensor of encrypted " + "integers"); + + return 0; +} + // The `MANPLatticeValue` represents the squared Minimal Arithmetic // Noise Padding for an operation using the squared 2-norm of an // equivalent dot operation. This can either be an actual value if the @@ -42,13 +87,7 @@ struct MANPLatticeValue { // // TODO: Provide a mechanism to propagate Minimal Arithmetic Noise // Padding across function calls. - if (value.isa() && - (value.getType().isa() || - (value.getType().isa() && - value.getType() - .cast() - .getElementType() - .isa()))) { + if (isEncryptedFunctionParameter(value)) { return MANPLatticeValue(llvm::APInt{1, 1, false}); } else { // All other operations have an unknown Minimal Arithmetic Noise @@ -541,13 +580,35 @@ struct MaxMANPPass : public MaxMANPBase { protected: void processOperation(mlir::Operation *op) { + static const llvm::APInt one{1, 1, false}; + bool upd = false; + + // Process all function arguments and use the default value of 1 + // for MANP and the declarend precision + if (mlir::FuncOp func = llvm::dyn_cast_or_null(op)) { + for (mlir::BlockArgument blockArg : func.getBody().getArguments()) { + if (isEncryptedFunctionParameter(blockArg)) { + unsigned int width = getEintPrecision(blockArg); + + if (this->maxEintWidth < width) { + this->maxEintWidth = width; + } + + if (APIntWidthExtendULT(this->maxMANP, one)) { + this->maxMANP = one; + upd = true; + } + } + } + } + + // Process all results using MANP attribute from MANP pas for (mlir::OpResult res : op->getResults()) { mlir::zamalang::HLFHE::EncryptedIntegerType eTy = res.getType() .dyn_cast_or_null(); if (eTy) { - bool upd = false; if (this->maxEintWidth < eTy.getWidth()) { this->maxEintWidth = eTy.getWidth(); upd = true; @@ -564,11 +625,11 @@ protected: this->maxMANP = MANP.getValue(); upd = true; } - - if (upd) - this->updateMax(this->maxMANP, this->maxEintWidth); } } + + if (upd) + this->updateMax(this->maxMANP, this->maxEintWidth); } std::function updateMax; From b12be451433a1687d2343275e7bb88cc2fc15931 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 18 Oct 2021 11:04:48 +0200 Subject: [PATCH 12/19] feat(compiler): Add method getResultVectorSize to JITLambda::Argument Add method `JITLambda::Argument::getResultVectorSize` that returns the number of elements of the result if the result is a vector. --- compiler/include/zamalang/Support/Jit.h | 4 ++++ compiler/lib/Support/Jit.cpp | 14 ++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index c7358ae77..34212a8e7 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -57,6 +57,10 @@ public: // Fill the result. llvm::Error getResult(size_t pos, uint64_t *res, size_t size); + // Returns the number of elements of the result vector at position + // `pos` or an error if the result is a scalar value + llvm::Expected getResultVectorSize(size_t pos); + private: llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 95be53411..85ab88aaf 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -344,6 +344,20 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { return llvm::Error::success(); } +// Returns the number of elements of the result vector at position +// `pos` or an error if the result is a scalar value +llvm::Expected JITLambda::Argument::getResultVectorSize(size_t pos) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + + if (info.shape.size == 0) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Result at pos %zu is not a tensor", pos); + } + + return info.shape.size; +} + llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, size_t size) { auto gate = outputGates[pos]; From e76aee7e10ac7ef183d64e4f024a1791d6342ae5 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 18 Oct 2021 11:20:29 +0200 Subject: [PATCH 13/19] feat(compiler): Add class StreamStringError with a stream interface for llvm::Error Composing error messages for `llvm::Error` is either done by using `llvm::createStringError()` with an appropriate format string and arguments or by writing to a `std::string`-backed `llvm::raw_string_ostream` and passing the result to `llvm::make_error()` verbatim. The new class `StreamStringError` encapsulates the latter solution into a class with an appropriate stream operator and implicit cast operators to `llvm::Error` and `llvm::Expected`. Example usage: llvm::Error foo(int i, size_t s, ...) { ... if(...) { return StreamStringError() << "Some error message with an integer: " << i << " and a size_t: " << s; } ... } --- compiler/include/zamalang/Support/Error.h | 53 +++++++++++++++++++++++ compiler/lib/Support/CMakeLists.txt | 1 + compiler/lib/Support/Error.cpp | 12 +++++ 3 files changed, 66 insertions(+) create mode 100644 compiler/include/zamalang/Support/Error.h create mode 100644 compiler/lib/Support/Error.cpp diff --git a/compiler/include/zamalang/Support/Error.h b/compiler/include/zamalang/Support/Error.h new file mode 100644 index 000000000..633b0e716 --- /dev/null +++ b/compiler/include/zamalang/Support/Error.h @@ -0,0 +1,53 @@ +#ifndef ZAMALANG_SUPPORT_STRING_ERROR_H +#define ZAMALANG_SUPPORT_STRING_ERROR_H + +#include + +namespace mlir { +namespace zamalang { + +// Internal error class that allows for composing `llvm::Error`s +// similar to `llvm::createStringError()`, but using stream-like +// composition with `operator<<`. +// +// Example: +// +// llvm::Error foo(int i, size_t s, ...) { +// ... +// if(...) { +// return StreamStringError() +// << "Some error message with an integer: " +// << i << " and a size_t: " << s; +// } +// ... +// } +class StreamStringError { +public: + StreamStringError(const llvm::StringRef &s) : buffer(s.str()), os(buffer){}; + StreamStringError() : buffer(""), os(buffer){}; + + template StreamStringError &operator<<(const T &v) { + this->os << v; + return *this; + } + + operator llvm::Error() { + return llvm::make_error(os.str(), + llvm::inconvertibleErrorCode()); + } + + template operator llvm::Expected() { + return this->operator llvm::Error(); + } + +protected: + std::string buffer; + llvm::raw_string_ostream os; +}; + +StreamStringError &operator<<(StreamStringError &se, llvm::Error &err); + +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 82c3077ca..9694989d3 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(ZamalangSupport + Error.cpp Pipeline.cpp Jit.cpp CompilerEngine.cpp diff --git a/compiler/lib/Support/Error.cpp b/compiler/lib/Support/Error.cpp new file mode 100644 index 000000000..32cfa6399 --- /dev/null +++ b/compiler/lib/Support/Error.cpp @@ -0,0 +1,12 @@ +#include + +namespace mlir { +namespace zamalang { +// Specialized `operator<<` for `llvm::Error` that marks the error +// as checked through `std::move` and `llvm::toString` +StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) { + se << llvm::toString(std::move(err)); + return se; +} +} // namespace zamalang +} // namespace mlir From 3ae924e17471a9575da6e3af2194e661c948dca3 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 25 Oct 2021 11:02:48 +0200 Subject: [PATCH 14/19] enhance(compiler): Disable RTTI for unit tests LLVM and MLIR are compiled without runtime type information (RTTI). Use the same restrictions for unit tests to avoid linker errors related to typeinfo when building the test executable. --- compiler/tests/unittest/CMakeLists.txt | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/compiler/tests/unittest/CMakeLists.txt b/compiler/tests/unittest/CMakeLists.txt index 70efefca6..dd30a72d9 100644 --- a/compiler/tests/unittest/CMakeLists.txt +++ b/compiler/tests/unittest/CMakeLists.txt @@ -2,11 +2,14 @@ enable_testing() include_directories(${PROJECT_SOURCE_DIR}/include) - add_executable( end_to_end_jit_test end_to_end_jit_test.cc -) + ) + +set_source_files_properties( + end_to_end_jit_test.cc PROPERTIES COMPILE_FLAGS "-fno-rtti") + target_link_libraries( end_to_end_jit_test gtest_main From d738104c4b5f787861939f2854056ef016068a4c Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 28 Oct 2021 15:55:54 +0100 Subject: [PATCH 15/19] fix: use std::string for JIT entrypoint funcname As we store the funcname for the entrypoint for later use, a pointer might point to some random memory when used as there is no special management for that name at the higher levels. --- compiler/include/zamalang/Support/Jit.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 34212a8e7..835749b0f 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -104,7 +104,7 @@ public: private: mlir::LLVM::LLVMFunctionType type; - llvm::StringRef name; + std::string name; std::unique_ptr engine; }; From 1187cfbd623f453bb9236362f3151dbc7e1d13ea Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Mon, 18 Oct 2021 15:38:12 +0200 Subject: [PATCH 16/19] refactor(compiler): Refactor CompilerEngine and related classes This commit contains several incremental improvements towards a clear interface for lambdas: - Unification of static and JIT compilation by using the static compilation path of `CompilerEngine` within a new subclass `JitCompilerEngine`. - Clear ownership for compilation artefacts through `CompilationContext`, making it impossible to destroy objects used directly or indirectly before destruction of their users. - Clear interface for lambdas generated by the compiler through `JitCompilerEngine::Lambda` with a templated call operator, encapsulating otherwise manual orchestration of `CompilerEngine`, `JITLambda`, and `CompilerEngine::Argument`. - Improved error handling through `llvm::Expected` and proper error checking following the conventions for `llvm::Expected` and `llvm::Error`. Co-authored-by: youben11 --- .../zamalang-c/Support/CompilerEngine.h | 19 +- .../include/zamalang/Support/CompilerEngine.h | 147 +++- compiler/include/zamalang/Support/Jit.h | 5 - .../zamalang/Support/JitCompilerEngine.h | 296 +++++++ .../include/zamalang/Support/LambdaArgument.h | 157 ++++ .../lib/Bindings/Python/CompilerAPIModule.cpp | 50 +- .../lib/Bindings/Python/zamalang/compiler.py | 29 +- compiler/lib/CAPI/Support/CompilerEngine.cpp | 107 ++- compiler/lib/Support/CMakeLists.txt | 2 + compiler/lib/Support/CompilerEngine.cpp | 474 ++++++++--- compiler/lib/Support/Jit.cpp | 51 +- compiler/lib/Support/JitCompilerEngine.cpp | 105 +++ compiler/lib/Support/LambdaArgument.cpp | 7 + compiler/src/main.cpp | 408 +++------- .../Conversion/HLFHEToMidLFHE/add_eint.mlir | 2 +- .../HLFHEToMidLFHE/add_eint_int.mlir | 2 +- .../HLFHEToMidLFHE/apply_univariate.mlir | 2 +- .../HLFHEToMidLFHE/apply_univariate_cst.mlir | 2 +- .../HLFHEToMidLFHE/linalg_generic.mlir | 2 +- .../HLFHEToMidLFHE/mul_eint_int.mlir | 2 +- .../HLFHEToMidLFHE/sub_int_eint.mlir | 2 +- .../LowLFHEToConcreteCAPI/bootstrap.mlir | 2 +- .../glwe_from_table.mlir | 2 +- .../LowLFHEToConcreteCAPI/keyswitch_lwe.mlir | 2 +- .../Conversion/MidLFHEToLowLFHE/add_glwe.mlir | 2 +- .../MidLFHEToLowLFHE/add_glwe_int.mlir | 2 +- .../MidLFHEToLowLFHE/apply_lookup_table.mlir | 2 +- .../apply_lookup_table_cst.mlir | 2 +- .../MidLFHEToLowLFHE/mul_glwe_int.mlir | 2 +- .../MidLFHEToLowLFHE/sub_int_glwe.mlir | 2 +- .../tests/Dialect/HLFHE/Analysis/MANP.mlir | 2 +- compiler/tests/Dialect/HLFHE/dot.invalid.mlir | 2 +- .../Dialect/HLFHE/eint_error_p_too_big.mlir | 2 +- .../Dialect/HLFHE/eint_error_p_too_small.mlir | 2 +- .../Dialect/HLFHE/op_add_eint_err_inputs.mlir | 2 +- .../Dialect/HLFHE/op_add_eint_err_result.mlir | 2 +- .../HLFHE/op_add_eint_int_err_inputs.mlir | 2 +- .../HLFHE/op_add_eint_int_err_result.mlir | 2 +- .../op_apply_lookup_table_bad_dimension.mlir | 2 +- .../HLFHE/op_mul_eint_int_err_inputs.mlir | 2 +- .../HLFHE/op_mul_eint_int_err_result.mlir | 2 +- .../HLFHE/op_sub_int_eint_err_inputs.mlir | 2 +- .../HLFHE/op_sub_int_eint_err_result.mlir | 2 +- compiler/tests/Dialect/HLFHE/ops.mlir | 2 +- .../Dialect/HLFHE/tensor-ops-to-linalg.mlir | 2 +- compiler/tests/Dialect/HLFHE/types.mlir | 2 +- compiler/tests/Dialect/LowLFHE/ops.mlir | 2 +- compiler/tests/Dialect/LowLFHE/types.mlir | 2 +- .../Dialect/MidLFHE/op_add_glwe.invalid.mlir | 2 +- .../tests/Dialect/MidLFHE/op_add_glwe.mlir | 2 +- .../MidLFHE/op_add_glwe_int.invalid.mlir | 2 +- .../Dialect/MidLFHE/op_add_glwe_int.mlir | 2 +- .../op_apply_lookup_table.invalid.mlir | 2 +- .../MidLFHE/op_apply_lookup_table.mlir | 2 +- .../MidLFHE/op_mul_glwe_int.invalid.mlir | 2 +- .../Dialect/MidLFHE/op_mul_glwe_int.mlir | 2 +- .../MidLFHE/op_sub_int_glwe.invalid.mlir | 2 +- .../Dialect/MidLFHE/op_sub_int_glwe.mlir | 2 +- .../tests/Dialect/MidLFHE/types_glwe.mlir | 2 +- compiler/tests/python/test_compiler_engine.py | 2 +- .../tests/unittest/end_to_end_jit_test.cc | 738 +++++++++--------- 61 files changed, 1690 insertions(+), 997 deletions(-) create mode 100644 compiler/include/zamalang/Support/JitCompilerEngine.h create mode 100644 compiler/include/zamalang/Support/LambdaArgument.h create mode 100644 compiler/lib/Support/JitCompilerEngine.cpp create mode 100644 compiler/lib/Support/LambdaArgument.cpp diff --git a/compiler/include/zamalang-c/Support/CompilerEngine.h b/compiler/include/zamalang-c/Support/CompilerEngine.h index d7e4dbd8e..834b30c50 100644 --- a/compiler/include/zamalang-c/Support/CompilerEngine.h +++ b/compiler/include/zamalang-c/Support/CompilerEngine.h @@ -5,15 +5,17 @@ #include "mlir-c/Registration.h" #include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/ExecutionArgument.h" +#include "zamalang/Support/Jit.h" +#include "zamalang/Support/JitCompilerEngine.h" #ifdef __cplusplus extern "C" { #endif -struct compilerEngine { - mlir::zamalang::CompilerEngine *ptr; +struct lambda { + mlir::zamalang::JitCompilerEngine::Lambda *ptr; }; -typedef struct compilerEngine compilerEngine; +typedef struct lambda lambda; struct executionArguments { mlir::zamalang::ExecutionArgument *data; @@ -21,13 +23,12 @@ struct executionArguments { }; typedef struct executionArguments exectuionArguments; -// Compile an MLIR module -MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine, - const char *module); +MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda +buildLambda(const char *module, const char *funcName); -// Run the compiled module -MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e, - executionArguments args); +MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args); + +MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); #ifdef __cplusplus } diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index f7dbc981f..e9c496b1a 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -1,49 +1,138 @@ #ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H #define ZAMALANG_SUPPORT_COMPILER_ENGINE_H -#include "Jit.h" +#include +#include +#include +#include +#include +#include +#include +#include namespace mlir { namespace zamalang { -/// CompilerEngine is an tools that provides tools to implements the compilation -/// flow and manage the compilation flow state. +// Compilation context that acts as the root owner of LLVM and MLIR +// data structures directly and indirectly referenced by artefacts +// produced by the `CompilerEngine`. +class CompilationContext { +public: + CompilationContext(); + ~CompilationContext(); + + mlir::MLIRContext *getMLIRContext(); + llvm::LLVMContext *getLLVMContext(); + + static std::shared_ptr createShared(); + +protected: + mlir::MLIRContext *mlirContext; + llvm::LLVMContext *llvmContext; +}; + class CompilerEngine { public: - CompilerEngine() { - context = new mlir::MLIRContext(); - loadDialects(); - } - ~CompilerEngine() { - if (context != nullptr) - delete context; - } + // Result of an invocation of the `CompilerEngine` with optional + // fields for the results produced by different stages. + class CompilationResult { + public: + CompilationResult(std::shared_ptr compilationContext = + CompilationContext::createShared()) + : compilationContext(compilationContext) {} - // Compile an mlir programs from it's textual representation. - llvm::Error compile( - std::string mlirStr, - llvm::Optional overrideConstraints = {}); + llvm::Optional mlirModuleRef; + llvm::Optional clientParameters; + std::unique_ptr keySet; + std::unique_ptr llvmModule; + llvm::Optional fheContext; - // Build the jit lambda argument. - llvm::Expected> buildArgument(); + protected: + std::shared_ptr compilationContext; + }; - // Call the compiled function with and argument object. - llvm::Error invoke(JITLambda::Argument &arg); + // Specification of the exit stage of the compilation pipeline + enum class Target { + // Only read sources and produce corresponding MLIR module + ROUND_TRIP, - // Call the compiled function with a list of integer arguments. - llvm::Expected run(std::vector args); + // Read sources and exit before any lowering + HLFHE, - // Get a printable representation of the compiled module - std::string getCompiledModule(); + // Read sources and attempt to run the Minimal Arithmetic Noise + // Padding pass + HLFHE_MANP, + + // Read sources and lower all HLFHE operations to MidLFHE + // operations + MIDLFHE, + + // Read sources and lower all HLFHE and MidLFHE operations to LowLFHE + // operations + LOWLFHE, + + // Read sources and lower all HLFHE, MidLFHE and LowLFHE + // operations to canonical MLIR dialects. Cryptographic operations + // are lowered to invocations of the concrete library. + STD, + + // Read sources and lower all HLFHE, MidLFHE and LowLFHE + // operations to operations from the LLVM dialect. Cryptographic + // operations are lowered to invocations of the concrete library. + LLVM, + + // Same as `LLVM`, but lowers to actual LLVM IR instead of the + // LLVM dialect + LLVM_IR, + + // Same as `LLVM_IR`, but invokes the LLVM optimization pipeline + // to produce optimized LLVM IR + OPTIMIZED_LLVM_IR + }; + + CompilerEngine(std::shared_ptr compilationContext) + : overrideMaxEintPrecision(), overrideMaxMANP(), + clientParametersFuncName(), verifyDiagnostics(false), + generateKeySet(false), generateClientParameters(false), + parametrizeMidLFHE(true), compilationContext(compilationContext) {} + + llvm::Expected compile(llvm::StringRef s, Target target); + + llvm::Expected + compile(std::unique_ptr buffer, Target target); + + llvm::Expected compile(llvm::SourceMgr &sm, Target target); + + void setFHEConstraints(const mlir::zamalang::V0FHEConstraint &c); + void setMaxEintPrecision(size_t v); + void setMaxMANP(size_t v); + void setVerifyDiagnostics(bool v); + void setGenerateKeySet(bool v); + void setGenerateClientParameters(bool v); + void setParametrizeMidLFHE(bool v); + void setClientParametersFuncName(const llvm::StringRef &name); + +protected: + llvm::Optional overrideMaxEintPrecision; + llvm::Optional overrideMaxMANP; + llvm::Optional clientParametersFuncName; + bool verifyDiagnostics; + bool generateKeySet; + bool generateClientParameters; + bool parametrizeMidLFHE; + + std::shared_ptr compilationContext; + + // Helper enum identifying an FHE dialect (`HLFHE`, `MIDLFHE`, `LOWLFHE`) + // or indicating that no FHE dialect is used (`NONE`). + enum class FHEDialect { HLFHE, MIDLFHE, LOWLFHE, NONE }; + static FHEDialect detectHighestFHEDialect(mlir::ModuleOp module); private: - // Load the necessary dialects into the engine's context - void loadDialects(); - - mlir::OwningModuleRef module_ref; - mlir::MLIRContext *context; - std::unique_ptr keySet; + llvm::Error lowerParamDependentHalf(Target target, CompilationResult &res); + llvm::Error determineFHEParameters(CompilationResult &res, bool noOverride); }; + } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 835749b0f..7c22a117c 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -9,11 +9,6 @@ namespace mlir { namespace zamalang { -mlir::LogicalResult -runJit(mlir::ModuleOp module, llvm::StringRef func, - llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, - std::function optPipeline, - llvm::raw_ostream &os); /// JITLambda is a tool to JIT compile an mlir module and to invoke a function /// of the module. diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h new file mode 100644 index 000000000..8957c0bb3 --- /dev/null +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -0,0 +1,296 @@ +#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H +#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H + +#include +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { + +namespace { +// Generic function template as well as specializations of +// `typedResult` must be declared at namespace scope due to return +// type template specialization + +// Helper function for `JitCompilerEngine::Lambda::operator()` +// implementing type-dependent preparation of the result. +template +llvm::Expected typedResult(JITLambda::Argument &arguments); + +// Specialization of `typedResult()` for scalar results, forwarding +// scalar value to caller +template <> +inline llvm::Expected typedResult(JITLambda::Argument &arguments) { + uint64_t res = 0; + + if (auto err = arguments.getResult(0, res)) + return StreamStringError() << "Cannot retrieve result:" << err; + + return res; +} + +// Specialization of `typedResult()` for vector results, initializing +// an `std::vector` of the right size with the results and forwarding +// it to the caller with move semantics. +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + llvm::Expected n = arguments.getResultVectorSize(0); + + if (auto err = n.takeError()) + return std::move(err); + + std::vector res(*n); + + if (auto err = arguments.getResult(0, res.data(), res.size())) + return StreamStringError() << "Cannot retrieve result:" << err; + + return std::move(res); +} + +// Adaptor class that adds arguments specified as instances of +// `LambdaArgument` to `JitLambda::Argument`. +class JITLambdaArgumentAdaptor { +public: + // Checks if the argument `arg` is an plaintext / encrypted integer + // argument or a plaintext / encrypted tensor argument with a + // backing integer type `IntT` and adds the argument to `jla` at + // position `pos`. + // + // Returns `true` if `arg` has one of the types above and its value + // was successfully added to `jla`, `false` if none of the types + // matches or an error if a type matched, but adding the argument to + // `jla` failed. + template + static inline llvm::Expected + tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { + if (auto ila = arg.dyn_cast>()) { + if (llvm::Error err = jla.setArg(pos, ila->getValue())) + return std::move(err); + else + return true; + } else if (auto tla = arg.dyn_cast< + TensorLambdaArgument>>()) { + llvm::Expected numElements = tla->getNumElements(); + + if (!numElements) + return std::move(numElements.takeError()); + + if (llvm::Error err = jla.setArg(pos, tla->getValue(), *numElements)) + return std::move(err); + else + return true; + } + + return false; + } + + // Recursive case for `tryAddArg(...)` + template + static inline llvm::Expected + tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { + llvm::Expected successOrError = tryAddArg(jla, pos, arg); + + if (!successOrError) + return std::move(successOrError.takeError()); + + if (successOrError.get() == false) + return tryAddArg(jla, pos, arg); + else + return true; + } + + // Attempts to add a single argument `arg` to `jla` at position + // `pos`. Returns an error if either the argument type is + // unsupported or if the argument types is supported, but adding it + // to `jla` failed. + static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos, + const LambdaArgument &arg) { + llvm::Expected successOrError = + JITLambdaArgumentAdaptor::tryAddArg(jla, pos, arg); + + if (!successOrError) + return std::move(successOrError.takeError()); + + if (successOrError.get() == false) + return StreamStringError("Unknown argument type"); + else + return llvm::Error::success(); + } +}; +} // namespace + +// A compiler engine that JIT-compiles a source and produces a lambda +// object directly invocable through its call operator. +class JitCompilerEngine : public CompilerEngine { +public: + // Wrapper class around `JITLambda` and `JITLambda::Argument` that + // allows for direct invocation of a compiled function through + // `operator ()`. + class Lambda { + public: + Lambda(Lambda &&other) + : innerLambda(std::move(other.innerLambda)), + keySet(std::move(other.keySet)), + compilationContext(other.compilationContext) {} + + Lambda(std::shared_ptr compilationContext, + std::unique_ptr lambda, std::unique_ptr keySet) + : innerLambda(std::move(lambda)), keySet(std::move(keySet)), + compilationContext(compilationContext) {} + + // Returns the number of arguments required for an invocation of + // the lambda + size_t getNumArguments() { return this->keySet->numInputs(); } + + // Returns the number of results an invocation of the lambda + // produces + size_t getNumResults() { return this->keySet->numOutputs(); } + + // Invocation with an dynamic list of arguments of different + // types, specified as `LambdaArgument`s + template + llvm::Expected + operator()(llvm::ArrayRef lambdaArgs) { + // Create the arguments of the JIT lambda + llvm::Expected> argsOrErr = + mlir::zamalang::JITLambda::Argument::create(*this->keySet.get()); + + if (llvm::Error err = argsOrErr.takeError()) + return StreamStringError("Could not create lambda arguments"); + + // Set the arguments + std::unique_ptr arguments = + std::move(argsOrErr.get()); + + for (size_t i = 0; i < lambdaArgs.size(); i++) { + if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument( + *arguments, i, *lambdaArgs[i])) { + return std::move(err); + } + } + + // Invoke the lambda + if (auto err = this->innerLambda->invoke(*arguments)) + return StreamStringError() << "Cannot invoke lambda:" << err; + + return std::move(typedResult(*arguments)); + } + + // Invocation with an array of arguments of the same type + template + llvm::Expected operator()(const llvm::ArrayRef args) { + // Create the arguments of the JIT lambda + llvm::Expected> argsOrErr = + mlir::zamalang::JITLambda::Argument::create(*this->keySet.get()); + + if (llvm::Error err = argsOrErr.takeError()) + return StreamStringError("Could not create lambda arguments"); + + // Set the arguments + std::unique_ptr arguments = + std::move(argsOrErr.get()); + + for (size_t i = 0; i < args.size(); i++) { + if (auto err = arguments->setArg(i, args[i])) { + return StreamStringError() + << "Cannot push argument " << i << ": " << err; + } + } + + // Invoke the lambda + if (auto err = this->innerLambda->invoke(*arguments)) + return StreamStringError() << "Cannot invoke lambda:" << err; + + return std::move(typedResult(*arguments)); + } + + // Invocation with arguments of different types + template + llvm::Expected operator()(const Ts... ts) { + // Create the arguments of the JIT lambda + llvm::Expected> argsOrErr = + mlir::zamalang::JITLambda::Argument::create(*this->keySet.get()); + + if (llvm::Error err = argsOrErr.takeError()) + return StreamStringError("Could not create lambda arguments"); + + // Set the arguments + std::unique_ptr arguments = + std::move(argsOrErr.get()); + + if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...)) + return std::move(err); + + // Invoke the lambda + if (auto err = this->innerLambda->invoke(*arguments)) + return StreamStringError() << "Cannot invoke lambda:" << err; + + return std::move(typedResult(*arguments)); + } + + protected: + template + inline llvm::Error addArgs(JITLambda::Argument *jitArgs) { + // base case -- nothing to do + return llvm::Error::success(); + } + + // Recursive case for scalars: extract first scalar argument from + // parameter pack and forward rest + template + inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg, + Ts... remainder) { + if (auto err = jitArgs->setArg(pos, arg)) { + return StreamStringError() + << "Cannot push scalar argument " << pos << ": " << err; + } + + return this->addArgs(jitArgs, remainder...); + } + + // Recursive case for tensors: extract pointer and size from + // parameter pack and forward rest + template + inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg, + size_t size, Ts... remainder) { + if (auto err = jitArgs->setArg(pos, arg, size)) { + return StreamStringError() + << "Cannot push tensor argument " << pos << ": " << err; + } + + return this->addArgs(jitArgs, remainder...); + } + + std::unique_ptr innerLambda; + std::unique_ptr keySet; + std::shared_ptr compilationContext; + }; + + JitCompilerEngine(std::shared_ptr compilationContext = + CompilationContext::createShared(), + unsigned int optimizationLevel = 3); + + llvm::Expected buildLambda(llvm::StringRef src, + llvm::StringRef funcName = "main"); + + llvm::Expected buildLambda(std::unique_ptr buffer, + llvm::StringRef funcName = "main"); + + llvm::Expected buildLambda(llvm::SourceMgr &sm, + llvm::StringRef funcName = "main"); + +protected: + llvm::Expected findLLVMFuncOp(mlir::ModuleOp module, + llvm::StringRef name); + unsigned int optimizationLevel; +}; + +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/include/zamalang/Support/LambdaArgument.h b/compiler/include/zamalang/Support/LambdaArgument.h new file mode 100644 index 000000000..9d5378377 --- /dev/null +++ b/compiler/include/zamalang/Support/LambdaArgument.h @@ -0,0 +1,157 @@ +#ifndef ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H +#define ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H + +#include +#include + +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { + +// Abstract base class for lambda arguments +class LambdaArgument + : public llvm::RTTIExtends { +public: + LambdaArgument(LambdaArgument &) = delete; + + template bool isa() const { return llvm::isa(*this); } + + // Cast functions on constant instances + template const T &cast() const { return llvm::cast(*this); } + template const T *dyn_cast() const { + return llvm::dyn_cast(this); + } + + // Cast functions for mutable instances + template T &cast() { return llvm::cast(*this); } + template T *dyn_cast() { return llvm::dyn_cast(this); } + + static char ID; + +protected: + LambdaArgument(){}; +}; + +// Class for integer arguments. `BackingIntType` is used as the data +// type to hold the argument's value. The precision is the actual +// precision of the value, which might be different from the precision +// of the backing integer type. +template +class IntLambdaArgument + : public llvm::RTTIExtends, + LambdaArgument> { +public: + typedef BackingIntType value_type; + + IntLambdaArgument(BackingIntType value, + unsigned int precision = 8 * sizeof(BackingIntType)) + : precision(precision) { + if (precision < 8 * sizeof(BackingIntType)) { + this->value = value & (1 << (this->precision - 1)); + } else { + this->value = value; + } + } + + unsigned int getPrecision() const { return this->precision; } + BackingIntType getValue() const { return this->value; } + + static char ID; + +protected: + unsigned int precision; + BackingIntType value; +}; + +template +char IntLambdaArgument::ID = 0; + +namespace { +// Calculates `accu *= factor` or returns an error if the result +// would overflow +template +llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) { + static_assert(std::numeric_limits::is_integer && + std::numeric_limits::is_integer && + !std::numeric_limits::is_signed && + !std::numeric_limits::is_signed, + "Only unsigned integers are supported"); + + const AccuT left = std::numeric_limits::max() / accu; + + if (left > factor) { + accu *= factor; + return llvm::Error::success(); + } + + return StreamStringError("Multiplying value ") + << accu << " with " << factor << " would cause an overflow"; +} +} // namespace + +// Class for Tensor arguments. This can either be plaintext tensors +// (for `ScalarArgumentT = IntLambaArgument`) or tensors +// representing encrypted integers (for `ScalarArgumentT = +// EIntLambaArgument`). +template +class TensorLambdaArgument + : public llvm::RTTIExtends, + LambdaArgument> { +public: + typedef ScalarArgumentT scalar_type; + + // Construct tensor argument from the one-dimensional array `value`, + // but interpreting the array's values as a linearized + // multi-dimensional tensor with the sizes of the dimensions + // specified in `dimensions`. + TensorLambdaArgument( + llvm::MutableArrayRef value, + llvm::ArrayRef dimensions) + : value(value), dimensions(dimensions.vec()) {} + + // Construct a one-dimensional tensor argument from the + // array `value`. + TensorLambdaArgument( + llvm::MutableArrayRef value) + : TensorLambdaArgument(value, {(unsigned int)value.size()}) {} + + const std::vector &getDimensions() const { + return this->dimensions; + } + + // Returns the total number of elements in the tensor. If the number + // of elements cannot be represented as a `size_t`, the method + // returns an error. + llvm::Expected getNumElements() const { + size_t accu = 1; + + for (unsigned int dimSize : dimensions) + if (llvm::Error err = safeUnsignedMul(accu, dimSize)) + return std::move(err); + + return accu; + } + + // Returns a bare pointer to the linearized values of the tensor. + typename ScalarArgumentT::value_type *getValue() const { + return this->value.data(); + } + + static char ID; + +protected: + llvm::MutableArrayRef value; + std::vector dimensions; +}; + +template +char TensorLambdaArgument::ID = 0; + +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index a99f1c1a6..65e100e68 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1,8 +1,9 @@ #include "CompilerAPIModule.h" #include "zamalang-c/Support/CompilerEngine.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc" -#include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/ExecutionArgument.h" +#include "zamalang/Support/Jit.h" +#include "zamalang/Support/JitCompilerEngine.h" #include #include #include @@ -14,27 +15,15 @@ #include #include -using mlir::zamalang::CompilerEngine; using mlir::zamalang::ExecutionArgument; +using mlir::zamalang::JitCompilerEngine; /// Populate the compiler API python module. void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { m.doc() = "Zamalang compiler python API"; - m.def("round_trip", [](std::string mlir_input) { - mlir::MLIRContext context; - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - auto module_ref = mlir::parseSourceString(mlir_input, &context); - if (!module_ref) { - throw std::logic_error("mlir parsing failed"); - } - std::string result; - llvm::raw_string_ostream os(result); - module_ref->print(os); - return os.str(); - }); + m.def("round_trip", + [](std::string mlir_input) { return roundTrip(mlir_input.c_str()); }); pybind11::class_>( m, "ExecutionArgument") @@ -45,20 +34,19 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { .def("is_tensor", &ExecutionArgument::isTensor) .def("is_int", &ExecutionArgument::isInt); - pybind11::class_(m, "CompilerEngine") + pybind11::class_(m, "JitCompilerEngine") .def(pybind11::init()) - .def("run", - [](CompilerEngine &engine, std::vector args) { - // wrap and call CAPI - compilerEngine e{&engine}; - exectuionArguments a{args.data(), args.size()}; - return compilerEngineRun(e, a); - }) - .def("compile_fhe", - [](CompilerEngine &engine, std::string mlir_input) { - // wrap and call CAPI - compilerEngine e{&engine}; - compilerEngineCompile(e, mlir_input.c_str()); - }) - .def("get_compiled_module", &CompilerEngine::getCompiledModule); + .def_static("build_lambda", + [](std::string mlir_input, std::string func_name) { + return buildLambda(mlir_input.c_str(), func_name.c_str()); + }); + + pybind11::class_(m, "Lambda") + .def("invoke", [](JitCompilerEngine::Lambda &py_lambda, + std::vector args) { + // wrap and call CAPI + lambda c_lambda{&py_lambda}; + exectuionArguments a{args.data(), args.size()}; + return invokeLambda(c_lambda, a); + }); } diff --git a/compiler/lib/Bindings/Python/zamalang/compiler.py b/compiler/lib/Bindings/Python/zamalang/compiler.py index 76e372463..130f4275e 100644 --- a/compiler/lib/Bindings/Python/zamalang/compiler.py +++ b/compiler/lib/Bindings/Python/zamalang/compiler.py @@ -1,10 +1,9 @@ """Compiler submodule""" from typing import List, Union -from mlir._mlir_libs._zamalang._compiler import CompilerEngine as _CompilerEngine +from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip - def round_trip(mlir_str: str) -> str: """Parse the MLIR input, then return it back. @@ -49,25 +48,24 @@ def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgume class CompilerEngine: def __init__(self, mlir_str: str = None): - self._engine = _CompilerEngine() + self._engine = _JitCompilerEngine() + self._lambda = None if mlir_str is not None: self.compile_fhe(mlir_str) - def compile_fhe(self, mlir_str: str) -> "CompilerEngine": - """Compile the MLIR input and build a CompilerEngine. + def compile_fhe(self, mlir_str: str, func_name: str = "main"): + """Compile the MLIR input. Args: mlir_str (str): MLIR to compile. + func_name (str): name of the function to set as entrypoint. Raises: TypeError: if the argument is not an str. - - Returns: - CompilerEngine: engine used for execution. """ if not isinstance(mlir_str, str): raise TypeError("input must be an `str`") - return self._engine.compile_fhe(mlir_str) + self._lambda = self._engine.build_lambda(mlir_str, func_name) def run(self, *args: List[Union[int, List[int]]]) -> int: """Run the compiled code. @@ -77,17 +75,12 @@ class CompilerEngine: Raises: TypeError: if execution arguments can't be constructed + RuntimeError: if the engine has not compiled any code yet Returns: int: result of execution. """ + if self._lambda is None: + raise RuntimeError("need to compile an MLIR code first") execution_arguments = [create_execution_argument(arg) for arg in args] - return self._engine.run(execution_arguments) - - def get_compiled_module(self) -> str: - """Compiled module in printable form. - - Returns: - str: Compiled module in printable form. - """ - return self._engine.get_compiled_module() + return self._lambda.invoke(execution_arguments) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index e48f7ddf2..a8f9dd71e 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -1,62 +1,83 @@ #include "zamalang-c/Support/CompilerEngine.h" #include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/ExecutionArgument.h" +#include "zamalang/Support/Jit.h" +#include "zamalang/Support/JitCompilerEngine.h" +#include "zamalang/Support/logging.h" -using mlir::zamalang::CompilerEngine; +// using mlir::zamalang::CompilerEngine; using mlir::zamalang::ExecutionArgument; +using mlir::zamalang::JitCompilerEngine; -void compilerEngineCompile(compilerEngine engine, const char *module) { - auto error = engine.ptr->compile(module); - if (error) { - llvm::errs() << "Compilation failed: " << error << "\n"; - llvm::consumeError(std::move(error)); +mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, + const char *funcName) { + mlir::zamalang::JitCompilerEngine engine; + llvm::Expected lambdaOrErr = + engine.buildLambda(module, funcName); + if (!lambdaOrErr) { + mlir::zamalang::log_error() + << "Compilation failed: " + << llvm::toString(std::move(lambdaOrErr.takeError())) << "\n"; throw std::runtime_error( "failed compiling, see previous logs for more info"); } + return std::move(*lambdaOrErr); } -uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) { - auto args_size = args.size; - auto maybeArgument = engine.ptr->buildArgument(); - if (auto err = maybeArgument.takeError()) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error( - "failed building arguments, see previous logs for more info"); +uint64_t invokeLambda(lambda l, executionArguments args) { + mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr = + (mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr; + + if (args.size != lambda_ptr->getNumArguments()) { + throw std::invalid_argument("wrong number of arguments"); } // Set the integer/tensor arguments - auto arguments = std::move(maybeArgument.get()); - for (auto i = 0; i < args_size; i++) { + std::vector lambdaArgumentsRef; + for (auto i = 0; i < args.size; i++) { if (args.data[i].isInt()) { // integer argument - if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error("failed pushing integer argument, see " - "previous logs for more info"); - } + lambdaArgumentsRef.push_back(new mlir::zamalang::IntLambdaArgument<>( + args.data[i].getIntegerArgument())); } else { // tensor argument - assert(args.data[i].isTensor() && "should be tensor argument"); - if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(), - args.data[i].getTensorSize())) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error("failed pushing tensor argument, see " - "previous logs for more info"); - } + llvm::MutableArrayRef tensor(args.data[i].getTensorArgument(), + args.data[i].getTensorSize()); + lambdaArgumentsRef.push_back( + new mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument>(tensor)); } } - // Invoke the lambda - if (auto err = engine.ptr->invoke(*arguments)) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error("failed running, see previous logs for more info"); - } - uint64_t result = 0; - if (auto err = arguments->getResult(0, result)) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); + // Run lambda + llvm::Expected resOrError = (*lambda_ptr)( + llvm::ArrayRef(lambdaArgumentsRef)); + // Free heap + for (size_t i = 0; i < lambdaArgumentsRef.size(); i++) + delete lambdaArgumentsRef[i]; + + if (!resOrError) { + mlir::zamalang::log_error() + << "Lambda invokation failed: " + << llvm::toString(std::move(resOrError.takeError())) << "\n"; throw std::runtime_error( - "failed getting result, see previous logs for more info"); + "failed invoking lambda, see previous logs for more info"); } - return result; -} \ No newline at end of file + return *resOrError; +} + +std::string roundTrip(const char *module) { + std::shared_ptr ccx = + mlir::zamalang::CompilationContext::createShared(); + mlir::zamalang::JitCompilerEngine ce{ccx}; + + llvm::Expected retOrErr = + ce.compile(module, mlir::zamalang::CompilerEngine::Target::ROUND_TRIP); + if (!retOrErr) { + mlir::zamalang::log_error() + << llvm::toString(std::move(retOrErr.takeError())) << "\n"; + throw std::runtime_error( + "mlir parsing failed, see previous logs for more info"); + } + + std::string result; + llvm::raw_string_ostream os(result); + retOrErr->mlirModuleRef->get().print(os); + return os.str(); +} diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 9694989d3..a1fde7a86 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -3,6 +3,8 @@ add_mlir_library(ZamalangSupport Pipeline.cpp Jit.cpp CompilerEngine.cpp + JitCompilerEngine.cpp + LambdaArgument.cpp V0Parameters.cpp V0Curves.cpp ClientParameters.cpp diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 5a41c7e46..50d98b721 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -1,3 +1,5 @@ +#include +#include #include #include #include @@ -9,155 +11,419 @@ #include #include #include +#include +#include #include namespace mlir { namespace zamalang { -void CompilerEngine::loadDialects() { - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); +// Creates a new compilation context that can be shared across +// compilation engines and results +std::shared_ptr CompilationContext::createShared() { + return std::make_shared(); } -std::string CompilerEngine::getCompiledModule() { - std::string compiledModule; - llvm::raw_string_ostream os(compiledModule); - module_ref->print(os); - return os.str(); +CompilationContext::CompilationContext() + : mlirContext(nullptr), llvmContext(nullptr) {} + +CompilationContext::~CompilationContext() { + delete this->mlirContext; + delete this->llvmContext; } -llvm::Error CompilerEngine::compile( - std::string mlirStr, - llvm::Optional overrideConstraints) { - module_ref = mlir::parseSourceString(mlirStr, context); - if (!module_ref) { - return llvm::make_error("mlir parsing failed", - llvm::inconvertibleErrorCode()); +// Returns the MLIR context for a compilation context. Creates and +// initializes a new MLIR context if necessary. +mlir::MLIRContext *CompilationContext::getMLIRContext() { + if (this->mlirContext == nullptr) { + this->mlirContext = new mlir::MLIRContext(); + + this->mlirContext->getOrLoadDialect(); + this->mlirContext + ->getOrLoadDialect(); + this->mlirContext + ->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); } - mlir::ModuleOp module = module_ref.get(); + return this->mlirContext; +} - llvm::Optional fheConstraintsOpt = - overrideConstraints; +// Returns the LLVM context for a compilation context. Creates and +// initializes a new LLVM context if necessary. +llvm::LLVMContext *CompilationContext::getLLVMContext() { + if (this->llvmContext == nullptr) + this->llvmContext = new llvm::LLVMContext(); - if (!fheConstraintsOpt.hasValue()) { + return this->llvmContext; +} + +// Sets the FHE constraints for the compilation. Overrides any +// automatically detected configuration and prevents the autodetection +// pass from running. +void CompilerEngine::setFHEConstraints( + const mlir::zamalang::V0FHEConstraint &c) { + this->overrideMaxEintPrecision = c.p; + this->overrideMaxMANP = c.norm2; +} + +void CompilerEngine::setVerifyDiagnostics(bool v) { + this->verifyDiagnostics = v; +} + +void CompilerEngine::setGenerateKeySet(bool v) { this->generateKeySet = v; } + +void CompilerEngine::setGenerateClientParameters(bool v) { + this->generateClientParameters = v; +} + +void CompilerEngine::setMaxEintPrecision(size_t v) { + this->overrideMaxEintPrecision = v; +} + +void CompilerEngine::setParametrizeMidLFHE(bool v) { + this->parametrizeMidLFHE = v; +} + +void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; } + +void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) { + this->clientParametersFuncName = name.str(); +} + +// Helper function detecting the FHE dialect with the highest level of +// abstraction used in `module`. If no FHE dialect is used, the +// function returns `CompilerEngine::FHEDialect::NONE`. +CompilerEngine::FHEDialect +CompilerEngine::detectHighestFHEDialect(mlir::ModuleOp module) { + CompilerEngine::FHEDialect highestDialect = CompilerEngine::FHEDialect::NONE; + + mlir::TypeID hlfheID = + mlir::TypeID::get(); + mlir::TypeID midlfheID = + mlir::TypeID::get(); + mlir::TypeID lowlfheID = + mlir::TypeID::get(); + + // Helper lambda updating the currently highest dialect if necessary + // by dialect type ID + auto updateDialectFromDialectID = [&](mlir::TypeID dialectID) { + if (dialectID == hlfheID) { + highestDialect = CompilerEngine::FHEDialect::HLFHE; + return true; + } else if (dialectID == lowlfheID && + highestDialect == CompilerEngine::FHEDialect::NONE) { + highestDialect = CompilerEngine::FHEDialect::LOWLFHE; + } else if (dialectID == midlfheID && + (highestDialect == CompilerEngine::FHEDialect::NONE || + highestDialect == CompilerEngine::FHEDialect::LOWLFHE)) { + highestDialect = CompilerEngine::FHEDialect::MIDLFHE; + } + + return false; + }; + + // Helper lambda updating the currently highest dialect if necessary + // by value type + std::function updateDialectFromType = + [&](mlir::Type ty) -> bool { + if (updateDialectFromDialectID(ty.getDialect().getTypeID())) + return true; + + if (mlir::TensorType tensorTy = ty.dyn_cast_or_null()) + return updateDialectFromType(tensorTy.getElementType()); + + return false; + }; + + module.walk([&](mlir::Operation *op) { + // Check operation itself + if (updateDialectFromDialectID(op->getDialect()->getTypeID())) + return mlir::WalkResult::interrupt(); + + // Check types of operands + for (mlir::Value operand : op->getOperands()) { + if (updateDialectFromType(operand.getType())) + return mlir::WalkResult::interrupt(); + } + + // Check types of results + for (mlir::Value res : op->getResults()) { + if (updateDialectFromType(res.getType())) { + return mlir::WalkResult::interrupt(); + } + } + + return mlir::WalkResult::advance(); + }); + + return highestDialect; +} + +// Sets the FHE parameters of `res` either through autodetection or +// fixed constraints provided in +// `CompilerEngine::overrideMaxEintPrecision` and +// `CompilerEngine::overrideMaxMANP`. +// +// Autodetected values can be partially or fully overridden through +// `CompilerEngine::overrideMaxEintPrecision` and +// `CompilerEngine::overrideMaxMANP`. +// +// If `noOverrideAutodetected` is true, autodetected values are not +// overriden and used directly for `res`. +// +// Return an error if autodetection fails. +llvm::Error +CompilerEngine::determineFHEParameters(CompilationResult &res, + bool noOverrideAutodetected) { + mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + mlir::ModuleOp module = res.mlirModuleRef->get(); + llvm::Optional fheConstraints; + + // Determine FHE constraints either through autodetection or through + // overridden values + if (this->overrideMaxEintPrecision.hasValue() && + this->overrideMaxMANP.hasValue() && !noOverrideAutodetected) { + fheConstraints.emplace(mlir::zamalang::V0FHEConstraint{ + this->overrideMaxMANP.getValue(), + this->overrideMaxEintPrecision.getValue()}); + + } else { llvm::Expected> fheConstraintsOrErr = - mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(*context, + mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(mlirContext, module); if (auto err = fheConstraintsOrErr.takeError()) return std::move(err); if (!fheConstraintsOrErr.get().hasValue()) { - return llvm::make_error( - "Could not determine maximum required precision for encrypted " - "integers " - "and maximum value for the Minimal Arithmetic Noise Padding", - llvm::inconvertibleErrorCode()); + return StreamStringError("Could not determine maximum required precision " + "for encrypted integers and maximum value for " + "the Minimal Arithmetic Noise Padding"); } - fheConstraintsOpt = fheConstraintsOrErr.get(); + if (noOverrideAutodetected) + return llvm::Error::success(); + + fheConstraints = fheConstraintsOrErr.get(); + + // Override individual values if requested + if (this->overrideMaxEintPrecision.hasValue()) + fheConstraints->p = this->overrideMaxEintPrecision.getValue(); + + if (this->overrideMaxMANP.hasValue()) + fheConstraints->norm2 = this->overrideMaxMANP.getValue(); } - mlir::zamalang::V0FHEConstraint fheConstraints = fheConstraintsOpt.getValue(); - const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); + const mlir::zamalang::V0Parameter *fheParams = + getV0Parameter(fheConstraints.getValue()); - if (!parameter) { - std::string buffer; - llvm::raw_string_ostream strs(buffer); - strs << "Could not determine V0 parameters for 2-norm of " - << fheConstraints.norm2 << " and p of " << fheConstraints.p; - - return llvm::make_error(strs.str(), - llvm::inconvertibleErrorCode()); + if (!fheParams) { + return StreamStringError() + << "Could not determine V0 parameters for 2-norm of " + << fheConstraints->norm2 << " and p of " << fheConstraints->p; } - mlir::zamalang::V0FHEContext fheContext{fheConstraints, *parameter}; + res.fheContext.emplace( + mlir::zamalang::V0FHEContext{*fheConstraints, *fheParams}); - // Lower to MLIR Std - if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext, - false) - .failed()) { - return llvm::make_error("failed to lower to MLIR Std", - llvm::inconvertibleErrorCode()); - } - // Create the client parameters - auto clientParameter = mlir::zamalang::createClientParametersForV0( - fheContext, "main", module_ref.get()); - if (auto err = clientParameter.takeError()) { - return std::move(err); - } - auto maybeKeySet = - mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0); - if (auto err = maybeKeySet.takeError()) { - return std::move(err); - } - keySet = std::move(maybeKeySet.get()); - - // Lower to MLIR LLVM Dialect - if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(*context, module, false) - .failed()) { - return llvm::make_error( - "failed to lower to LLVM dialect", llvm::inconvertibleErrorCode()); - } return llvm::Error::success(); } -llvm::Expected> -CompilerEngine::buildArgument() { - if (keySet.get() == nullptr) { - return llvm::make_error( - "CompilerEngine::buildArgument: invalid engine state, the keySet has " - "not been generated", - llvm::inconvertibleErrorCode()); - } - return JITLambda::Argument::create(*keySet); -} +// Performs all lowering from HLFHE to the FHE dialect with the lwoest +// level of abstraction that requires FHE parameters. +// +// Returns an error if any of the lowerings fails. +llvm::Error CompilerEngine::lowerParamDependentHalf(Target target, + CompilationResult &res) { + mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + mlir::ModuleOp module = res.mlirModuleRef->get(); -llvm::Error CompilerEngine::invoke(JITLambda::Argument &arg) { - // Create the JIT lambda - auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); - auto module = module_ref.get(); - auto maybeLambda = - mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline); - if (auto err = maybeLambda.takeError()) { - return std::move(err); + // HLFHE -> MidLFHE + if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module, false) + .failed()) { + return StreamStringError("Lowering from HLFHE to MidLFHE failed"); } - // Invoke the lambda - if (auto err = maybeLambda.get()->invoke(arg)) { - return std::move(err); + + if (target == Target::MIDLFHE) + return llvm::Error::success(); + + // MidLFHE -> LowLFHE + if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( + mlirContext, module, *res.fheContext, this->parametrizeMidLFHE) + .failed()) { + return StreamStringError("Lowering from MidLFHE to LowLFHE failed"); } + return llvm::Error::success(); } -llvm::Expected CompilerEngine::run(std::vector args) { - // Build the argument of the JIT lambda. - auto maybeArgument = buildArgument(); - if (auto err = maybeArgument.takeError()) { - return std::move(err); +// Compile the sources managed by the source manager `sm` to the +// target dialect `target`. If successful, the result can be retrieved +// using `getModule()` and `getLLVMModule()`, respectively depending +// on the target dialect. +llvm::Expected +CompilerEngine::compile(llvm::SourceMgr &sm, Target target) { + CompilationResult res(this->compilationContext); + + mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + + mlir::SourceMgrDiagnosticVerifierHandler smHandler(sm, &mlirContext); + mlirContext.printOpOnDiagnostic(false); + + mlir::OwningModuleRef mlirModuleRef = + mlir::parseSourceFile(sm, &mlirContext); + + if (this->verifyDiagnostics) { + if (smHandler.verify().failed()) + return StreamStringError("Verification of diagnostics failed"); + else + return res; } - // Set the integer arguments - auto arguments = std::move(maybeArgument.get()); - for (auto i = 0; i < args.size(); i++) { - if (auto err = arguments->setArg(i, args[i])) { + + if (!mlirModuleRef) + return StreamStringError("Could not parse source"); + + res.mlirModuleRef = std::move(mlirModuleRef); + mlir::ModuleOp module = res.mlirModuleRef->get(); + + if (target == Target::HLFHE || target == Target::ROUND_TRIP) + return res; + + // Detect highest FHE dialect and check if FHE parameter + // autodetection / lowering of parameter-dependent dialects can be + // skipped + FHEDialect highestFHEDialect = this->detectHighestFHEDialect(module); + + if (highestFHEDialect == FHEDialect::HLFHE || + highestFHEDialect == FHEDialect::MIDLFHE || + this->generateClientParameters) { + bool noOverrideAutoDetected = (target == Target::HLFHE_MANP); + if (auto err = this->determineFHEParameters(res, noOverrideAutoDetected)) return std::move(err); + } + + // return early if only the MANP pass was requested + if (target == Target::HLFHE_MANP) + return res; + + if (highestFHEDialect == FHEDialect::HLFHE || + highestFHEDialect == FHEDialect::MIDLFHE) { + if (llvm::Error err = this->lowerParamDependentHalf(target, res)) + return std::move(err); + } + + if (target == Target::HLFHE_MANP || target == Target::MIDLFHE || + target == Target::LOWLFHE) + return res; + + // LowLFHE -> Canonical dialects + if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module) + .failed()) { + return StreamStringError( + "Lowering from LowLFHE to canonical MLIR dialects failed"); + } + + if (target == Target::STD) + return res; + + // Generate client parameters if requested + if (this->generateClientParameters) { + if (!this->clientParametersFuncName.hasValue()) { + return StreamStringError( + "Generation of client parameters requested, but no function name " + "specified"); } + + llvm::Expected clientParametersOrErr = + mlir::zamalang::createClientParametersForV0( + *res.fheContext, *this->clientParametersFuncName, module); + + if (llvm::Error err = clientParametersOrErr.takeError()) + return std::move(err); + + res.clientParameters = clientParametersOrErr.get(); } - // Invoke the lambda - if (auto err = invoke(*arguments)) { - return std::move(err); + + // Generate Key set if requested + if (this->generateKeySet) { + if (!res.clientParameters.hasValue()) { + return StreamStringError("Generation of keyset requested without request " + "for generation of client parameters"); + } + + llvm::Expected> keySetOrErr = + mlir::zamalang::KeySet::generate(*res.clientParameters, 0, 0); + + if (auto err = keySetOrErr.takeError()) + return std::move(err); + + res.keySet = std::move(*keySetOrErr); } - uint64_t res = 0; - if (auto err = arguments->getResult(0, res)) { - return std::move(err); + + // MLIR canonical dialects -> LLVM Dialect + if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module, + false) + .failed()) { + return StreamStringError("Failed to lower to LLVM dialect"); } + + if (target == Target::LLVM) + return res; + + // Lowering to actual LLVM IR (i.e., not the LLVM dialect) + llvm::LLVMContext &llvmContext = *this->compilationContext->getLLVMContext(); + + res.llvmModule = mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR( + mlirContext, llvmContext, module); + + if (!res.llvmModule) + return StreamStringError("Failed to convert from LLVM dialect to LLVM IR"); + + if (target == Target::LLVM_IR) + return res; + + if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *res.llvmModule) + .failed()) { + return StreamStringError("Failed to optimize LLVM IR"); + } + + if (target == Target::OPTIMIZED_LLVM_IR) + return res; + return res; +} // namespace zamalang + +// Compile the source `s` to the target dialect `target`. If successful, the +// result can be retrieved using `getModule()` and `getLLVMModule()`, +// respectively depending on the target dialect. +llvm::Expected +CompilerEngine::compile(llvm::StringRef s, Target target) { + std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); + llvm::Expected res = this->compile(std::move(mb), target); + + return std::move(res); } + +// Compile the contained in `buffer` to the target dialect +// `target`. If successful, the result can be retrieved using +// `getModule()` and `getLLVMModule()`, respectively depending on the +// target dialect. +llvm::Expected +CompilerEngine::compile(std::unique_ptr buffer, + Target target) { + llvm::SourceMgr sm; + + sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); + + llvm::Expected res = this->compile(sm, target); + + return std::move(res); +} + } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 85ab88aaf..0b100e03e 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -1,3 +1,4 @@ +#include "llvm/Support/Error.h" #include #include #include @@ -12,56 +13,6 @@ namespace mlir { namespace zamalang { -// JIT-compiles `module` invokes `func` with the arguments passed in -// `jitArguments` and `keySet` -mlir::LogicalResult -runJit(mlir::ModuleOp module, llvm::StringRef func, - llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, - std::function optPipeline, - llvm::raw_ostream &os) { - // Create the JIT lambda - auto maybeLambda = - mlir::zamalang::JITLambda::create(func, module, optPipeline); - if (!maybeLambda) { - return mlir::failure(); - } - auto lambda = std::move(maybeLambda.get()); - - // Create the arguments of the JIT lambda - auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet); - if (auto err = maybeArguments.takeError()) { - ::mlir::zamalang::log_error() - << "Cannot create lambda arguments: " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - - // Set the arguments - auto arguments = std::move(maybeArguments.get()); - for (size_t i = 0; i < funcArgs.size(); i++) { - if (auto err = arguments->setArg(i, funcArgs[i])) { - ::mlir::zamalang::log_error() - << "Cannot push argument " << i << ": " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - } - // Invoke the lambda - if (auto err = lambda->invoke(*arguments)) { - ::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - uint64_t res = 0; - if (auto err = arguments->getResult(0, res)) { - ::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - llvm::errs() << res << "\n"; - return mlir::success(); -} - llvm::Expected> JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, llvm::function_ref optPipeline) { diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp new file mode 100644 index 000000000..05359ac4f --- /dev/null +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -0,0 +1,105 @@ +#include "llvm/Support/Error.h" +#include +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { + +JitCompilerEngine::JitCompilerEngine( + std::shared_ptr compilationContext, + unsigned int optimizationLevel) + : CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) { +} + +// Returns the `LLVMFuncOp` operation in the compiled module with the +// specified name. If no LLVMFuncOp with that name exists or if there +// was no prior call to `compile()` resulting in an MLIR module in the +// LLVM dialect, an error is returned. +llvm::Expected +JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) { + auto funcOps = module.getOps(); + auto funcOp = llvm::find_if( + funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; }); + + if (funcOp == funcOps.end()) { + return StreamStringError() + << "Module does not contain function named '" << name.str() << "'"; + } + + return *funcOp; +} + +// Build a lambda from the function with the name given in +// `funcName` from the sources in `buffer`. +llvm::Expected +JitCompilerEngine::buildLambda(std::unique_ptr buffer, + llvm::StringRef funcName) { + llvm::SourceMgr sm; + + sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); + + llvm::Expected res = + this->buildLambda(sm, funcName); + + return std::move(res); +} + +// Build a lambda from the function with the name given in `funcName` +// from the source string `s`. +llvm::Expected +JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) { + std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); + llvm::Expected res = + this->buildLambda(std::move(mb), funcName); + + return std::move(res); +} + +// Build a lambda from the function with the name given in +// `funcName` from the sources managed by the source manager `sm`. +llvm::Expected +JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { + MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + + this->setGenerateKeySet(true); + this->setGenerateClientParameters(true); + this->setClientParametersFuncName(funcName); + + // First, compile to LLVM Dialect + llvm::Expected compResOrErr = + this->compile(sm, Target::LLVM_IR); + + if (!compResOrErr) + return std::move(compResOrErr.takeError()); + + mlir::ModuleOp module = compResOrErr->mlirModuleRef->get(); + + // Locate function to JIT-compile + llvm::Expected funcOrError = + this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName); + + if (!funcOrError) + return std::move(funcOrError.takeError()); + + // Prepare LLVM infrastructure for JIT compilation + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerLLVMDialectTranslation(mlirContext); + + std::function optPipeline = + mlir::makeOptimizingTransformer(3, 0, nullptr); + + llvm::Expected> lambdaOrErr = + mlir::zamalang::JITLambda::create(funcName, module, optPipeline); + + if (!lambdaOrErr) + return std::move(lambdaOrErr.takeError()); + + return Lambda{this->compilationContext, std::move(lambdaOrErr.get()), + std::move(compResOrErr->keySet)}; +} +} // namespace zamalang +} // namespace mlir diff --git a/compiler/lib/Support/LambdaArgument.cpp b/compiler/lib/Support/LambdaArgument.cpp new file mode 100644 index 000000000..a693c0177 --- /dev/null +++ b/compiler/lib/Support/LambdaArgument.cpp @@ -0,0 +1,7 @@ +#include + +namespace mlir { +namespace zamalang { +char LambdaArgument::ID = 0; +} // namespace zamalang +} // namespace mlir diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index afc0abf6a..1e81f2803 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -22,15 +23,15 @@ #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" -#include "zamalang/Support/Jit.h" +#include "zamalang/Support/Error.h" +#include "zamalang/Support/JitCompilerEngine.h" #include "zamalang/Support/KeySet.h" #include "zamalang/Support/Pipeline.h" #include "zamalang/Support/logging.h" -enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM }; - enum Action { ROUND_TRIP, + DUMP_HLFHE, DUMP_HLFHE_MANP, DUMP_MIDLFHE, DUMP_LOWLFHE, @@ -80,26 +81,6 @@ llvm::cl::opt parametrizeMidLFHE( llvm::cl::desc("Perform MidLFHE global parametrization pass"), llvm::cl::init(true)); -static llvm::cl::opt entryDialect( - "e", "entry-dialect", llvm::cl::desc("Entry dialect"), - llvm::cl::init(EntryDialect::HLFHE), - llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, - llvm::cl::values( - clEnumValN(EntryDialect::HLFHE, "hlfhe", - "Input module is composed of HLFHE operations")), - llvm::cl::values( - clEnumValN(EntryDialect::MIDLFHE, "midlfhe", - "Input module is composed of MidLFHE operations")), - llvm::cl::values( - clEnumValN(EntryDialect::LOWLFHE, "lowlfhe", - "Input module is composed of LowLFHE operations")), - llvm::cl::values( - clEnumValN(EntryDialect::STD, "std", - "Input module is composed of operations from std")), - llvm::cl::values( - clEnumValN(EntryDialect::LLVM, "llvm", - "Input module is composed of operations from llvm"))); - static llvm::cl::opt action( "a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, @@ -109,6 +90,8 @@ static llvm::cl::opt action( llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp", "Dump HLFHE module after running the Minimal " "Arithmetic Noise Padding pass")), + llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe", + "Dump HLFHE module")), llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe", "Lower to MidLFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe", @@ -158,50 +141,7 @@ llvm::cl::opt, false, OptionalSizeTParser> assumeMaxMANP( llvm::cl::desc( "Assume a maximum for the Minimum Arithmetic Noise Padding")); -}; // namespace cmdline - -std::function defaultOptPipeline = - mlir::makeOptimizingTransformer(3, 0, nullptr); - -std::unique_ptr -generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, - const std::string &jitFuncName) { - std::unique_ptr keySet; - - mlir::zamalang::log_verbose() - << "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2 - << ", p:" << fheContext.constraint.p << "}\n"; - mlir::zamalang::log_verbose() - << "### FHE parameters for the atomic pattern: {k: " - << fheContext.parameter.k - << ", polynomialSize: " << fheContext.parameter.polynomialSize - << ", nSmall: " << fheContext.parameter.nSmall - << ", brLevel: " << fheContext.parameter.brLevel - << ", brLogBase: " << fheContext.parameter.brLogBase - << ", ksLevel: " << fheContext.parameter.ksLevel - << ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n"; - - // Create the client parameters - auto clientParameter = mlir::zamalang::createClientParametersForV0( - fheContext, jitFuncName, module); - - if (auto err = clientParameter.takeError()) { - mlir::zamalang::log_error() - << "cannot generate client parameters: " << err << "\n"; - return nullptr; - } - - mlir::zamalang::log_verbose() << "### Generate the key set\n"; - - auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, - 0); // TODO: seed - if (auto err = maybeKeySet.takeError()) { - llvm::errs() << err; - return nullptr; - } - - return std::move(maybeKeySet.get()); -} +} // namespace cmdline llvm::Expected buildFHEContext( llvm::Optional autoFHEConstraints, @@ -209,65 +149,48 @@ llvm::Expected buildFHEContext( llvm::Optional overrideMaxMANP) { if (!autoFHEConstraints.hasValue() && (!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) { - return llvm::make_error( + return mlir::zamalang::StreamStringError( "Maximum encrypted integer precision and maximum for the Minimal" "Arithmetic Noise Passing are required, but were neither specified" - "explicitly nor determined automatically", - llvm::inconvertibleErrorCode()); + "explicitly nor determined automatically"); } mlir::zamalang::V0FHEConstraint fheConstraints{ - .norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue() - : autoFHEConstraints.getValue().norm2, - .p = overrideMaxEintPrecision.hasValue() - ? overrideMaxEintPrecision.getValue() - : autoFHEConstraints.getValue().p}; + overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue() + : autoFHEConstraints.getValue().norm2, + overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue() + : autoFHEConstraints.getValue().p}; const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); if (!parameter) { - std::string buffer; - llvm::raw_string_ostream strs(buffer); - strs << "Could not determine V0 parameters for 2-norm of " - << fheConstraints.norm2 << " and p of " << fheConstraints.p; - - return llvm::make_error(strs.str(), - llvm::inconvertibleErrorCode()); + return mlir::zamalang::StreamStringError() + << "Could not determine V0 parameters for 2-norm of " + << fheConstraints.norm2 << " and p of " << fheConstraints.p; } return mlir::zamalang::V0FHEContext{fheConstraints, *parameter}; } -mlir::LogicalResult buildAssignFHEContext( - llvm::Optional &fheContext, - llvm::Optional autoFHEConstraints, - llvm::Optional overrideMaxEintPrecision, - llvm::Optional overrideMaxMANP) { +namespace llvm { +// This needs to be wrapped into the llvm namespace for proper +// operator lookup +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const llvm::ArrayRef arr) { + os << "("; + for (size_t i = 0; i < arr.size(); i++) { + os << arr[i]; - if (fheContext.hasValue()) - return mlir::success(); - - llvm::Expected fheContextOrErr = - buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision, - overrideMaxMANP); - - if (auto err = fheContextOrErr.takeError()) { - mlir::zamalang::log_error() << err; - return mlir::failure(); + if (i != arr.size() - 1) + os << ", "; } - fheContext.emplace(fheContextOrErr.get()); - - return mlir::success(); + return os; } +} // namespace llvm // Process a single source buffer // -// The parameter `entryDialect` must specify the FHE dialect to which -// belong all FHE operations used in the source buffer. The input -// program must only contain FHE operations from that single FHE -// dialect, otherwise processing might fail. -// // The parameter `action` specifies how the buffer should be processed // and thus defines the output. // @@ -276,15 +199,14 @@ mlir::LogicalResult buildAssignFHEContext( // using the parameters given in `jitArgs`. // // The parameter `parametrizeMidLFHE` defines, whether the -// parametrization pass for MidLFHE is executed. If the pair of -// `entryDialect` and `action` does not involve any MidlFHE -// manipulation, this parameter does not have any effect. +// parametrization pass for MidLFHE is executed. If the `action` does +// not involve any MidlFHE manipulation, this parameter does not have +// any effect. // // The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if // set, override the values for the maximum required precision of // encrypted integers and the maximum value for the Minimum Arithmetic -// Noise Padding otherwise determined automatically if the entry -// dialect is HLFHE.. +// Noise Padding otherwise determined automatically. // // If `verifyDiagnostics` is `true`, the procedure only checks if the // diagnostic messages provided in the source buffer using @@ -292,164 +214,103 @@ mlir::LogicalResult buildAssignFHEContext( // the procedure checks if the parsed module is valid and if all // requested transformations succeeded. // -// If `verbose` is true, debug messages are displayed throughout the -// compilation process. -// // Compilation output is written to the stream specified by `os`. -mlir::LogicalResult processInputBuffer( - mlir::MLIRContext &context, std::unique_ptr buffer, - enum EntryDialect entryDialect, enum Action action, - const std::string &jitFuncName, llvm::ArrayRef jitArgs, - bool parametrizeMidlHFE, llvm::Optional overrideMaxEintPrecision, - llvm::Optional overrideMaxMANP, bool verifyDiagnostics, - bool verbose, llvm::raw_ostream &os) { - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); +mlir::LogicalResult +processInputBuffer(std::unique_ptr buffer, + enum Action action, const std::string &jitFuncName, + llvm::ArrayRef jitArgs, bool parametrizeMidlHFE, + llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP, + bool verifyDiagnostics, llvm::raw_ostream &os) { + std::shared_ptr ccx = + mlir::zamalang::CompilationContext::createShared(); - mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, - &context); - mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context); + mlir::zamalang::JitCompilerEngine ce{ccx}; - llvm::Optional fheConstraints; - llvm::Optional fheContext; + ce.setVerifyDiagnostics(verifyDiagnostics); + ce.setParametrizeMidLFHE(parametrizeMidlHFE); - std::unique_ptr keySet = nullptr; + if (overrideMaxEintPrecision.hasValue()) + ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue()); - if (verbose) - context.disableMultithreading(); + if (overrideMaxMANP.hasValue()) + ce.setMaxMANP(overrideMaxMANP.getValue()); - if (verifyDiagnostics) - return sourceMgrHandler.verify(); + if (action == Action::JIT_INVOKE) { + llvm::Expected lambdaOrErr = + ce.buildLambda(std::move(buffer), jitFuncName); - if (!moduleRef) - return mlir::failure(); - - mlir::ModuleOp module = moduleRef.get(); - - if (action == Action::ROUND_TRIP) { - module->print(os); - return mlir::success(); - } - - // Lowering pipeline. Each stage is represented as a label in the - // switch statement, from the most abstract dialect to the lowest - // level. Every labels acts as an entry point into the pipeline with - // a fallthrough mechanism to the next stage. Actions act as exit - // points from the pipeline. - switch (entryDialect) { - case EntryDialect::HLFHE: - if (action == Action::DUMP_HLFHE_MANP) { - if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false) - .failed()) { - return mlir::failure(); - } - - module.print(os); - return mlir::success(); - } else { - llvm::Expected> - fheConstraintsOrErr = - mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context, - module); - if (auto err = fheConstraintsOrErr.takeError()) { - mlir::zamalang::log_error() << err; - return mlir::failure(); - } else { - fheConstraints = fheConstraintsOrErr.get(); - } - } - - if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) - .failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::MIDLFHE: - if (action == Action::DUMP_MIDLFHE) { - module.print(os); - return mlir::success(); - } - - if (buildAssignFHEContext(fheContext, fheConstraints, - overrideMaxEintPrecision, overrideMaxMANP) - .failed()) { - return mlir::failure(); - } - - if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( - context, module, fheContext.getValue(), parametrizeMidlHFE) - .failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::LOWLFHE: - if (action == Action::DUMP_LOWLFHE) { - module.print(os); - return mlir::success(); - } - - if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::STD: - if (action == Action::DUMP_STD) { - module.print(os); - return mlir::success(); - } else if (action == Action::JIT_INVOKE) { - if (buildAssignFHEContext(fheContext, fheConstraints, - overrideMaxEintPrecision, overrideMaxMANP) - .failed()) { - return mlir::failure(); - } - - keySet = generateKeySet(module, fheContext.getValue(), jitFuncName); - } - - if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module, - verbose) - .failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::LLVM: { - if (action == Action::DUMP_LLVM_DIALECT) { - module.print(os); - return mlir::success(); - } else if (action == Action::JIT_INVOKE) { - return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet, - defaultOptPipeline, os); - } - - llvm::LLVMContext llvmContext; - std::unique_ptr llvmModule = - mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext, - module); - - if (!llvmModule) { + if (!lambdaOrErr) { mlir::zamalang::log_error() - << "Failed to translate LLVM dialect to LLVM IR\n"; + << "Failed to JIT-compile " << jitFuncName << ": " + << llvm::toString(std::move(lambdaOrErr.takeError())); return mlir::failure(); } - if (action == Action::DUMP_LLVM_IR) { - llvmModule->dump(); - return mlir::success(); - } + llvm::Expected resOrErr = (*lambdaOrErr)(jitArgs); - if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule) - .failed()) { - mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n"; + if (!resOrErr) { + mlir::zamalang::log_error() + << "Failed to JIT-invoke " << jitFuncName << " with arguments " + << jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError())); return mlir::failure(); } - if (action == Action::DUMP_OPTIMIZED_LLVM_IR) { - llvmModule->dump(); - return mlir::success(); + os << *resOrErr << "\n"; + } else { + enum mlir::zamalang::CompilerEngine::Target target; + + switch (action) { + case Action::ROUND_TRIP: + target = mlir::zamalang::CompilerEngine::Target::ROUND_TRIP; + break; + case Action::DUMP_HLFHE: + target = mlir::zamalang::CompilerEngine::Target::HLFHE; + break; + case Action::DUMP_HLFHE_MANP: + target = mlir::zamalang::CompilerEngine::Target::HLFHE_MANP; + break; + case Action::DUMP_MIDLFHE: + target = mlir::zamalang::CompilerEngine::Target::MIDLFHE; + break; + case Action::DUMP_LOWLFHE: + target = mlir::zamalang::CompilerEngine::Target::LOWLFHE; + break; + case Action::DUMP_STD: + target = mlir::zamalang::CompilerEngine::Target::STD; + break; + case Action::DUMP_LLVM_DIALECT: + target = mlir::zamalang::CompilerEngine::Target::LLVM; + break; + case Action::DUMP_LLVM_IR: + target = mlir::zamalang::CompilerEngine::Target::LLVM_IR; + break; + case Action::DUMP_OPTIMIZED_LLVM_IR: + target = mlir::zamalang::CompilerEngine::Target::OPTIMIZED_LLVM_IR; + break; + case JIT_INVOKE: + // Case just here to satisfy the compiler; already handled above + break; } - break; - } + llvm::Expected retOrErr = + ce.compile(std::move(buffer), target); + + if (!retOrErr) { + mlir::zamalang::log_error() + << llvm::toString(std::move(retOrErr.takeError())) << "\n"; + + return mlir::failure(); + } + + if (verifyDiagnostics) { + return mlir::success(); + } else if (action == Action::DUMP_LLVM_IR || + action == Action::DUMP_OPTIMIZED_LLVM_IR) { + retOrErr->llvmModule->print(os, nullptr); + } else { + retOrErr->mlirModuleRef->get().print(os); + } } return mlir::success(); @@ -459,44 +320,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { // Parse command line arguments llvm::cl::ParseCommandLineOptions(argc, argv); - // Initialize the MLIR context - mlir::MLIRContext context; - mlir::zamalang::setupLogging(cmdline::verbose); // String for error messages from library functions std::string errorMessage; - if (cmdline::action == Action::DUMP_HLFHE_MANP && - cmdline::entryDialect != EntryDialect::HLFHE) { - mlir::zamalang::log_error() - << "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs"; - return mlir::failure(); - } - - if (cmdline::action == Action::JIT_INVOKE && - cmdline::entryDialect != EntryDialect::HLFHE && - cmdline::entryDialect != EntryDialect::MIDLFHE && - cmdline::entryDialect != EntryDialect::LOWLFHE && - cmdline::entryDialect != EntryDialect::STD) { - mlir::zamalang::log_error() - << "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs"; - return mlir::failure(); - } - - // Load our Dialect in this MLIR Context. - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - - if (cmdline::verifyDiagnostics) - context.printOpOnDiagnostic(false); - auto output = mlir::openOutputFile(cmdline::output, &errorMessage); if (!output) { @@ -523,20 +351,20 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { return processInputBuffer( - context, std::move(inputBuffer), cmdline::entryDialect, - cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, + std::move(inputBuffer), cmdline::action, + cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, cmdline::verbose, os); + cmdline::verifyDiagnostics, os); }, output->os()))) return mlir::failure(); } else { return processInputBuffer( - context, std::move(file), cmdline::entryDialect, cmdline::action, - cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, + std::move(file), cmdline::action, cmdline::jitFuncName, + cmdline::jitArgs, cmdline::parametrizeMidLFHE, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, cmdline::verbose, output->os()); + cmdline::verifyDiagnostics, output->os()); } } diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir index bf08e2e90..fc460a435 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir index 49b8063a3..224270914 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir index 2b7a9a761..846572c00 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}> func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir index 23a7b3c28..9163d28f0 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir index c98974a2a..6947a3942 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe --assume-max-manp=10 --assume-max-eint-precision=2 2>&1| FileCheck %s // CHECK: #map0 = affine_map<(d0) -> (d0)> // CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir index c83bac0ba..ff156d615 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir index a34343f21..f0da29950 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 23d8585ae..94892649d 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index a40f5df5a..13a8c7214 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index 93a60d527..0e6ff2534 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir index a3ec7b838..497ce0cd8 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir index 42b77bcf4..a0c63723d 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index 20ae4d5f0..5f917cd3b 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index 201a53669..e8bc3ad06 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir index 24fc2ff96..17a86d8d5 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir index 1359aaa1e..f352e0f99 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir index 302a75e0e..acbb19672 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s +// RUN: zamacompiler --split-input-file --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s func @single_zero() -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir index 0e53d8d5a..939c52c71 100644 --- a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir +++ b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // Incompatible shapes func @dot_incompatible_shapes( diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir index 6a6d4f962..800e8df69 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<8>) { diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir index bb43f441c..7e543efaa 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<0>) { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir index 1bb62e224..39b97ed31 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir index d43bc7194..5608ffdd0 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir index 205e7afe1..9e91eb794 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir index 0a8ae9a8c..79d014609 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir index 0a8d9cd48..d05921ccf 100644 --- a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir +++ b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument. func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir index 6a9e6e059..45b847b8e 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir index ee84b2a49..50f288ba1 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir index deded0859..5bf4c57af 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1 func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir index 207414189..3aa512584 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index f7653afb7..44fefdf8e 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @zero() -> !HLFHE.eint<2> func @zero() -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir index 1130181ea..573d03ac9 100644 --- a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s //CHECK: #map0 = affine_map<(d0) -> (d0)> //CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> diff --git a/compiler/tests/Dialect/HLFHE/types.mlir b/compiler/tests/Dialect/HLFHE/types.mlir index 8e6b6bc85..2a9ad463c 100644 --- a/compiler/tests/Dialect/HLFHE/types.mlir +++ b/compiler/tests/Dialect/HLFHE/types.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>> func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) { diff --git a/compiler/tests/Dialect/LowLFHE/ops.mlir b/compiler/tests/Dialect/LowLFHE/ops.mlir index b909b0473..fc80bebb5 100644 --- a/compiler/tests/Dialect/LowLFHE/ops.mlir +++ b/compiler/tests/Dialect/LowLFHE/ops.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> { diff --git a/compiler/tests/Dialect/LowLFHE/types.mlir b/compiler/tests/Dialect/LowLFHE/types.mlir index 27552cb2a..07cf87134 100644 --- a/compiler/tests/Dialect/LowLFHE/types.mlir +++ b/compiler/tests/Dialect/LowLFHE/types.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir index 55df41028..46318b97a 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter result func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir index 3d1f81407..970c22942 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir index 97ead991e..cf743fb65 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir index ba6a37313..3b7d1d510 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir index ee8a78f6d..86bb8e44b 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // Bad dimension of the lookup table func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir index bfb720643..e42502600 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir index f21873208..4c3a2d238 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir index ae9daa983..3c4a01957 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir index 0903aeb00..debf54b4d 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir index 47fab99a6..47b850daa 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir index b66236c76..974fb3cc6 100644 --- a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=roundtrip 2>&1| FileCheck %s // CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 6e8bfb33a..54d7cdac3 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -56,7 +56,7 @@ def test_compile_and_run(mlir_input, args, expected_result): def test_compile_and_run_invalid_arg_number(mlir_input, args): engine = CompilerEngine() engine.compile_fhe(mlir_input) - with pytest.raises(RuntimeError, match=r"failed pushing integer argument"): + with pytest.raises(ValueError, match=r"wrong number of arguments"): engine.run(*args) diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index aa7cbd7fb..91365ec64 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -1,8 +1,11 @@ +#include #include +#include #include "zamalang/Support/CompilerEngine.h" +#include "zamalang/Support/JitCompilerEngine.h" -mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7}; +mlir::zamalang::V0FHEConstraint defaultV0Constraints = {10, 7}; #define ASSERT_LLVM_ERROR(err) \ if (err) { \ @@ -10,384 +13,405 @@ mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7}; ASSERT_TRUE(false); \ } +// Checks that the value `val` is not in an error state. Returns +// `true` if the test passes, otherwise `false`. +template +static bool assert_expected_success(llvm::Expected &val) { + if (!((bool)val)) { + llvm::errs() << llvm::toString(std::move(val.takeError())); + return false; + } + + return true; +} + +// Checks that the value `val` is not in an error state. Returns +// `true` if the test passes, otherwise `false`. +template +static bool assert_expected_success(llvm::Expected &&val) { + return assert_expected_success(val); +} + +// Checks that the value `val` of type `llvm::Expected` is not in +// an error state. +#define ASSERT_EXPECTED_SUCCESS(val) \ + do { \ + if (!assert_expected_success(val)) \ + GTEST_FATAL_FAILURE_("Expected contained in error state"); \ + } while (0) + +// Checks that the value `val` is not in an error state and is equal +// to the value given in `exp`. Returns `true` if the test passes, +// otherwise `false`. +template +static bool assert_expected_value(llvm::Expected &val, const V &exp) { + if (!assert_expected_success(val)) + return false; + + if (!(val.get() == static_cast(exp))) { + llvm::errs() << "Expected value " << exp << ", but got " << val.get() + << "\n"; + return false; + } + + return true; +} + +// Checks that the value `val` is not in an error state and is equal +// to the value given in `exp`. Returns `true` if the test passes, +// otherwise `false`. +template +static bool assert_expected_value(llvm::Expected &&val, const V &exp) { + return assert_expected_value(val, exp); +} + +// Checks that the value `val` of type `llvm::Expected` is not in +// an error state and is equal to the value of type `T` given in +// `exp`. +#define ASSERT_EXPECTED_VALUE(val, exp) \ + do { \ + if (!assert_expected_value(val, exp)) { \ + GTEST_FATAL_FAILURE_("Expected with wrong value"); \ + } \ + } while (0) + +// Jit-compiles the function specified by `func` from `src` and +// returns the corresponding lambda. Any compilation errors are caught +// and reult in abnormal termination. +template +mlir::zamalang::JitCompilerEngine::Lambda +internalCheckedJit(F checkfunc, llvm::StringRef src, + llvm::StringRef func = "main", + bool useDefaultFHEConstraints = false) { + mlir::zamalang::JitCompilerEngine engine; + + if (useDefaultFHEConstraints) + engine.setFHEConstraints(defaultV0Constraints); + + llvm::Expected lambdaOrErr = + engine.buildLambda(src, func); + + checkfunc(lambdaOrErr); + + return std::move(*lambdaOrErr); +} + +// Shorthands to create integer literals of a specific type +uint8_t operator"" _u8(unsigned long long int v) { return v; } +uint16_t operator"" _u16(unsigned long long int v) { return v; } +uint32_t operator"" _u32(unsigned long long int v) { return v; } +uint64_t operator"" _u64(unsigned long long int v) { return v; } + +// Evaluates to the number of elements of a statically initialized +// array +#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0])) + +// Wrapper around `internalCheckedJit` that causes +// `ASSERT_EXPECTED_SUCCESS` to use the file and line number of the +// caller instead of `internalCheckedJit`. +#define checkedJit(...) \ + internalCheckedJit( \ + [](llvm::Expected &lambda) { \ + ASSERT_EXPECTED_SUCCESS(lambda); \ + }, \ + __VA_ARGS__) + TEST(CompileAndRunHLFHE, add_eint) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> } -)XXX"; - ASSERT_FALSE(engine.compile(mlirStr)); - auto maybeResult = engine.run({1, 2}); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - ASSERT_EQ(result, 3); +)XXX"); + + ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), 3); + ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), 9); + ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), 2); +} + +// Same as CompileAndRunHLFHE::add_eint above, but using +// `LambdaArgument` instances +TEST(CompileAndRunHLFHE, add_eint_lambda_argument) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> +} +)XXX"); + + mlir::zamalang::IntLambdaArgument<> ila1(1); + mlir::zamalang::IntLambdaArgument<> ila2(2); + mlir::zamalang::IntLambdaArgument<> ila7(7); + mlir::zamalang::IntLambdaArgument<> ila9(9); + + ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila2}), 3); + ASSERT_EXPECTED_VALUE(lambda({&ila7, &ila9}), 16); + ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila7}), 8); + ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila9}), 10); + ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9); +} + +TEST(CompileAndRunHLFHE, add_u64) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%arg0: i64, %arg1: i64) -> i64 { + %1 = addi %arg0, %arg1 : i64 + return %1: i64 +} +)XXX", + "main", true); + + ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), (uint64_t)3); + ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), (uint64_t)9); + ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), (uint64_t)2); } TEST(CompileAndRunTensorStd, extract_64) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi64>, %i: index) -> i64{ %c = tensor.extract %t[%i] : tensor<10xi64> return %c : i64 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF, - 0, - 8978, - 2587490, - 90, - 197864, - 698735, - 72132, - 87474, - 42}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint64_t t_arg[] = {0xFFFFFFFFFFFFFFFF, + 0, + 8978, + 2587490, + 90, + 197864, + 698735, + 72132, + 87474, + 42}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_32) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi32>, %i: index) -> i32{ %c = tensor.extract %t[%i] : tensor<10xi32> return %c : i32 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90, - 197864, 698735, 72132, 87474, 42}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); +)XXX", + "main", "true"); + static uint32_t t_arg[] = {0xFFFFFFFF, 0, 8978, 2587490, 90, + 197864, 698735, 72132, 87474, 42}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); +} + +// Same as `CompileAndRunTensorStd::extract_32` above, but using +// `LambdaArgument` instances +TEST(CompileAndRunTensorStd, extract_32_lambda_argument) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<10xi32>, %i: index) -> i32{ + %c = tensor.extract %t[%i] : tensor<10xi32> + return %c : i32 +} +)XXX", + "main", "true"); + static std::vector t_arg{0xFFFFFFFF, 0, 8978, 2587490, 90, + 197864, 698735, 72132, 87474, 42}; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + tla(t_arg); + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) { + mlir::zamalang::IntLambdaArgument idx(i); + ASSERT_EXPECTED_VALUE(lambda({&tla, &idx}), t_arg[i]); } } TEST(CompileAndRunTensorStd, extract_16) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi16>, %i: index) -> i16{ %c = tensor.extract %t[%i] : tensor<10xi16> return %c : i16 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227, - 63269, 36435, 52380, 7401, 13313}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + uint16_t t_arg[] = {0xFFFF, 0, 59589, 47826, 16227, + 63269, 36435, 52380, 7401, 13313}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_8) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi8>, %i: index) -> i8{ %c = tensor.extract %t[%i] : tensor<10xi8> return %c : i8 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint8_t t_arg[] = {0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi5>, %i: index) -> i5{ %c = tensor.extract %t[%i] : tensor<10xi5> return %c : i5 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_1) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi1>, %i: index) -> i1{ %c = tensor.extract %t[%i] : tensor<10xi1> return %c : i1 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint8_t t_arg[] = {0, 0, 1, 0, 1, 1, 0, 1, 1, 0}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorEncrypted, extract_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{ %c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>> return %c : !HLFHE.eint<5> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX"); + + static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( -func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{ + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> +!HLFHE.eint<5>{ %ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>> %tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>> - %c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5> - return %c : !HLFHE.eint<5> + %c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> + !HLFHE.eint<5> return %c : !HLFHE.eint<5> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - for (size_t i = 0; i < size; i++) { - for (size_t j = 0; j < size; j++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Set the %j argument - ASSERT_LLVM_ERROR(argument->setArg(2, j)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i] + t_arg[j]); - } - } +)XXX"); + + static uint8_t t_arg[] = {3, 0, 7, 12, 14, 6, 5, 4, 1, 2}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + for (size_t j = 0; j < ARRAY_SIZE(t_arg); j++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i, j), + t_arg[i] + t_arg[j]); } TEST(CompileAndRunTensorEncrypted, dim_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{ %c0 = constant 0 : index %c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>> return %c : index } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, size); +)XXX"); + + static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg)), ARRAY_SIZE(t_arg)); } TEST(CompileAndRunTensorEncrypted, from_elements_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> { %t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>> return %t: tensor<1x!HLFHE.eint<5>> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, 10)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - size_t size_res = 1; - uint64_t t_res[size_res]; - ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res)); - ASSERT_EQ(t_res[0], 10); +)XXX"); + + llvm::Expected> res = + lambda.operator()>(10_u64); + + ASSERT_EXPECTED_SUCCESS(res); + ASSERT_EQ(res->size(), (size_t)1); + ASSERT_EQ(res->at(0), 10_u64); } TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> { %c_0 = constant 0 : index %c_1 = constant 1 : index %a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>> %b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>> - %aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) - %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) - %bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) - %out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>> + %aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> + (!HLFHE.eint<5>) %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, + !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %bplusb = "HLFHE.add_eint"(%b, %b): + (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %out = + tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>> return %out: tensor<3x!HLFHE.eint<5>> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the argument - const size_t in_size = 2; - uint8_t in[in_size] = {2, 16}; - ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - const size_t size_res = 3; - uint64_t t_res[size_res]; - ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res)); - ASSERT_EQ(t_res[0], in[0] + in[0]); - ASSERT_EQ(t_res[1], in[0] + in[1]); - ASSERT_EQ(t_res[2], in[1] + in[1]); +)XXX"); + + static uint8_t in[] = {2, 16}; + + llvm::Expected> res = + lambda.operator()>(in, ARRAY_SIZE(in)); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (size_t)3); + ASSERT_EQ(res->at(0), (uint64_t)(in[0] + in[0])); + ASSERT_EQ(res->at(1), (uint64_t)(in[0] + in[1])); + ASSERT_EQ(res->at(2), (uint64_t)(in[1] + in[1])); } TEST(CompileAndRunTensorEncrypted, linalg_generic) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> (0)> -func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> { +func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: +!HLFHE.eint<7>) -> !HLFHE.eint<7> { %tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>> - %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) { - ^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors - %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7> - %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7> - linalg.yield %5 : !HLFHE.eint<7> + %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types + = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) + outs(%tacc : tensor<1x!HLFHE.eint<7>>) { ^bb0(%arg2: !HLFHE.eint<7>, %arg3: + i8, %arg4: !HLFHE.eint<7>): // no predecessors + %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> + !HLFHE.eint<7> %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, + !HLFHE.eint<7>) -> !HLFHE.eint<7> linalg.yield %5 : !HLFHE.eint<7> } -> tensor<1x!HLFHE.eint<7>> %c0 = constant 0 : index %ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>> return %ret : !HLFHE.eint<7> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set arg0, arg1, acc - const size_t in_size = 2; - uint8_t arg0[in_size] = {2, 8}; - ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size)); - uint8_t arg1[in_size] = {6, 8}; - ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size)); - ASSERT_LLVM_ERROR(argument->setArg(2, 0)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, 76); +)XXX", + "main", "true"); + + static uint8_t arg0[] = {2, 8}; + static uint8_t arg1[] = {6, 8}; + + llvm::Expected res = + lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1), 0_u64); + + ASSERT_EXPECTED_VALUE(res, 76); } TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> { @@ -395,77 +419,70 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>, (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> return %ret : !HLFHE.eint<7> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set arg0, arg1, acc - const size_t in_size = 4; - uint8_t arg0[in_size] = {0, 1, 2, 3}; - ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size)); - uint8_t arg1[in_size] = {0, 1, 2, 3}; - ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, 14); +)XXX"); + static uint8_t arg0[] = {0, 1, 2, 3}; + static uint8_t arg1[] = {0, 1, 2, 3}; + + llvm::Expected res = + lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1)); + + ASSERT_EXPECTED_VALUE(res, 14); } -class CompileAndRunWithPrecision : public ::testing::TestWithParam { -protected: - mlir::zamalang::CompilerEngine engine; - void compile(std::string mlirStr) { ASSERT_FALSE(engine.compile(mlirStr)); } - void run(std::vector args, uint64_t expected) { - auto maybeResult = engine.run(args); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - if (result == expected) { - ASSERT_TRUE(true); - } else { - // TODO: Better way to test the probability of exactness - llvm::errs() << "one fail retry\n"; - maybeResult = engine.run(args); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, expected); - } - } -}; +class CompileAndRunWithPrecision : public ::testing::TestWithParam {}; TEST_P(CompileAndRunWithPrecision, identity_func) { - int precision = GetParam(); + uint64_t precision = GetParam(); std::ostringstream mlirProgram; - auto sizeOfTLU = 1 << precision; - mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision - << ">) -> !HLFHE.eint<" << precision << "> { \n"; - mlirProgram << " %tlu = std.constant dense<[0"; - for (auto i = 1; i < sizeOfTLU; i++) { - mlirProgram << ", " << i; - } - mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n"; - mlirProgram << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): " - "(!HLFHE.eint<" - << precision << ">, tensor<" << sizeOfTLU - << "xi64>) -> (!HLFHE.eint<" << precision << ">)\n "; - mlirProgram << "return %1: !HLFHE.eint<" << precision << ">\n"; + uint64_t sizeOfTLU = 1 << precision; - mlirProgram << "}\n"; - llvm::errs() << mlirProgram.str(); - compile(mlirProgram.str()); - for (auto i = 0; i < sizeOfTLU; i++) { - run({(uint64_t)i}, i); + mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision + << ">) -> !HLFHE.eint<" << precision << "> { \n" + << " %tlu = std.constant dense<[0"; + + for (uint64_t i = 1; i < sizeOfTLU; i++) + mlirProgram << ", " << i; + + mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n" + << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): " + << "(!HLFHE.eint<" << precision << ">, tensor<" << sizeOfTLU + << "xi64>) -> (!HLFHE.eint<" << precision << ">)\n " + << "return %1: !HLFHE.eint<" << precision << ">\n" + << "}\n"; + + mlir::zamalang::JitCompilerEngine::Lambda lambda = + checkedJit(mlirProgram.str()); + + if (precision == 7) { + // Test fails with a probability of 5% for a precision of 7. The + // probability of the test failing 5 times in a row is .05^5, + // which is less than 1:10,000 and comparable to the probability + // of failure for the other values. + static const int max_tries = 3; + + for (uint64_t i = 0; i < sizeOfTLU; i++) { + for (int retry = 0; retry <= max_tries; retry++) { + if (retry == max_tries) + GTEST_FATAL_FAILURE_("Maximum number of tries exceeded"); + + llvm::Expected val = lambda(i); + ASSERT_EXPECTED_SUCCESS(val); + + if (*val == i) + break; + } + } + } else { + for (uint64_t i = 0; i < sizeOfTLU; i++) + ASSERT_EXPECTED_VALUE(lambda(i), i); } } -INSTANTIATE_TEST_CASE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision, - ::testing::Values(1, 2, 3, 4, 5, 6, 7)); +INSTANTIATE_TEST_SUITE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision, + ::testing::Values(1, 2, 3, 4, 5, 6, 7)); TEST(TestHLFHEApplyLookupTable, multiple_precision) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> { %tlu_7 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64> %tlu_3 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64> @@ -474,45 +491,22 @@ func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> { %a_plus_b = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<6>, !HLFHE.eint<6>) -> (!HLFHE.eint<6>) return %a_plus_b: !HLFHE.eint<6> } -)XXX"; - ASSERT_FALSE(engine.compile(mlirStr)); - uint64_t arg0 = 23; - uint64_t arg1 = 7; - uint64_t expected = 30; - auto maybeResult = engine.run({arg0, arg1}); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - ASSERT_EQ(result, expected); +)XXX"); + + ASSERT_EXPECTED_VALUE(lambda(23_u64, 7_u64), 30); } TEST(CompileAndRunTLU, random_func) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: !HLFHE.eint<6>) -> !HLFHE.eint<6> { %tlu = std.constant dense<[16, 91, 16, 83, 80, 74, 21, 96, 1, 63, 49, 122, 76, 89, 74, 55, 109, 110, 103, 54, 105, 14, 66, 47, 52, 89, 7, 10, 73, 44, 119, 92, 25, 104, 123, 100, 108, 86, 29, 121, 118, 52, 107, 48, 34, 37, 13, 122, 107, 48, 74, 59, 96, 36, 50, 55, 120, 72, 27, 45, 12, 5, 96, 12]> : tensor<64xi64> %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<6>, tensor<64xi64>) -> (!HLFHE.eint<6>) return %1: !HLFHE.eint<6> } -)XXX"; - ASSERT_FALSE(engine.compile(mlirStr)); - // first value - auto maybeResult = engine.run({5}); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - ASSERT_EQ(result, 74); - // second value - maybeResult = engine.run({62}); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, 96); - // edge value low - maybeResult = engine.run({0}); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, 16); - // edge value high - maybeResult = engine.run({63}); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, 12); +)XXX"); + + ASSERT_EXPECTED_VALUE(lambda(5_u64), 74); + ASSERT_EXPECTED_VALUE(lambda(62_u64), 96); + ASSERT_EXPECTED_VALUE(lambda(0_u64), 16); + ASSERT_EXPECTED_VALUE(lambda(63_u64), 12); } From 41cba6311398e64b17aabe677e66be96df953d53 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 22 Oct 2021 09:41:23 +0200 Subject: [PATCH 17/19] refactor(compiler): Move the keyset generation from CompilerEngine to JitCompilerEngine --- .../include/zamalang/Support/CompilerEngine.h | 8 ++------ compiler/lib/Support/CompilerEngine.cpp | 18 ------------------ compiler/lib/Support/JitCompilerEngine.cpp | 17 +++++++++++++++-- 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index e9c496b1a..a2eb2ed28 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -8,7 +8,6 @@ #include #include #include -#include namespace mlir { namespace zamalang { @@ -43,7 +42,6 @@ public: llvm::Optional mlirModuleRef; llvm::Optional clientParameters; - std::unique_ptr keySet; std::unique_ptr llvmModule; llvm::Optional fheContext; @@ -93,8 +91,8 @@ public: CompilerEngine(std::shared_ptr compilationContext) : overrideMaxEintPrecision(), overrideMaxMANP(), clientParametersFuncName(), verifyDiagnostics(false), - generateKeySet(false), generateClientParameters(false), - parametrizeMidLFHE(true), compilationContext(compilationContext) {} + generateClientParameters(false), parametrizeMidLFHE(true), + compilationContext(compilationContext) {} llvm::Expected compile(llvm::StringRef s, Target target); @@ -107,7 +105,6 @@ public: void setMaxEintPrecision(size_t v); void setMaxMANP(size_t v); void setVerifyDiagnostics(bool v); - void setGenerateKeySet(bool v); void setGenerateClientParameters(bool v); void setParametrizeMidLFHE(bool v); void setClientParametersFuncName(const llvm::StringRef &name); @@ -117,7 +114,6 @@ protected: llvm::Optional overrideMaxMANP; llvm::Optional clientParametersFuncName; bool verifyDiagnostics; - bool generateKeySet; bool generateClientParameters; bool parametrizeMidLFHE; diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 50d98b721..6224569a5 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -74,8 +74,6 @@ void CompilerEngine::setVerifyDiagnostics(bool v) { this->verifyDiagnostics = v; } -void CompilerEngine::setGenerateKeySet(bool v) { this->generateKeySet = v; } - void CompilerEngine::setGenerateClientParameters(bool v) { this->generateClientParameters = v; } @@ -349,22 +347,6 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) { res.clientParameters = clientParametersOrErr.get(); } - // Generate Key set if requested - if (this->generateKeySet) { - if (!res.clientParameters.hasValue()) { - return StreamStringError("Generation of keyset requested without request " - "for generation of client parameters"); - } - - llvm::Expected> keySetOrErr = - mlir::zamalang::KeySet::generate(*res.clientParameters, 0, 0); - - if (auto err = keySetOrErr.takeError()) - return std::move(err); - - res.keySet = std::move(*keySetOrErr); - } - // MLIR canonical dialects -> LLVM Dialect if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module, false) diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index 05359ac4f..95626182d 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -64,7 +64,6 @@ llvm::Expected JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); - this->setGenerateKeySet(true); this->setGenerateClientParameters(true); this->setClientParametersFuncName(funcName); @@ -95,11 +94,25 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { llvm::Expected> lambdaOrErr = mlir::zamalang::JITLambda::create(funcName, module, optPipeline); + // Generate the KeySet for encrypting lambda arguments, decrypting lambda + // results + if (!compResOrErr->clientParameters.hasValue()) { + return StreamStringError("Cannot generate the keySet since client " + "parameters has not been computed"); + } + + llvm::Expected> keySetOrErr = + mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0); + + if (auto err = keySetOrErr.takeError()) + return std::move(err); + if (!lambdaOrErr) return std::move(lambdaOrErr.takeError()); return Lambda{this->compilationContext, std::move(lambdaOrErr.get()), - std::move(compResOrErr->keySet)}; + std::move(*keySetOrErr)}; } + } // namespace zamalang } // namespace mlir From 85d102c9b243a01d69d813cdc1fe2a0305934d62 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 22 Oct 2021 15:08:26 +0200 Subject: [PATCH 18/19] refactor(compiler): Simplify the compiler flow and re enable --passes compiler option No more need to compute the fhe context at high level --- .../include/zamalang/Support/CompilerEngine.h | 22 +- compiler/include/zamalang/Support/Pipeline.h | 36 +-- compiler/include/zamalang/Support/logging.h | 1 + compiler/lib/Support/CompilerEngine.cpp | 238 +++++------------- compiler/lib/Support/JitCompilerEngine.cpp | 2 +- compiler/lib/Support/Pipeline.cpp | 167 ++++++------ compiler/lib/Support/logging.cpp | 1 + compiler/src/main.cpp | 31 ++- .../Conversion/HLFHEToMidLFHE/add_eint.mlir | 2 +- .../HLFHEToMidLFHE/add_eint_int.mlir | 2 +- .../HLFHEToMidLFHE/apply_univariate.mlir | 2 +- .../HLFHEToMidLFHE/apply_univariate_cst.mlir | 2 +- .../HLFHEToMidLFHE/linalg_generic.mlir | 2 +- .../HLFHEToMidLFHE/mul_eint_int.mlir | 2 +- .../HLFHEToMidLFHE/sub_int_eint.mlir | 2 +- .../LowLFHEToConcreteCAPI/bootstrap.mlir | 2 +- .../glwe_from_table.mlir | 2 +- .../LowLFHEToConcreteCAPI/keyswitch_lwe.mlir | 2 +- .../Conversion/MidLFHEToLowLFHE/add_glwe.mlir | 2 +- .../MidLFHEToLowLFHE/add_glwe_int.mlir | 2 +- .../MidLFHEToLowLFHE/apply_lookup_table.mlir | 2 +- .../apply_lookup_table_cst.mlir | 2 +- .../MidLFHEToLowLFHE/mul_glwe_int.mlir | 2 +- .../MidLFHEToLowLFHE/sub_int_glwe.mlir | 2 +- .../tests/Dialect/HLFHE/Analysis/MANP.mlir | 2 +- 25 files changed, 200 insertions(+), 332 deletions(-) diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index a2eb2ed28..db2a3543e 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -57,10 +58,6 @@ public: // Read sources and exit before any lowering HLFHE, - // Read sources and attempt to run the Minimal Arithmetic Noise - // Padding pass - HLFHE_MANP, - // Read sources and lower all HLFHE operations to MidLFHE // operations MIDLFHE, @@ -91,7 +88,8 @@ public: CompilerEngine(std::shared_ptr compilationContext) : overrideMaxEintPrecision(), overrideMaxMANP(), clientParametersFuncName(), verifyDiagnostics(false), - generateClientParameters(false), parametrizeMidLFHE(true), + generateClientParameters(false), + enablePass([](mlir::Pass *pass) { return true; }), compilationContext(compilationContext) {} llvm::Expected compile(llvm::StringRef s, Target target); @@ -106,8 +104,8 @@ public: void setMaxMANP(size_t v); void setVerifyDiagnostics(bool v); void setGenerateClientParameters(bool v); - void setParametrizeMidLFHE(bool v); void setClientParametersFuncName(const llvm::StringRef &name); + void setEnablePass(std::function enablePass); protected: llvm::Optional overrideMaxEintPrecision; @@ -115,18 +113,14 @@ protected: llvm::Optional clientParametersFuncName; bool verifyDiagnostics; bool generateClientParameters; - bool parametrizeMidLFHE; + std::function enablePass; std::shared_ptr compilationContext; - // Helper enum identifying an FHE dialect (`HLFHE`, `MIDLFHE`, `LOWLFHE`) - // or indicating that no FHE dialect is used (`NONE`). - enum class FHEDialect { HLFHE, MIDLFHE, LOWLFHE, NONE }; - static FHEDialect detectHighestFHEDialect(mlir::ModuleOp module); - private: - llvm::Error lowerParamDependentHalf(Target target, CompilationResult &res); - llvm::Error determineFHEParameters(CompilationResult &res, bool noOverride); + llvm::Expected> + getV0FHEConstraint(CompilationResult &res); + llvm::Error determineFHEParameters(CompilationResult &res); }; } // namespace zamalang diff --git a/compiler/include/zamalang/Support/Pipeline.h b/compiler/include/zamalang/Support/Pipeline.h index bdd921a0b..0a94e4f70 100644 --- a/compiler/include/zamalang/Support/Pipeline.h +++ b/compiler/include/zamalang/Support/Pipeline.h @@ -4,43 +4,43 @@ #include #include #include +#include + #include namespace mlir { namespace zamalang { namespace pipeline { -mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context, - mlir::ModuleOp &module, bool debug); - llvm::Expected> -getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module); +getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); -mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, - mlir::ModuleOp &module, bool verbose); +mlir::LogicalResult +lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); -mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, - mlir::ModuleOp &module, - V0FHEContext &fheContext, - bool parametrize); +mlir::LogicalResult +lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + llvm::Optional &fheContext, + std::function enablePass); -mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context, - mlir::ModuleOp &module); +mlir::LogicalResult +lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); -mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, - mlir::ModuleOp &module, bool verbose); +mlir::LogicalResult +lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, llvm::Module &module); -mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context, - mlir::ModuleOp &module, - V0FHEContext &fheContext, bool verbose); - std::unique_ptr lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context, llvm::LLVMContext &llvmContext, mlir::ModuleOp &module); + } // namespace pipeline } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/logging.h b/compiler/include/zamalang/Support/logging.h index 426381bcb..6779cd0e1 100644 --- a/compiler/include/zamalang/Support/logging.h +++ b/compiler/include/zamalang/Support/logging.h @@ -32,6 +32,7 @@ private: StreamWrap &log_error(void); StreamWrap &log_verbose(void); void setupLogging(bool verbose); +bool isVerbose(); } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 6224569a5..f05f2a909 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -82,180 +82,61 @@ void CompilerEngine::setMaxEintPrecision(size_t v) { this->overrideMaxEintPrecision = v; } -void CompilerEngine::setParametrizeMidLFHE(bool v) { - this->parametrizeMidLFHE = v; -} - void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; } void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) { this->clientParametersFuncName = name.str(); } -// Helper function detecting the FHE dialect with the highest level of -// abstraction used in `module`. If no FHE dialect is used, the -// function returns `CompilerEngine::FHEDialect::NONE`. -CompilerEngine::FHEDialect -CompilerEngine::detectHighestFHEDialect(mlir::ModuleOp module) { - CompilerEngine::FHEDialect highestDialect = CompilerEngine::FHEDialect::NONE; - - mlir::TypeID hlfheID = - mlir::TypeID::get(); - mlir::TypeID midlfheID = - mlir::TypeID::get(); - mlir::TypeID lowlfheID = - mlir::TypeID::get(); - - // Helper lambda updating the currently highest dialect if necessary - // by dialect type ID - auto updateDialectFromDialectID = [&](mlir::TypeID dialectID) { - if (dialectID == hlfheID) { - highestDialect = CompilerEngine::FHEDialect::HLFHE; - return true; - } else if (dialectID == lowlfheID && - highestDialect == CompilerEngine::FHEDialect::NONE) { - highestDialect = CompilerEngine::FHEDialect::LOWLFHE; - } else if (dialectID == midlfheID && - (highestDialect == CompilerEngine::FHEDialect::NONE || - highestDialect == CompilerEngine::FHEDialect::LOWLFHE)) { - highestDialect = CompilerEngine::FHEDialect::MIDLFHE; - } - - return false; - }; - - // Helper lambda updating the currently highest dialect if necessary - // by value type - std::function updateDialectFromType = - [&](mlir::Type ty) -> bool { - if (updateDialectFromDialectID(ty.getDialect().getTypeID())) - return true; - - if (mlir::TensorType tensorTy = ty.dyn_cast_or_null()) - return updateDialectFromType(tensorTy.getElementType()); - - return false; - }; - - module.walk([&](mlir::Operation *op) { - // Check operation itself - if (updateDialectFromDialectID(op->getDialect()->getTypeID())) - return mlir::WalkResult::interrupt(); - - // Check types of operands - for (mlir::Value operand : op->getOperands()) { - if (updateDialectFromType(operand.getType())) - return mlir::WalkResult::interrupt(); - } - - // Check types of results - for (mlir::Value res : op->getResults()) { - if (updateDialectFromType(res.getType())) { - return mlir::WalkResult::interrupt(); - } - } - - return mlir::WalkResult::advance(); - }); - - return highestDialect; +void CompilerEngine::setEnablePass( + std::function enablePass) { + this->enablePass = enablePass; } -// Sets the FHE parameters of `res` either through autodetection or -// fixed constraints provided in -// `CompilerEngine::overrideMaxEintPrecision` and -// `CompilerEngine::overrideMaxMANP`. -// -// Autodetected values can be partially or fully overridden through -// `CompilerEngine::overrideMaxEintPrecision` and -// `CompilerEngine::overrideMaxMANP`. -// -// If `noOverrideAutodetected` is true, autodetected values are not -// overriden and used directly for `res`. -// -// Return an error if autodetection fails. -llvm::Error -CompilerEngine::determineFHEParameters(CompilationResult &res, - bool noOverrideAutodetected) { +// Returns the overwritten V0FHEConstraint or try to compute them from HLFHE +llvm::Expected> +CompilerEngine::getV0FHEConstraint(CompilationResult &res) { mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); mlir::ModuleOp module = res.mlirModuleRef->get(); llvm::Optional fheConstraints; - - // Determine FHE constraints either through autodetection or through - // overridden values + // If the values has been overwritten returns if (this->overrideMaxEintPrecision.hasValue() && - this->overrideMaxMANP.hasValue() && !noOverrideAutodetected) { - fheConstraints.emplace(mlir::zamalang::V0FHEConstraint{ + this->overrideMaxMANP.hasValue()) { + return mlir::zamalang::V0FHEConstraint{ this->overrideMaxMANP.getValue(), - this->overrideMaxEintPrecision.getValue()}); - - } else { - llvm::Expected> - fheConstraintsOrErr = - mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(mlirContext, - module); - - if (auto err = fheConstraintsOrErr.takeError()) - return std::move(err); - - if (!fheConstraintsOrErr.get().hasValue()) { - return StreamStringError("Could not determine maximum required precision " - "for encrypted integers and maximum value for " - "the Minimal Arithmetic Noise Padding"); - } - - if (noOverrideAutodetected) - return llvm::Error::success(); - - fheConstraints = fheConstraintsOrErr.get(); - - // Override individual values if requested - if (this->overrideMaxEintPrecision.hasValue()) - fheConstraints->p = this->overrideMaxEintPrecision.getValue(); - - if (this->overrideMaxMANP.hasValue()) - fheConstraints->norm2 = this->overrideMaxMANP.getValue(); + this->overrideMaxEintPrecision.getValue()}; } + // Else compute constraint from HLFHE + llvm::Expected> + fheConstraintsOrErr = + mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE( + mlirContext, module, enablePass); + if (auto err = fheConstraintsOrErr.takeError()) + return std::move(err); + + return fheConstraintsOrErr.get(); +} + +// set the fheContext field if the v0Constraint can be computed +llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { + auto fheConstraintOrErr = getV0FHEConstraint(res); + if (auto err = fheConstraintOrErr.takeError()) + return std::move(err); + if (!fheConstraintOrErr.get().hasValue()) { + return llvm::Error::success(); + } const mlir::zamalang::V0Parameter *fheParams = - getV0Parameter(fheConstraints.getValue()); + getV0Parameter(fheConstraintOrErr.get().getValue()); if (!fheParams) { return StreamStringError() << "Could not determine V0 parameters for 2-norm of " - << fheConstraints->norm2 << " and p of " << fheConstraints->p; - } - - res.fheContext.emplace( - mlir::zamalang::V0FHEContext{*fheConstraints, *fheParams}); - - return llvm::Error::success(); -} - -// Performs all lowering from HLFHE to the FHE dialect with the lwoest -// level of abstraction that requires FHE parameters. -// -// Returns an error if any of the lowerings fails. -llvm::Error CompilerEngine::lowerParamDependentHalf(Target target, - CompilationResult &res) { - mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); - mlir::ModuleOp module = res.mlirModuleRef->get(); - - // HLFHE -> MidLFHE - if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module, false) - .failed()) { - return StreamStringError("Lowering from HLFHE to MidLFHE failed"); - } - - if (target == Target::MIDLFHE) - return llvm::Error::success(); - - // MidLFHE -> LowLFHE - if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( - mlirContext, module, *res.fheContext, this->parametrizeMidLFHE) - .failed()) { - return StreamStringError("Lowering from MidLFHE to LowLFHE failed"); + << (*fheConstraintOrErr)->norm2 << " and p of " + << (*fheConstraintOrErr)->p; } + res.fheContext.emplace(mlir::zamalang::V0FHEContext{ + (*fheConstraintOrErr).getValue(), *fheParams}); return llvm::Error::success(); } @@ -289,43 +170,40 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) { res.mlirModuleRef = std::move(mlirModuleRef); mlir::ModuleOp module = res.mlirModuleRef->get(); - if (target == Target::HLFHE || target == Target::ROUND_TRIP) + if (target == Target::ROUND_TRIP) return res; - // Detect highest FHE dialect and check if FHE parameter - // autodetection / lowering of parameter-dependent dialects can be - // skipped - FHEDialect highestFHEDialect = this->detectHighestFHEDialect(module); - - if (highestFHEDialect == FHEDialect::HLFHE || - highestFHEDialect == FHEDialect::MIDLFHE || - this->generateClientParameters) { - bool noOverrideAutoDetected = (target == Target::HLFHE_MANP); - if (auto err = this->determineFHEParameters(res, noOverrideAutoDetected)) - return std::move(err); - } - - // return early if only the MANP pass was requested - if (target == Target::HLFHE_MANP) + // HLFHE High level pass to determine FHE parameters + if (auto err = this->determineFHEParameters(res)) + return std::move(err); + if (target == Target::HLFHE) return res; - if (highestFHEDialect == FHEDialect::HLFHE || - highestFHEDialect == FHEDialect::MIDLFHE) { - if (llvm::Error err = this->lowerParamDependentHalf(target, res)) - return std::move(err); + // HLFHE -> MidLFHE + if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module, + enablePass) + .failed()) { + return StreamStringError("Lowering from HLFHE to MidLFHE failed"); } + if (target == Target::MIDLFHE) + return res; - if (target == Target::HLFHE_MANP || target == Target::MIDLFHE || - target == Target::LOWLFHE) + // MidLFHE -> LowLFHE + if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( + mlirContext, module, res.fheContext, this->enablePass) + .failed()) { + return StreamStringError("Lowering from MidLFHE to LowLFHE failed"); + } + if (target == Target::LOWLFHE) return res; // LowLFHE -> Canonical dialects - if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module) + if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module, + enablePass) .failed()) { return StreamStringError( "Lowering from LowLFHE to canonical MLIR dialects failed"); } - if (target == Target::STD) return res; @@ -336,6 +214,10 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) { "Generation of client parameters requested, but no function name " "specified"); } + if (!res.fheContext.hasValue()) { + return StreamStringError( + "Cannot generate client parameters, the fhe context is empty"); + } llvm::Expected clientParametersOrErr = mlir::zamalang::createClientParametersForV0( @@ -349,7 +231,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target) { // MLIR canonical dialects -> LLVM Dialect if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module, - false) + enablePass) .failed()) { return StreamStringError("Failed to lower to LLVM dialect"); } diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index 95626182d..b1d8deef9 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -88,7 +88,7 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { llvm::InitializeNativeTargetAsmPrinter(); mlir::registerLLVMDialectTranslation(mlirContext); - std::function optPipeline = + llvm::function_ref optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); llvm::Expected> lambdaOrErr = diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 64880551f..e367d1f5a 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -19,8 +19,30 @@ namespace mlir { namespace zamalang { namespace pipeline { -static void addPotentiallyNestedPass(mlir::PassManager &pm, - std::unique_ptr pass) { + +static void pipelinePrinting(llvm::StringRef name, mlir::PassManager &pm, + mlir::MLIRContext &ctx) { + if (mlir::zamalang::isVerbose()) { + mlir::zamalang::log_verbose() + << "##################################################\n" + << "### " << name << " pipeline\n"; + auto isModule = [](mlir::Pass *, mlir::Operation *op) { + return mlir::isa(op); + }; + ctx.disableMultithreading(true); + pm.enableIRPrinting(isModule, isModule); + pm.enableStatistics(); + pm.enableTiming(); + pm.enableVerifier(); + } +} + +static void +addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr pass, + std::function enablePass) { + if (!enablePass(pass.get())) { + return; + } if (!pass->getOpName() || *pass->getOpName() == "builtin.module") { pm.addPass(std::move(pass)); } else { @@ -29,26 +51,20 @@ static void addPotentiallyNestedPass(mlir::PassManager &pm, } } -// Creates an instance of the Minimal Arithmetic Noise Padding pass -// and invokes it for all functions of `module`. -mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context, - mlir::ModuleOp &module, bool debug) { - mlir::PassManager pm(&context); - pm.addNestedPass(mlir::zamalang::createMANPPass(debug)); - return pm.run(module); -} - llvm::Expected> -getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) { +getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { llvm::Optional oMax2norm; llvm::Optional oMaxWidth; mlir::PassManager pm(&context); - addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass()); + pipelinePrinting("ComputeFHEConstraintOnHLFHE", pm, context); + addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass(), enablePass); addPotentiallyNestedPass( - pm, mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP, - unsigned currMaxWidth) { + pm, + mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP, + unsigned currMaxWidth) { assert((uint64_t)currMaxWidth < std::numeric_limits::max() && "Maximum width does not fit into size_t"); @@ -64,15 +80,14 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) { if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width) oMaxWidth.emplace(width); - })); - + }), + enablePass); if (pm.run(module.getOperation()).failed()) { return llvm::make_error( "Failed to determine the maximum Arithmetic Noise Padding and maximum" "required precision", llvm::inconvertibleErrorCode()); } - llvm::Optional ret; if (oMax2norm.hasValue() && oMaxWidth.hasValue()) { @@ -84,86 +99,76 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) { return ret; } -mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, - mlir::ModuleOp &module, bool verbose) { +mlir::LogicalResult +lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { mlir::PassManager pm(&context); + pipelinePrinting("HLFHEToMidLFHE", pm, context); - if (verbose) { - mlir::zamalang::log_verbose() - << "##################################################\n" - << "### HLFHE to MidLFHE pipeline\n"; + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass); + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass); - pm.enableIRPrinting(); - pm.enableStatistics(); - pm.enableTiming(); - pm.enableVerifier(); + return pm.run(module.getOperation()); +} + +mlir::LogicalResult +lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + llvm::Optional &fheContext, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("MidLFHEToLowLFHE", pm, context); + + if (fheContext.hasValue()) { + addPotentiallyNestedPass( + pm, + mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass( + fheContext.getValue()), + enablePass); } addPotentiallyNestedPass( - pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg()); - addPotentiallyNestedPass(pm, - mlir::zamalang::createConvertHLFHEToMidLFHEPass()); + pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass); return pm.run(module.getOperation()); } -mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, - mlir::ModuleOp &module, - V0FHEContext &fheContext, - bool parametrize) { - mlir::PassManager pm(&context); - - if (parametrize) { - addPotentiallyNestedPass( - pm, mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass( - fheContext)); - } - - addPotentiallyNestedPass(pm, - mlir::zamalang::createConvertMidLFHEToLowLFHEPass()); - - return pm.run(module.getOperation()); -} - -mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context, - mlir::ModuleOp &module) { +mlir::LogicalResult +lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { mlir::PassManager pm(&context); + pipelinePrinting("LowLFHEToStd", pm, context); pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()); return pm.run(module.getOperation()); } -mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, - mlir::ModuleOp &module, - bool verbose) { +mlir::LogicalResult +lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { mlir::PassManager pm(&context); - - if (verbose) { - mlir::zamalang::log_verbose() - << "##################################################\n" - << "### MlirStdsDialectToMlirLLVMDialect pipeline\n"; - context.disableMultithreading(); - pm.enableIRPrinting(); - pm.enableStatistics(); - pm.enableTiming(); - pm.enableVerifier(); - } + pipelinePrinting("StdToLLVM", pm, context); // Unparametrize LowLFHE addPotentiallyNestedPass( - pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass()); + pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(), enablePass); // Bufferize - addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass()); - addPotentiallyNestedPass(pm, mlir::createStdBufferizePass()); - addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass()); - addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass()); - addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass()); - addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass()); - addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass(), + enablePass); + addPotentiallyNestedPass(pm, mlir::createStdBufferizePass(), enablePass); + addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass(), enablePass); + addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass(), enablePass); + addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(), + enablePass); + addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass); + addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(), + enablePass); // Convert to MLIR LLVM Dialect addPotentiallyNestedPass( - pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()); + pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), + enablePass); return pm.run(module); } @@ -181,7 +186,7 @@ lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context, mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, llvm::Module &module) { - std::function optPipeline = + llvm::function_ref optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); if (optPipeline(&module)) @@ -190,18 +195,6 @@ mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, return mlir::success(); } -mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context, - mlir::ModuleOp &module, - V0FHEContext &fheContext, bool verbose) { - if (lowerHLFHEToMidLFHE(context, module, verbose).failed() || - lowerMidLFHEToLowLFHE(context, module, fheContext, true).failed() || - lowerLowLFHEToStd(context, module).failed()) { - return mlir::failure(); - } else { - return mlir::success(); - } -} - } // namespace pipeline } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Support/logging.cpp b/compiler/lib/Support/logging.cpp index 0f800c459..7a6b27651 100644 --- a/compiler/lib/Support/logging.cpp +++ b/compiler/lib/Support/logging.cpp @@ -18,5 +18,6 @@ StreamWrap &log_verbose(void) { // Sets up logging. If `verbose` is false, messages passed to // `log_verbose` will be discarded. void setupLogging(bool verbose) { ::mlir::zamalang::verbose = verbose; } +bool isVerbose() { return verbose; } } // namespace zamalang } // namespace mlir diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 1e81f2803..1c8aada22 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -32,7 +32,6 @@ enum Action { ROUND_TRIP, DUMP_HLFHE, - DUMP_HLFHE_MANP, DUMP_MIDLFHE, DUMP_LOWLFHE, DUMP_STD, @@ -76,10 +75,10 @@ llvm::cl::opt output("o", llvm::cl::opt verbose("verbose", llvm::cl::desc("verbose logs"), llvm::cl::init(false)); -llvm::cl::opt parametrizeMidLFHE( - "parametrize-midlfhe", - llvm::cl::desc("Perform MidLFHE global parametrization pass"), - llvm::cl::init(true)); +llvm::cl::list passes( + "passes", + llvm::cl::desc("Specify the passes to run (use only for compiler tests)"), + llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore); static llvm::cl::opt action( "a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired, @@ -87,9 +86,6 @@ static llvm::cl::opt action( llvm::cl::values( clEnumValN(Action::ROUND_TRIP, "roundtrip", "Parse input module and regenerate textual representation")), - llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp", - "Dump HLFHE module after running the Minimal " - "Arithmetic Noise Padding pass")), llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe", "Dump HLFHE module")), llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe", @@ -218,7 +214,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, mlir::LogicalResult processInputBuffer(std::unique_ptr buffer, enum Action action, const std::string &jitFuncName, - llvm::ArrayRef jitArgs, bool parametrizeMidlHFE, + llvm::ArrayRef jitArgs, llvm::Optional overrideMaxEintPrecision, llvm::Optional overrideMaxMANP, bool verifyDiagnostics, llvm::raw_ostream &os) { @@ -228,7 +224,13 @@ processInputBuffer(std::unique_ptr buffer, mlir::zamalang::JitCompilerEngine ce{ccx}; ce.setVerifyDiagnostics(verifyDiagnostics); - ce.setParametrizeMidLFHE(parametrizeMidlHFE); + if (cmdline::passes.size() != 0) { + ce.setEnablePass([](mlir::Pass *pass) { + return std::any_of( + cmdline::passes.begin(), cmdline::passes.end(), + [&](const std::string &p) { return pass->getArgument() == p; }); + }); + } if (overrideMaxEintPrecision.hasValue()) ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue()); @@ -267,9 +269,6 @@ processInputBuffer(std::unique_ptr buffer, case Action::DUMP_HLFHE: target = mlir::zamalang::CompilerEngine::Target::HLFHE; break; - case Action::DUMP_HLFHE_MANP: - target = mlir::zamalang::CompilerEngine::Target::HLFHE_MANP; - break; case Action::DUMP_MIDLFHE: target = mlir::zamalang::CompilerEngine::Target::MIDLFHE; break; @@ -353,7 +352,6 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { return processInputBuffer( std::move(inputBuffer), cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, - cmdline::parametrizeMidLFHE, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, os); }, @@ -362,9 +360,8 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { } else { return processInputBuffer( std::move(file), cmdline::action, cmdline::jitFuncName, - cmdline::jitArgs, cmdline::parametrizeMidLFHE, - cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, output->os()); + cmdline::jitArgs, cmdline::assumeMaxEintPrecision, + cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, output->os()); } } diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir index fc460a435..233f641b8 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir index 224270914..5cc76254a 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir index 846572c00..060151796 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}> func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir index 9163d28f0..899e7b920 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir index 6947a3942..714fc3a52 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --action=dump-midlfhe --assume-max-manp=10 --assume-max-eint-precision=2 2>&1| FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK: #map0 = affine_map<(d0) -> (d0)> // CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir index ff156d615..50682d056 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir index f0da29950..3de90a6a7 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 94892649d..fa9d099c5 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index 13a8c7214..923a73c64 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index 0e6ff2534..d7590ac56 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir index 497ce0cd8..e70ce888e 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir index a0c63723d..8c9284546 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index 5f917cd3b..eee5f305c 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index e8bc3ad06..79581bbce 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir index 17a86d8d5..fa1e01b91 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir index f352e0f99..fb09a1e39 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir index acbb19672..ad9f2e2e2 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s +// RUN: zamacompiler --passes MANP --action=dump-hlfhe --split-input-file %s 2>&1 | FileCheck %s func @single_zero() -> !HLFHE.eint<2> { From e66fb20c4ff76eca17abc44cedaa99cc2c1d159f Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 22 Oct 2021 15:36:25 +0200 Subject: [PATCH 19/19] test(compiler): Re introduce tests that was removed by a prior refacto --- compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir | 7 +++++++ .../LowLFHEUnparametrize/unrealized_conversion_cast.mlir | 8 ++++++++ 2 files changed, 15 insertions(+) create mode 100644 compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir create mode 100644 compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir diff --git a/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir b/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir new file mode 100644 index 000000000..8374ce81e --- /dev/null +++ b/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir @@ -0,0 +1,7 @@ +// RUN: zamacompiler --passes lowlfhe-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_> +func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> { + // CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_> + return %arg0: !LowLFHE.lwe_ciphertext<1024,4> +} \ No newline at end of file diff --git a/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir b/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir new file mode 100644 index 000000000..f44a8404f --- /dev/null +++ b/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir @@ -0,0 +1,8 @@ +// RUN: zamacompiler --passes lowlfhe-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s + +// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_> +func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<_,_> { + // CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_> + %0 = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_> + return %0: !LowLFHE.lwe_ciphertext<_,_> +} \ No newline at end of file