mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
Merge branch 'master' into hlfhelinalg-binary-op-lowering
This commit is contained in:
@@ -53,7 +53,7 @@ jobs:
|
||||
cd /compiler
|
||||
pip install pytest
|
||||
export CONCRETE_PROJECT=/concrete
|
||||
make -B BUILD_DIR=/build build
|
||||
make -B BUILD_DIR=/build build-initialized
|
||||
make BUILD_DIR=/build test
|
||||
|
||||
- name: Send Slack Notification
|
||||
|
||||
55
README.md
55
README.md
@@ -1,3 +1,56 @@
|
||||
# Homomorphizer
|
||||
|
||||
The homomorphizer is a compiler that takes a high level computation model and produces a programs that evaluate the model in an homomorphic way.
|
||||
The homomorphizer is a compiler that takes a high level computation model and produces a programs that evaluate the model in an homomorphic way.
|
||||
|
||||
## Build tarball
|
||||
|
||||
The final tarball contains intallation instructions. We only support Linux x86_64 for the moment. You can find the output tarball under `/tarballs`.
|
||||
|
||||
```bash
|
||||
$ cd compiler
|
||||
$ make release_tarballs
|
||||
```
|
||||
|
||||
## Build the Python Package
|
||||
|
||||
Currently supported platforms:
|
||||
- Linux x86_64 for python 3.8, 3.9, and 3.10
|
||||
|
||||
### Linux
|
||||
|
||||
We use the [manylinux](https://github.com/pypa/manylinux) docker images for building python packages for Linux. Those packages should work on distributions that have GLIBC >= 2.24.
|
||||
|
||||
You can use Make to build the python wheels using these docker images:
|
||||
|
||||
```bash
|
||||
$ cd compiler
|
||||
$ make package_py38 # package_py39 package_py310
|
||||
```
|
||||
|
||||
This will build the image for the appropriate python version then copy the wheels out under `/wheels`
|
||||
|
||||
### Build wheels in your environment
|
||||
|
||||
#### Temporary MLIR issue
|
||||
|
||||
Due to an issue with MLIR, you will need to manually add `__init__.py` files to the `mlir` python package after the build.
|
||||
|
||||
```bash
|
||||
$ make python-bindings
|
||||
$ touch build/tools/zamalang/python_packages/zamalang_core/mlir/__init__.py
|
||||
$ touch build/tools/zamalang/python_packages/zamalang_core/mlir/dialects/__init__.py
|
||||
```
|
||||
|
||||
#### Build wheel
|
||||
|
||||
Building the wheels is actually simple.
|
||||
|
||||
```bash
|
||||
$ pip wheel --no-deps -w ../wheels .
|
||||
```
|
||||
|
||||
Depending on the platform you are using (specially Linux), you might need to use `auditwheel` to specify the platform this wheel is targeting. For example, in our build of the package for Linux x86_64 and GLIBC 2.24, we also run:
|
||||
|
||||
```bash
|
||||
$ auditwheel repair ../wheels/*.whl --plat manylinux_2_24_x86_64 -w ../wheels
|
||||
```
|
||||
|
||||
23
builders/Dockerfile.release_manylinux_2_24_x86_64
Normal file
23
builders/Dockerfile.release_manylinux_2_24_x86_64
Normal file
@@ -0,0 +1,23 @@
|
||||
FROM quay.io/pypa/manylinux_2_24_x86_64
|
||||
|
||||
RUN apt-get update
|
||||
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y build-essential ninja-build
|
||||
# Set the python path. Options: [cp38-cp38, cp39-cp39, cp310-cp310]
|
||||
ARG python_tag=cp38-cp38
|
||||
# Install python deps
|
||||
RUN /opt/python/${python_tag}/bin/pip install numpy pybind11==2.6.2 PyYAML
|
||||
# Setup LLVM
|
||||
COPY /llvm-project /llvm-project
|
||||
# Setup Concrete
|
||||
COPY --from=ghcr.io/zama-ai/concrete-api-env:latest /target/release /concrete/target/release
|
||||
ENV CONCRETE_PROJECT=/concrete
|
||||
# Setup and build compiler
|
||||
COPY /compiler /compiler
|
||||
WORKDIR /compiler
|
||||
RUN make Python3_EXECUTABLE=/opt/python/${python_tag}/bin/python python-bindings
|
||||
# Fix MLIR package
|
||||
RUN touch build/tools/zamalang/python_packages/zamalang_core/mlir/__init__.py
|
||||
RUN touch build/tools/zamalang/python_packages/zamalang_core/mlir/dialects/__init__.py
|
||||
# Build wheel
|
||||
RUN /opt/python/${python_tag}/bin/pip wheel --no-deps -w /wheels .
|
||||
RUN auditwheel repair /wheels/*.whl --plat manylinux_2_24_x86_64 -w /wheels
|
||||
23
builders/Dockerfile.release_tarball_linux_x86_64
Normal file
23
builders/Dockerfile.release_tarball_linux_x86_64
Normal file
@@ -0,0 +1,23 @@
|
||||
FROM quay.io/pypa/manylinux_2_24_x86_64
|
||||
|
||||
RUN apt-get update
|
||||
RUN DEBIAN_FRONTEND="noninteractive" apt-get install -y build-essential ninja-build
|
||||
# Setup LLVM
|
||||
COPY /llvm-project /llvm-project
|
||||
# Setup Concrete
|
||||
COPY --from=ghcr.io/zama-ai/concrete-api-env:latest /target/release /concrete/target/release
|
||||
ENV CONCRETE_PROJECT=/concrete
|
||||
# Setup and build compiler
|
||||
COPY /compiler /compiler
|
||||
WORKDIR /compiler
|
||||
RUN make BINDINGS_PYTHON_ENABLED=OFF zamacompiler
|
||||
# Build tarball
|
||||
RUN mkdir -p /tarballs/zamacompiler/lib /tarballs/zamacompiler/bin && \
|
||||
cp /compiler/build/bin/zamacompiler /tarballs/zamacompiler/bin && \
|
||||
cp /compiler/build/lib/libZamalangRuntime.so /tarballs/zamacompiler/lib
|
||||
RUN echo "# Installation\n"\
|
||||
"You can install the compiler by either:\n"\
|
||||
"1. Extracting the tarball as is somewhere of your choosing, and add /path/to/tarball/zamacompiler/bin to your \$PATH\n"\
|
||||
"2. Extracting the tarball and putting the bin/zamacompiler into a path already in your \$PATH, and lib/libZamalangRuntime.so into one of your lib folders (e.g /usr/lib)"\
|
||||
>> /tarballs/zamacompiler/Installation.md
|
||||
RUN cd /tarballs && tar -czvf zamacompiler.tar.gz zamacompiler
|
||||
@@ -12,7 +12,6 @@ COPY /llvm-project /llvm-project
|
||||
COPY /compiler /compiler
|
||||
WORKDIR /compiler
|
||||
RUN mkdir -p /build
|
||||
RUN make BUILD_DIR=/build -B build
|
||||
RUN make BUILD_DIR=/build zamacompiler python-bindings
|
||||
ENV PYTHONPATH "$PYTHONPATH:/build/tools/zamalang/python_packages/zamalang_core:/build/tools/zamalang/python_packages/zamalang_core/mlir/_mlir_libs/"
|
||||
ENV PATH "$PATH:/build/bin"
|
||||
|
||||
@@ -56,7 +56,18 @@ if(ZAMALANG_BINDINGS_PYTHON_ENABLED)
|
||||
message(STATUS "ZamaLang Python bindings are enabled.")
|
||||
|
||||
include(MLIRDetectPythonEnv)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
# After CMake 3.18, we are able to limit the scope of the search to just
|
||||
# Development.Module. Searching for Development will fail in situations where
|
||||
# the Python libraries are not available. When possible, limit to just
|
||||
# Development.Module.
|
||||
# See https://pybind11.readthedocs.io/en/stable/compiling.html#findpython-mode
|
||||
if(CMAKE_VERSION VERSION_LESS "3.18.0")
|
||||
set(_python_development_component Development)
|
||||
else()
|
||||
set(_python_development_component Development.Module)
|
||||
endif()
|
||||
find_package(Python3 COMPONENTS Interpreter ${_python_development_component} REQUIRED)
|
||||
unset(_python_development_component)
|
||||
message(STATUS "Found Python include dirs: ${Python3_INCLUDE_DIRS}")
|
||||
message(STATUS "Found Python libraries: ${Python3_LIBRARIES}")
|
||||
message(STATUS "Found Python executable: ${Python3_EXECUTABLE}")
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
BUILD_DIR=./build
|
||||
Python3_EXECUTABLE=
|
||||
BINDINGS_PYTHON_ENABLED=ON
|
||||
|
||||
|
||||
build:
|
||||
$(BUILD_DIR)/configured.stamp:
|
||||
cmake -B $(BUILD_DIR) -GNinja ../llvm-project/llvm/ \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_BUILD_EXAMPLES=OFF \
|
||||
-DLLVM_TARGETS_TO_BUILD="host" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DZAMALANG_BINDINGS_PYTHON_ENABLED=ON \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=$(BINDINGS_PYTHON_ENABLED) \
|
||||
-DZAMALANG_BINDINGS_PYTHON_ENABLED=$(BINDINGS_PYTHON_ENABLED) \
|
||||
-DCONCRETE_FFI_RELEASE=${CONCRETE_PROJECT}/target/release \
|
||||
-DLLVM_EXTERNAL_PROJECTS=zamalang \
|
||||
-DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=.
|
||||
-DLLVM_EXTERNAL_ZAMALANG_SOURCE_DIR=. \
|
||||
-DPython3_EXECUTABLE=${Python3_EXECUTABLE}
|
||||
touch $@
|
||||
|
||||
build-end-to-end-jit: build
|
||||
build-initialized: $(BUILD_DIR)/configured.stamp
|
||||
|
||||
build-end-to-end-jit: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target end_to_end_jit_test
|
||||
|
||||
zamacompiler: build
|
||||
zamacompiler: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target zamacompiler
|
||||
|
||||
python-bindings: build
|
||||
cmake --build $(BUILD_DIR) --target ZamalangMLIRPythonModules ZamalangPythonModules
|
||||
python-bindings: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ZamalangMLIRPythonModules
|
||||
cmake --build $(BUILD_DIR) --target ZamalangPythonModules
|
||||
|
||||
test-check: zamacompiler file-check not
|
||||
$(BUILD_DIR)/bin/llvm-lit -v tests/
|
||||
@@ -30,7 +37,7 @@ test-end-to-end-jit: build-end-to-end-jit
|
||||
$(BUILD_DIR)/bin/end_to_end_jit_test
|
||||
|
||||
test-python: python-bindings
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core/mlir/_mlir_libs/ LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
|
||||
PYTHONPATH=${PYTHONPATH}:$(BUILD_DIR)/tools/zamalang/python_packages/zamalang_core LD_PRELOAD=$(BUILD_DIR)/lib/libZamalangRuntime.so pytest -vs tests/python
|
||||
|
||||
test: test-check test-end-to-end-jit test-python
|
||||
|
||||
@@ -42,3 +49,39 @@ file-check:
|
||||
cmake --build $(BUILD_DIR) --target FileCheck
|
||||
not:
|
||||
cmake --build $(BUILD_DIR) --target not
|
||||
|
||||
# Python packages
|
||||
|
||||
define build_image_and_copy_wheels
|
||||
docker image build -t concretefhe-compiler-manylinux:$(1) --build-arg python_tag=$(1) -f ../builders/Dockerfile.release_manylinux_2_24_x86_64 ..
|
||||
docker container run --rm -v ${PWD}/../wheels:/wheels_volume concretefhe-compiler-manylinux:$(1) cp -r /wheels/. /wheels_volume/.
|
||||
endef
|
||||
|
||||
package_py38:
|
||||
$(call build_image_and_copy_wheels,cp38-cp38)
|
||||
|
||||
package_py39:
|
||||
$(call build_image_and_copy_wheels,cp39-cp39)
|
||||
|
||||
package_py310:
|
||||
$(call build_image_and_copy_wheels,cp310-cp310)
|
||||
|
||||
release_tarballs:
|
||||
docker image build -t concretefhe-compiler-manylinux:linux_x86_64_tarball -f ../builders/Dockerfile.release_tarball_linux_x86_64 ..
|
||||
docker container run --rm -v ${PWD}/../tarballs:/tarballs_volume concretefhe-compiler-manylinux:linux_x86_64_tarball cp -r /tarballs/. /tarballs_volume/.
|
||||
|
||||
.PHONY: build-initialized \
|
||||
build-end-to-end-jit \
|
||||
zamacompiler \
|
||||
python-bindings \
|
||||
test-check \
|
||||
test-end-to-end-jit \
|
||||
test-python \
|
||||
test \
|
||||
add-deps \
|
||||
file-check \
|
||||
not \
|
||||
package_py38 \
|
||||
package_py39 \
|
||||
package_py310 \
|
||||
release_tarballs
|
||||
|
||||
@@ -5,15 +5,17 @@
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct compilerEngine {
|
||||
mlir::zamalang::CompilerEngine *ptr;
|
||||
struct lambda {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda *ptr;
|
||||
};
|
||||
typedef struct compilerEngine compilerEngine;
|
||||
typedef struct lambda lambda;
|
||||
|
||||
struct executionArguments {
|
||||
mlir::zamalang::ExecutionArgument *data;
|
||||
@@ -21,13 +23,12 @@ struct executionArguments {
|
||||
};
|
||||
typedef struct executionArguments exectuionArguments;
|
||||
|
||||
// Compile an MLIR module
|
||||
MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine,
|
||||
const char *module);
|
||||
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName);
|
||||
|
||||
// Run the compiled module
|
||||
MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e,
|
||||
executionArguments args);
|
||||
MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1,49 +1,128 @@
|
||||
#ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include "Jit.h"
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/SourceMgr.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/MLIRContext.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <zamalang/Conversion/Utils/GlobalFHEContext.h>
|
||||
#include <zamalang/Support/ClientParameters.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
/// CompilerEngine is an tools that provides tools to implements the compilation
|
||||
/// flow and manage the compilation flow state.
|
||||
// Compilation context that acts as the root owner of LLVM and MLIR
|
||||
// data structures directly and indirectly referenced by artefacts
|
||||
// produced by the `CompilerEngine`.
|
||||
class CompilationContext {
|
||||
public:
|
||||
CompilationContext();
|
||||
~CompilationContext();
|
||||
|
||||
mlir::MLIRContext *getMLIRContext();
|
||||
llvm::LLVMContext *getLLVMContext();
|
||||
|
||||
static std::shared_ptr<CompilationContext> createShared();
|
||||
|
||||
protected:
|
||||
mlir::MLIRContext *mlirContext;
|
||||
llvm::LLVMContext *llvmContext;
|
||||
};
|
||||
|
||||
class CompilerEngine {
|
||||
public:
|
||||
CompilerEngine() {
|
||||
context = new mlir::MLIRContext();
|
||||
loadDialects();
|
||||
}
|
||||
~CompilerEngine() {
|
||||
if (context != nullptr)
|
||||
delete context;
|
||||
}
|
||||
// Result of an invocation of the `CompilerEngine` with optional
|
||||
// fields for the results produced by different stages.
|
||||
class CompilationResult {
|
||||
public:
|
||||
CompilationResult(std::shared_ptr<CompilationContext> compilationContext =
|
||||
CompilationContext::createShared())
|
||||
: compilationContext(compilationContext) {}
|
||||
|
||||
// Compile an mlir programs from it's textual representation.
|
||||
llvm::Error compile(
|
||||
std::string mlirStr,
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints = {});
|
||||
llvm::Optional<mlir::OwningModuleRef> mlirModuleRef;
|
||||
llvm::Optional<mlir::zamalang::ClientParameters> clientParameters;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
|
||||
|
||||
// Build the jit lambda argument.
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();
|
||||
protected:
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
};
|
||||
|
||||
// Call the compiled function with and argument object.
|
||||
llvm::Error invoke(JITLambda::Argument &arg);
|
||||
// Specification of the exit stage of the compilation pipeline
|
||||
enum class Target {
|
||||
// Only read sources and produce corresponding MLIR module
|
||||
ROUND_TRIP,
|
||||
|
||||
// Call the compiled function with a list of integer arguments.
|
||||
llvm::Expected<uint64_t> run(std::vector<uint64_t> args);
|
||||
// Read sources and exit before any lowering
|
||||
HLFHE,
|
||||
|
||||
// Get a printable representation of the compiled module
|
||||
std::string getCompiledModule();
|
||||
// Read sources and lower all HLFHE operations to MidLFHE
|
||||
// operations
|
||||
MIDLFHE,
|
||||
|
||||
// Read sources and lower all HLFHE and MidLFHE operations to LowLFHE
|
||||
// operations
|
||||
LOWLFHE,
|
||||
|
||||
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
|
||||
// operations to canonical MLIR dialects. Cryptographic operations
|
||||
// are lowered to invocations of the concrete library.
|
||||
STD,
|
||||
|
||||
// Read sources and lower all HLFHE, MidLFHE and LowLFHE
|
||||
// operations to operations from the LLVM dialect. Cryptographic
|
||||
// operations are lowered to invocations of the concrete library.
|
||||
LLVM,
|
||||
|
||||
// Same as `LLVM`, but lowers to actual LLVM IR instead of the
|
||||
// LLVM dialect
|
||||
LLVM_IR,
|
||||
|
||||
// Same as `LLVM_IR`, but invokes the LLVM optimization pipeline
|
||||
// to produce optimized LLVM IR
|
||||
OPTIMIZED_LLVM_IR
|
||||
};
|
||||
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(),
|
||||
clientParametersFuncName(), verifyDiagnostics(false),
|
||||
generateClientParameters(false),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
llvm::Expected<CompilationResult> compile(llvm::StringRef s, Target target);
|
||||
|
||||
llvm::Expected<CompilationResult>
|
||||
compile(std::unique_ptr<llvm::MemoryBuffer> buffer, Target target);
|
||||
|
||||
llvm::Expected<CompilationResult> compile(llvm::SourceMgr &sm, Target target);
|
||||
|
||||
void setFHEConstraints(const mlir::zamalang::V0FHEConstraint &c);
|
||||
void setMaxEintPrecision(size_t v);
|
||||
void setMaxMANP(size_t v);
|
||||
void setVerifyDiagnostics(bool v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setClientParametersFuncName(const llvm::StringRef &name);
|
||||
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
protected:
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision;
|
||||
llvm::Optional<size_t> overrideMaxMANP;
|
||||
llvm::Optional<std::string> clientParametersFuncName;
|
||||
bool verifyDiagnostics;
|
||||
bool generateClientParameters;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
|
||||
private:
|
||||
// Load the necessary dialects into the engine's context
|
||||
void loadDialects();
|
||||
|
||||
mlir::OwningModuleRef module_ref;
|
||||
mlir::MLIRContext *context;
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
getV0FHEConstraint(CompilationResult &res);
|
||||
llvm::Error determineFHEParameters(CompilationResult &res);
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
53
compiler/include/zamalang/Support/Error.h
Normal file
53
compiler/include/zamalang/Support/Error.h
Normal file
@@ -0,0 +1,53 @@
|
||||
#ifndef ZAMALANG_SUPPORT_STRING_ERROR_H
|
||||
#define ZAMALANG_SUPPORT_STRING_ERROR_H
|
||||
|
||||
#include <llvm/Support/Error.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// Internal error class that allows for composing `llvm::Error`s
|
||||
// similar to `llvm::createStringError()`, but using stream-like
|
||||
// composition with `operator<<`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// llvm::Error foo(int i, size_t s, ...) {
|
||||
// ...
|
||||
// if(...) {
|
||||
// return StreamStringError()
|
||||
// << "Some error message with an integer: "
|
||||
// << i << " and a size_t: " << s;
|
||||
// }
|
||||
// ...
|
||||
// }
|
||||
class StreamStringError {
|
||||
public:
|
||||
StreamStringError(const llvm::StringRef &s) : buffer(s.str()), os(buffer){};
|
||||
StreamStringError() : buffer(""), os(buffer){};
|
||||
|
||||
template <typename T> StreamStringError &operator<<(const T &v) {
|
||||
this->os << v;
|
||||
return *this;
|
||||
}
|
||||
|
||||
operator llvm::Error() {
|
||||
return llvm::make_error<llvm::StringError>(os.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
template <typename T> operator llvm::Expected<T>() {
|
||||
return this->operator llvm::Error();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream os;
|
||||
};
|
||||
|
||||
StreamStringError &operator<<(StreamStringError &se, llvm::Error &err);
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -9,11 +9,6 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
mlir::LogicalResult
|
||||
runJit(mlir::ModuleOp module, llvm::StringRef func,
|
||||
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::raw_ostream &os);
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
@@ -53,6 +48,10 @@ public:
|
||||
// - or the size of the `res` buffser doesn't match the size of the tensor.
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> getResultVectorSize(size_t pos);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
@@ -97,7 +96,7 @@ public:
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
llvm::StringRef name;
|
||||
std::string name;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
};
|
||||
|
||||
|
||||
292
compiler/include/zamalang/Support/JitCompilerEngine.h
Normal file
292
compiler/include/zamalang/Support/JitCompilerEngine.h
Normal file
@@ -0,0 +1,292 @@
|
||||
#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <zamalang/Support/CompilerEngine.h>
|
||||
#include <zamalang/Support/Error.h>
|
||||
#include <zamalang/Support/Jit.h>
|
||||
#include <zamalang/Support/LambdaArgument.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
namespace {
|
||||
// Generic function template as well as specializations of
|
||||
// `typedResult` must be declared at namespace scope due to return
|
||||
// type template specialization
|
||||
|
||||
// Helper function for `JitCompilerEngine::Lambda::operator()`
|
||||
// implementing type-dependent preparation of the result.
|
||||
template <typename ResT>
|
||||
llvm::Expected<ResT> typedResult(JITLambda::Argument &arguments);
|
||||
|
||||
// Specialization of `typedResult()` for scalar results, forwarding
|
||||
// scalar value to caller
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
|
||||
uint64_t res = 0;
|
||||
|
||||
if (auto err = arguments.getResult(0, res))
|
||||
return StreamStringError() << "Cannot retrieve result:" << err;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// Specialization of `typedResult()` for vector results, initializing
|
||||
// an `std::vector` of the right size with the results and forwarding
|
||||
// it to the caller with move semantics.
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
|
||||
|
||||
if (auto err = n.takeError())
|
||||
return std::move(err);
|
||||
|
||||
std::vector<uint64_t> res(*n);
|
||||
|
||||
if (auto err = arguments.getResult(0, res.data(), res.size()))
|
||||
return StreamStringError() << "Cannot retrieve result:" << err;
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Adaptor class that adds arguments specified as instances of
|
||||
// `LambdaArgument` to `JitLambda::Argument`.
|
||||
class JITLambdaArgumentAdaptor {
|
||||
public:
|
||||
// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
// argument or a plaintext / encrypted tensor argument with a
|
||||
// backing integer type `IntT` and adds the argument to `jla` at
|
||||
// position `pos`.
|
||||
//
|
||||
// Returns `true` if `arg` has one of the types above and its value
|
||||
// was successfully added to `jla`, `false` if none of the types
|
||||
// matches or an error if a type matched, but adding the argument to
|
||||
// `jla` failed.
|
||||
template <typename IntT>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
|
||||
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
|
||||
if (llvm::Error err = jla.setArg(pos, ila->getValue()))
|
||||
return std::move(err);
|
||||
else
|
||||
return true;
|
||||
} else if (auto tla = arg.dyn_cast<
|
||||
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
|
||||
if (llvm::Error err =
|
||||
jla.setArg(pos, tla->getValue(), tla->getDimensions()))
|
||||
return std::move(err);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recursive case for `tryAddArg<IntT>(...)`
|
||||
template <typename IntT, typename NextIntT, typename... IntTs>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) {
|
||||
llvm::Expected<bool> successOrError = tryAddArg<IntT>(jla, pos, arg);
|
||||
|
||||
if (!successOrError)
|
||||
return std::move(successOrError.takeError());
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return tryAddArg<NextIntT, IntTs...>(jla, pos, arg);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
// Attempts to add a single argument `arg` to `jla` at position
|
||||
// `pos`. Returns an error if either the argument type is
|
||||
// unsupported or if the argument types is supported, but adding it
|
||||
// to `jla` failed.
|
||||
static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos,
|
||||
const LambdaArgument &arg) {
|
||||
llvm::Expected<bool> successOrError =
|
||||
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
|
||||
uint8_t>(jla, pos, arg);
|
||||
|
||||
if (!successOrError)
|
||||
return std::move(successOrError.takeError());
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return StreamStringError("Unknown argument type");
|
||||
else
|
||||
return llvm::Error::success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// A compiler engine that JIT-compiles a source and produces a lambda
|
||||
// object directly invocable through its call operator.
|
||||
class JitCompilerEngine : public CompilerEngine {
|
||||
public:
|
||||
// Wrapper class around `JITLambda` and `JITLambda::Argument` that
|
||||
// allows for direct invocation of a compiled function through
|
||||
// `operator ()`.
|
||||
class Lambda {
|
||||
public:
|
||||
Lambda(Lambda &&other)
|
||||
: innerLambda(std::move(other.innerLambda)),
|
||||
keySet(std::move(other.keySet)),
|
||||
compilationContext(other.compilationContext) {}
|
||||
|
||||
Lambda(std::shared_ptr<CompilationContext> compilationContext,
|
||||
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet)
|
||||
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
// Returns the number of arguments required for an invocation of
|
||||
// the lambda
|
||||
size_t getNumArguments() { return this->keySet->numInputs(); }
|
||||
|
||||
// Returns the number of results an invocation of the lambda
|
||||
// produces
|
||||
size_t getNumResults() { return this->keySet->numOutputs(); }
|
||||
|
||||
// Invocation with an dynamic list of arguments of different
|
||||
// types, specified as `LambdaArgument`s
|
||||
template <typename ResT = uint64_t>
|
||||
llvm::Expected<ResT>
|
||||
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
|
||||
for (size_t i = 0; i < lambdaArgs.size(); i++) {
|
||||
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
|
||||
*arguments, i, *lambdaArgs[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
}
|
||||
|
||||
// Invocation with an array of arguments of the same type
|
||||
template <typename T, typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, args[i])) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push argument " << i << ": " << err;
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
}
|
||||
|
||||
// Invocation with arguments of different types
|
||||
template <typename ResT = uint64_t, typename... Ts>
|
||||
llvm::Expected<ResT> operator()(const Ts... ts) {
|
||||
// Create the arguments of the JIT lambda
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>> argsOrErr =
|
||||
mlir::zamalang::JITLambda::Argument::create(*this->keySet.get());
|
||||
|
||||
if (llvm::Error err = argsOrErr.takeError())
|
||||
return StreamStringError("Could not create lambda arguments");
|
||||
|
||||
// Set the arguments
|
||||
std::unique_ptr<JITLambda::Argument> arguments =
|
||||
std::move(argsOrErr.get());
|
||||
|
||||
if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...))
|
||||
return std::move(err);
|
||||
|
||||
// Invoke the lambda
|
||||
if (auto err = this->innerLambda->invoke(*arguments))
|
||||
return StreamStringError() << "Cannot invoke lambda:" << err;
|
||||
|
||||
return std::move(typedResult<ResT>(*arguments));
|
||||
}
|
||||
|
||||
protected:
|
||||
template <int pos>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs) {
|
||||
// base case -- nothing to do
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Recursive case for scalars: extract first scalar argument from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg,
|
||||
Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push scalar argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
// Recursive case for tensors: extract pointer and size from
|
||||
// parameter pack and forward rest
|
||||
template <int pos, typename ArgT, typename... Ts>
|
||||
inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg,
|
||||
size_t size, Ts... remainder) {
|
||||
if (auto err = jitArgs->setArg(pos, arg, size)) {
|
||||
return StreamStringError()
|
||||
<< "Cannot push tensor argument " << pos << ": " << err;
|
||||
}
|
||||
|
||||
return this->addArgs<pos + 1>(jitArgs, remainder...);
|
||||
}
|
||||
|
||||
std::unique_ptr<JITLambda> innerLambda;
|
||||
std::unique_ptr<KeySet> keySet;
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
};
|
||||
|
||||
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =
|
||||
CompilationContext::createShared(),
|
||||
unsigned int optimizationLevel = 3);
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(llvm::StringRef src,
|
||||
llvm::StringRef funcName = "main");
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName = "main");
|
||||
|
||||
llvm::Expected<Lambda> buildLambda(llvm::SourceMgr &sm,
|
||||
llvm::StringRef funcName = "main");
|
||||
|
||||
protected:
|
||||
llvm::Expected<mlir::LLVM::LLVMFuncOp> findLLVMFuncOp(mlir::ModuleOp module,
|
||||
llvm::StringRef name);
|
||||
unsigned int optimizationLevel;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
170
compiler/include/zamalang/Support/LambdaArgument.h
Normal file
170
compiler/include/zamalang/Support/LambdaArgument.h
Normal file
@@ -0,0 +1,170 @@
|
||||
#ifndef ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
|
||||
#define ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include <llvm/ADT/ArrayRef.h>
|
||||
#include <llvm/Support/Casting.h>
|
||||
#include <llvm/Support/ExtensibleRTTI.h>
|
||||
#include <zamalang/Support/Error.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// Abstract base class for lambda arguments
|
||||
class LambdaArgument
|
||||
: public llvm::RTTIExtends<LambdaArgument, llvm::RTTIRoot> {
|
||||
public:
|
||||
LambdaArgument(LambdaArgument &) = delete;
|
||||
|
||||
template <typename T> bool isa() const { return llvm::isa<T>(*this); }
|
||||
|
||||
// Cast functions on constant instances
|
||||
template <typename T> const T &cast() const { return llvm::cast<T>(*this); }
|
||||
template <typename T> const T *dyn_cast() const {
|
||||
return llvm::dyn_cast<T>(this);
|
||||
}
|
||||
|
||||
// Cast functions for mutable instances
|
||||
template <typename T> T &cast() { return llvm::cast<T>(*this); }
|
||||
template <typename T> T *dyn_cast() { return llvm::dyn_cast<T>(this); }
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
LambdaArgument(){};
|
||||
};
|
||||
|
||||
// Class for integer arguments. `BackingIntType` is used as the data
|
||||
// type to hold the argument's value. The precision is the actual
|
||||
// precision of the value, which might be different from the precision
|
||||
// of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class IntLambdaArgument
|
||||
: public llvm::RTTIExtends<IntLambdaArgument<BackingIntType>,
|
||||
LambdaArgument> {
|
||||
public:
|
||||
typedef BackingIntType value_type;
|
||||
|
||||
IntLambdaArgument(BackingIntType value,
|
||||
unsigned int precision = 8 * sizeof(BackingIntType))
|
||||
: precision(precision) {
|
||||
if (precision < 8 * sizeof(BackingIntType)) {
|
||||
this->value = value & (1 << (this->precision - 1));
|
||||
} else {
|
||||
this->value = value;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int getPrecision() const { return this->precision; }
|
||||
BackingIntType getValue() const { return this->value; }
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
unsigned int precision;
|
||||
BackingIntType value;
|
||||
};
|
||||
|
||||
template <typename BackingIntType>
|
||||
char IntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
// Class for encrypted integer arguments. `BackingIntType` is used as
|
||||
// the data type to hold the argument's plaintext value. The precision
|
||||
// is the actual precision of the value, which might be different from
|
||||
// the precision of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class EIntLambdaArgument
|
||||
: public llvm::RTTIExtends<EIntLambdaArgument<BackingIntType>,
|
||||
IntLambdaArgument<BackingIntType>> {
|
||||
public:
|
||||
static char ID;
|
||||
};
|
||||
|
||||
template <typename BackingIntType>
|
||||
char EIntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
namespace {
|
||||
// Calculates `accu *= factor` or returns an error if the result
|
||||
// would overflow
|
||||
template <typename AccuT, typename ValT>
|
||||
llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
|
||||
static_assert(std::numeric_limits<AccuT>::is_integer &&
|
||||
std::numeric_limits<ValT>::is_integer &&
|
||||
!std::numeric_limits<AccuT>::is_signed &&
|
||||
!std::numeric_limits<ValT>::is_signed,
|
||||
"Only unsigned integers are supported");
|
||||
|
||||
const AccuT left = std::numeric_limits<AccuT>::max() / accu;
|
||||
|
||||
if (left > factor) {
|
||||
accu *= factor;
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
return StreamStringError("Multiplying value ")
|
||||
<< accu << " with " << factor << " would cause an overflow";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Class for Tensor arguments. This can either be plaintext tensors
|
||||
// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
|
||||
// representing encrypted integers (for `ScalarArgumentT =
|
||||
// EIntLambaArgument<T>`).
|
||||
template <typename ScalarArgumentT>
|
||||
class TensorLambdaArgument
|
||||
: public llvm::RTTIExtends<TensorLambdaArgument<ScalarArgumentT>,
|
||||
LambdaArgument> {
|
||||
public:
|
||||
typedef ScalarArgumentT scalar_type;
|
||||
|
||||
// Construct tensor argument from the one-dimensional array `value`,
|
||||
// but interpreting the array's values as a linearized
|
||||
// multi-dimensional tensor with the sizes of the dimensions
|
||||
// specified in `dimensions`.
|
||||
TensorLambdaArgument(
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value,
|
||||
llvm::ArrayRef<int64_t> dimensions)
|
||||
: value(value), dimensions(dimensions.vec()) {}
|
||||
|
||||
// Construct a one-dimensional tensor argument from the
|
||||
// array `value`.
|
||||
TensorLambdaArgument(
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value)
|
||||
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
|
||||
|
||||
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
// Returns the total number of elements in the tensor. If the number
|
||||
// of elements cannot be represented as a `size_t`, the method
|
||||
// returns an error.
|
||||
llvm::Expected<size_t> getNumElements() const {
|
||||
size_t accu = 1;
|
||||
|
||||
for (unsigned int dimSize : dimensions)
|
||||
if (llvm::Error err = safeUnsignedMul(accu, dimSize))
|
||||
return std::move(err);
|
||||
|
||||
return accu;
|
||||
}
|
||||
|
||||
// Returns a bare pointer to the linearized values of the tensor.
|
||||
typename ScalarArgumentT::value_type *getValue() const {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
llvm::MutableArrayRef<typename ScalarArgumentT::value_type> value;
|
||||
std::vector<int64_t> dimensions;
|
||||
};
|
||||
|
||||
template <typename ScalarArgumentT>
|
||||
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -4,43 +4,43 @@
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include <zamalang/Support/V0Parameters.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace pipeline {
|
||||
|
||||
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool debug);
|
||||
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module);
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose);
|
||||
mlir::LogicalResult
|
||||
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext,
|
||||
bool parametrize);
|
||||
mlir::LogicalResult
|
||||
lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module);
|
||||
mlir::LogicalResult
|
||||
lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose);
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
|
||||
llvm::Module &module);
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext, bool verbose);
|
||||
|
||||
std::unique_ptr<llvm::Module>
|
||||
lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
|
||||
llvm::LLVMContext &llvmContext,
|
||||
mlir::ModuleOp &module);
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -32,6 +32,7 @@ private:
|
||||
StreamWrap<llvm::raw_ostream> &log_error(void);
|
||||
StreamWrap<llvm::raw_ostream> &log_verbose(void);
|
||||
void setupLogging(bool verbose);
|
||||
bool isVerbose();
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ declare_mlir_python_sources(ZamalangBindingsPythonSources
|
||||
SOURCES
|
||||
zamalang/__init__.py
|
||||
zamalang/compiler.py
|
||||
zamalang/dialects/__init__.py
|
||||
zamalang/dialects/_ods_common.py)
|
||||
|
||||
################################################################################
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
#include "CompilerAPIModule.h"
|
||||
#include "zamalang-c/Support/CompilerEngine.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
#include <mlir/Dialect/StandardOps/IR/Ops.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
@@ -14,27 +15,15 @@
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
using mlir::zamalang::ExecutionArgument;
|
||||
using mlir::zamalang::JitCompilerEngine;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
m.doc() = "Zamalang compiler python API";
|
||||
|
||||
m.def("round_trip", [](std::string mlir_input) {
|
||||
mlir::MLIRContext context;
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
auto module_ref = mlir::parseSourceString(mlir_input, &context);
|
||||
if (!module_ref) {
|
||||
throw std::logic_error("mlir parsing failed");
|
||||
}
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
module_ref->print(os);
|
||||
return os.str();
|
||||
});
|
||||
m.def("round_trip",
|
||||
[](std::string mlir_input) { return roundTrip(mlir_input.c_str()); });
|
||||
|
||||
pybind11::class_<ExecutionArgument, std::shared_ptr<ExecutionArgument>>(
|
||||
m, "ExecutionArgument")
|
||||
@@ -45,20 +34,19 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
.def("is_tensor", &ExecutionArgument::isTensor)
|
||||
.def("is_int", &ExecutionArgument::isInt);
|
||||
|
||||
pybind11::class_<CompilerEngine>(m, "CompilerEngine")
|
||||
pybind11::class_<JitCompilerEngine>(m, "JitCompilerEngine")
|
||||
.def(pybind11::init())
|
||||
.def("run",
|
||||
[](CompilerEngine &engine, std::vector<ExecutionArgument> args) {
|
||||
// wrap and call CAPI
|
||||
compilerEngine e{&engine};
|
||||
exectuionArguments a{args.data(), args.size()};
|
||||
return compilerEngineRun(e, a);
|
||||
})
|
||||
.def("compile_fhe",
|
||||
[](CompilerEngine &engine, std::string mlir_input) {
|
||||
// wrap and call CAPI
|
||||
compilerEngine e{&engine};
|
||||
compilerEngineCompile(e, mlir_input.c_str());
|
||||
})
|
||||
.def("get_compiled_module", &CompilerEngine::getCompiledModule);
|
||||
.def_static("build_lambda",
|
||||
[](std::string mlir_input, std::string func_name) {
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str());
|
||||
});
|
||||
|
||||
pybind11::class_<JitCompilerEngine::Lambda>(m, "Lambda")
|
||||
.def("invoke", [](JitCompilerEngine::Lambda &py_lambda,
|
||||
std::vector<ExecutionArgument> args) {
|
||||
// wrap and call CAPI
|
||||
lambda c_lambda{&py_lambda};
|
||||
exectuionArguments a{args.data(), args.size()};
|
||||
return invokeLambda(c_lambda, a);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Zamalang python module"""
|
||||
from _zamalang import *
|
||||
from mlir._mlir_libs._zamalang import *
|
||||
from .compiler import CompilerEngine
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Compiler submodule"""
|
||||
from typing import List, Union
|
||||
from _zamalang._compiler import CompilerEngine as _CompilerEngine
|
||||
from _zamalang._compiler import ExecutionArgument as _ExecutionArgument
|
||||
from _zamalang._compiler import round_trip as _round_trip
|
||||
|
||||
from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine
|
||||
from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument
|
||||
from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
"""Parse the MLIR input, then return it back.
|
||||
@@ -49,25 +48,24 @@ def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgume
|
||||
|
||||
class CompilerEngine:
|
||||
def __init__(self, mlir_str: str = None):
|
||||
self._engine = _CompilerEngine()
|
||||
self._engine = _JitCompilerEngine()
|
||||
self._lambda = None
|
||||
if mlir_str is not None:
|
||||
self.compile_fhe(mlir_str)
|
||||
|
||||
def compile_fhe(self, mlir_str: str) -> "CompilerEngine":
|
||||
"""Compile the MLIR input and build a CompilerEngine.
|
||||
def compile_fhe(self, mlir_str: str, func_name: str = "main"):
|
||||
"""Compile the MLIR input.
|
||||
|
||||
Args:
|
||||
mlir_str (str): MLIR to compile.
|
||||
func_name (str): name of the function to set as entrypoint.
|
||||
|
||||
Raises:
|
||||
TypeError: if the argument is not an str.
|
||||
|
||||
Returns:
|
||||
CompilerEngine: engine used for execution.
|
||||
"""
|
||||
if not isinstance(mlir_str, str):
|
||||
raise TypeError("input must be an `str`")
|
||||
return self._engine.compile_fhe(mlir_str)
|
||||
self._lambda = self._engine.build_lambda(mlir_str, func_name)
|
||||
|
||||
def run(self, *args: List[Union[int, List[int]]]) -> int:
|
||||
"""Run the compiled code.
|
||||
@@ -77,17 +75,12 @@ class CompilerEngine:
|
||||
|
||||
Raises:
|
||||
TypeError: if execution arguments can't be constructed
|
||||
RuntimeError: if the engine has not compiled any code yet
|
||||
|
||||
Returns:
|
||||
int: result of execution.
|
||||
"""
|
||||
if self._lambda is None:
|
||||
raise RuntimeError("need to compile an MLIR code first")
|
||||
execution_arguments = [create_execution_argument(arg) for arg in args]
|
||||
return self._engine.run(execution_arguments)
|
||||
|
||||
def get_compiled_module(self) -> str:
|
||||
"""Compiled module in printable form.
|
||||
|
||||
Returns:
|
||||
str: Compiled module in printable form.
|
||||
"""
|
||||
return self._engine.get_compiled_module()
|
||||
return self._lambda.invoke(execution_arguments)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""HLFHE dialect module"""
|
||||
from ._HLFHE_ops_gen import *
|
||||
from _zamalang._hlfhe import *
|
||||
from mlir._mlir_libs._zamalang._hlfhe import *
|
||||
|
||||
@@ -1,62 +1,83 @@
|
||||
#include "zamalang-c/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
#include "zamalang/Support/logging.h"
|
||||
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
// using mlir::zamalang::CompilerEngine;
|
||||
using mlir::zamalang::ExecutionArgument;
|
||||
using mlir::zamalang::JitCompilerEngine;
|
||||
|
||||
void compilerEngineCompile(compilerEngine engine, const char *module) {
|
||||
auto error = engine.ptr->compile(module);
|
||||
if (error) {
|
||||
llvm::errs() << "Compilation failed: " << error << "\n";
|
||||
llvm::consumeError(std::move(error));
|
||||
mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module,
|
||||
const char *funcName) {
|
||||
mlir::zamalang::JitCompilerEngine engine;
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(module, funcName);
|
||||
if (!lambdaOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Compilation failed: "
|
||||
<< llvm::toString(std::move(lambdaOrErr.takeError())) << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed compiling, see previous logs for more info");
|
||||
}
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
|
||||
uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) {
|
||||
auto args_size = args.size;
|
||||
auto maybeArgument = engine.ptr->buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error(
|
||||
"failed building arguments, see previous logs for more info");
|
||||
uint64_t invokeLambda(lambda l, executionArguments args) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr =
|
||||
(mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr;
|
||||
|
||||
if (args.size != lambda_ptr->getNumArguments()) {
|
||||
throw std::invalid_argument("wrong number of arguments");
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args_size; i++) {
|
||||
std::vector<mlir::zamalang::LambdaArgument *> lambdaArgumentsRef;
|
||||
for (auto i = 0; i < args.size; i++) {
|
||||
if (args.data[i].isInt()) { // integer argument
|
||||
if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed pushing integer argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
lambdaArgumentsRef.push_back(new mlir::zamalang::IntLambdaArgument<>(
|
||||
args.data[i].getIntegerArgument()));
|
||||
} else { // tensor argument
|
||||
assert(args.data[i].isTensor() && "should be tensor argument");
|
||||
if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(),
|
||||
args.data[i].getTensorSize())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed pushing tensor argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
llvm::MutableArrayRef<uint8_t> tensor(args.data[i].getTensorArgument(),
|
||||
args.data[i].getTensorSize());
|
||||
lambdaArgumentsRef.push_back(
|
||||
new mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>(tensor));
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = engine.ptr->invoke(*arguments)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed running, see previous logs for more info");
|
||||
}
|
||||
uint64_t result = 0;
|
||||
if (auto err = arguments->getResult(0, result)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
// Run lambda
|
||||
llvm::Expected<uint64_t> resOrError = (*lambda_ptr)(
|
||||
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(lambdaArgumentsRef));
|
||||
// Free heap
|
||||
for (size_t i = 0; i < lambdaArgumentsRef.size(); i++)
|
||||
delete lambdaArgumentsRef[i];
|
||||
|
||||
if (!resOrError) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Lambda invokation failed: "
|
||||
<< llvm::toString(std::move(resOrError.takeError())) << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed getting result, see previous logs for more info");
|
||||
"failed invoking lambda, see previous logs for more info");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return *resOrError;
|
||||
}
|
||||
|
||||
std::string roundTrip(const char *module) {
|
||||
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
|
||||
mlir::zamalang::CompilationContext::createShared();
|
||||
mlir::zamalang::JitCompilerEngine ce{ccx};
|
||||
|
||||
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
|
||||
ce.compile(module, mlir::zamalang::CompilerEngine::Target::ROUND_TRIP);
|
||||
if (!retOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
|
||||
throw std::runtime_error(
|
||||
"mlir parsing failed, see previous logs for more info");
|
||||
}
|
||||
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
retOrErr->mlirModuleRef->get().print(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -23,6 +23,51 @@
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace {
|
||||
|
||||
// Returns `true` if the given value is a scalar or tensor argument of
|
||||
// a function, for which a MANP of 1 can be assumed.
|
||||
static bool isEncryptedFunctionParameter(mlir::Value value) {
|
||||
if (!value.isa<mlir::BlockArgument>())
|
||||
return false;
|
||||
|
||||
mlir::Block *block = value.cast<mlir::BlockArgument>().getOwner();
|
||||
|
||||
if (!block || !block->getParentOp() ||
|
||||
!llvm::isa<mlir::FuncOp>(block->getParentOp())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return (value.getType().isa<mlir::zamalang::HLFHE::EncryptedIntegerType>() ||
|
||||
(value.getType().isa<mlir::TensorType>() &&
|
||||
value.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()));
|
||||
}
|
||||
|
||||
// Returns the bit width of `value` if `value` is an encrypted integer
|
||||
// or the bit width of the elements if `value` is a tensor of
|
||||
// encrypted integers.
|
||||
static unsigned int getEintPrecision(mlir::Value value) {
|
||||
if (auto ty = value.getType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType>()) {
|
||||
return ty.getWidth();
|
||||
} else if (auto tensorTy =
|
||||
value.getType().dyn_cast_or_null<mlir::TensorType>()) {
|
||||
if (auto ty = tensorTy.getElementType()
|
||||
.dyn_cast_or_null<
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType>())
|
||||
return ty.getWidth();
|
||||
}
|
||||
|
||||
assert(false &&
|
||||
"Value is neither an encrypted integer nor a tensor of encrypted "
|
||||
"integers");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The `MANPLatticeValue` represents the squared Minimal Arithmetic
|
||||
// Noise Padding for an operation using the squared 2-norm of an
|
||||
// equivalent dot operation. This can either be an actual value if the
|
||||
@@ -41,13 +86,7 @@ struct MANPLatticeValue {
|
||||
//
|
||||
// TODO: Provide a mechanism to propagate Minimal Arithmetic Noise
|
||||
// Padding across function calls.
|
||||
if (value.isa<mlir::BlockArgument>() &&
|
||||
(value.getType().isa<mlir::zamalang::HLFHE::EncryptedIntegerType>() ||
|
||||
(value.getType().isa<mlir::TensorType>() &&
|
||||
value.getType()
|
||||
.cast<mlir::TensorType>()
|
||||
.getElementType()
|
||||
.isa<mlir::zamalang::HLFHE::EncryptedIntegerType>()))) {
|
||||
if (isEncryptedFunctionParameter(value)) {
|
||||
return MANPLatticeValue(llvm::APInt{1, 1, false});
|
||||
} else {
|
||||
// All other operations have an unknown Minimal Arithmetic Noise
|
||||
@@ -450,7 +489,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
bool isDummy = false;
|
||||
llvm::APInt norm2SqEquiv;
|
||||
|
||||
// HLFHE Operaors
|
||||
// HLFHE Operators
|
||||
if (auto dotOp = llvm::dyn_cast<mlir::zamalang::HLFHE::Dot>(op)) {
|
||||
norm2SqEquiv = getSqMANP(dotOp, operands);
|
||||
} else if (auto addEintIntOp =
|
||||
@@ -599,6 +638,29 @@ struct MaxMANPPass : public MaxMANPBase<MaxMANPPass> {
|
||||
|
||||
protected:
|
||||
void processOperation(mlir::Operation *op) {
|
||||
static const llvm::APInt one{1, 1, false};
|
||||
bool upd = false;
|
||||
|
||||
// Process all function arguments and use the default value of 1
|
||||
// for MANP and the declarend precision
|
||||
if (mlir::FuncOp func = llvm::dyn_cast_or_null<mlir::FuncOp>(op)) {
|
||||
for (mlir::BlockArgument blockArg : func.getBody().getArguments()) {
|
||||
if (isEncryptedFunctionParameter(blockArg)) {
|
||||
unsigned int width = getEintPrecision(blockArg);
|
||||
|
||||
if (this->maxEintWidth < width) {
|
||||
this->maxEintWidth = width;
|
||||
}
|
||||
|
||||
if (APIntWidthExtendULT(this->maxMANP, one)) {
|
||||
this->maxMANP = one;
|
||||
upd = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process all results using MANP attribute from MANP pas
|
||||
for (mlir::OpResult res : op->getResults()) {
|
||||
mlir::zamalang::HLFHE::EncryptedIntegerType eTy =
|
||||
res.getType()
|
||||
@@ -613,7 +675,6 @@ protected:
|
||||
}
|
||||
|
||||
if (eTy) {
|
||||
bool upd = false;
|
||||
if (this->maxEintWidth < eTy.getWidth()) {
|
||||
this->maxEintWidth = eTy.getWidth();
|
||||
upd = true;
|
||||
@@ -630,11 +691,11 @@ protected:
|
||||
this->maxMANP = MANP.getValue();
|
||||
upd = true;
|
||||
}
|
||||
|
||||
if (upd)
|
||||
this->updateMax(this->maxMANP, this->maxEintWidth);
|
||||
}
|
||||
}
|
||||
|
||||
if (upd)
|
||||
this->updateMax(this->maxMANP, this->maxEintWidth);
|
||||
}
|
||||
|
||||
std::function<void(const llvm::APInt &, unsigned)> updateMax;
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
add_mlir_library(ZamalangSupport
|
||||
Error.cpp
|
||||
Pipeline.cpp
|
||||
Jit.cpp
|
||||
CompilerEngine.cpp
|
||||
JitCompilerEngine.cpp
|
||||
LambdaArgument.cpp
|
||||
V0Parameters.cpp
|
||||
V0Curves.cpp
|
||||
ClientParameters.cpp
|
||||
|
||||
@@ -28,8 +28,13 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
width = type.getIntOrFloatBitWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
.encryption = llvm::None,
|
||||
.shape = {.width = width, .size = 0},
|
||||
/*.encryption = */ llvm::None,
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
@@ -41,7 +46,12 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
.variance = variance,
|
||||
.encoding = {.precision = precision},
|
||||
}),
|
||||
.shape = {.width = precision, .size = 0},
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ precision,
|
||||
/*.dimensions = */ std::vector<int64_t>(),
|
||||
/*.size = */ 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
@@ -70,34 +80,33 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
v0Curve->getVariance(1, 1 << v0Param.polynomialSize, 64);
|
||||
Variance keyswitchVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
|
||||
// Static client parameters from global parameters for v0
|
||||
ClientParameters c{
|
||||
.secretKeys{
|
||||
{"small", {.size = v0Param.nSmall}},
|
||||
{"big", {.size = v0Param.getNBigGlweSize()}},
|
||||
},
|
||||
.bootstrapKeys{
|
||||
ClientParameters c = {};
|
||||
c.secretKeys = {
|
||||
{"small", {/*.size = */ v0Param.nSmall}},
|
||||
{"big", {/*.size = */ v0Param.getNBigGlweSize()}},
|
||||
};
|
||||
c.bootstrapKeys = {
|
||||
{
|
||||
"bsk_v0",
|
||||
{
|
||||
"bsk_v0",
|
||||
{
|
||||
.inputSecretKeyID = "small",
|
||||
.outputSecretKeyID = "big",
|
||||
.level = v0Param.brLevel,
|
||||
.baseLog = v0Param.brLogBase,
|
||||
.k = v0Param.k,
|
||||
.variance = encryptionVariance,
|
||||
},
|
||||
/*.inputSecretKeyID = */ "small",
|
||||
/*.outputSecretKeyID = */ "big",
|
||||
/*.level = */ v0Param.brLevel,
|
||||
/*.baseLog = */ v0Param.brLogBase,
|
||||
/*.k = */ v0Param.k,
|
||||
/*.variance = */ encryptionVariance,
|
||||
},
|
||||
},
|
||||
.keyswitchKeys{
|
||||
};
|
||||
c.keyswitchKeys = {
|
||||
{
|
||||
"ksk_v0",
|
||||
{
|
||||
"ksk_v0",
|
||||
{
|
||||
.inputSecretKeyID = "big",
|
||||
.outputSecretKeyID = "small",
|
||||
.level = v0Param.ksLevel,
|
||||
.baseLog = v0Param.ksLogBase,
|
||||
.variance = keyswitchVariance,
|
||||
},
|
||||
/*.inputSecretKeyID = */ "big",
|
||||
/*.outputSecretKeyID = */ "small",
|
||||
/*.level = */ v0Param.ksLevel,
|
||||
/*.baseLog = */ v0Param.ksLogBase,
|
||||
/*.variance = */ keyswitchVariance,
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -134,4 +143,4 @@ createClientParametersForV0(V0FHEContext fheContext, llvm::StringRef name,
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/SMLoc.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/Dialect/MemRef/IR/MemRef.h>
|
||||
@@ -10,156 +12,285 @@
|
||||
#include <zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h>
|
||||
#include <zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h>
|
||||
#include <zamalang/Support/CompilerEngine.h>
|
||||
#include <zamalang/Support/Error.h>
|
||||
#include <zamalang/Support/Jit.h>
|
||||
#include <zamalang/Support/Pipeline.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
void CompilerEngine::loadDialects() {
|
||||
context->getOrLoadDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
context->getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context->getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
context->getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
context->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
// Creates a new compilation context that can be shared across
|
||||
// compilation engines and results
|
||||
std::shared_ptr<CompilationContext> CompilationContext::createShared() {
|
||||
return std::make_shared<CompilationContext>();
|
||||
}
|
||||
|
||||
std::string CompilerEngine::getCompiledModule() {
|
||||
std::string compiledModule;
|
||||
llvm::raw_string_ostream os(compiledModule);
|
||||
module_ref->print(os);
|
||||
return os.str();
|
||||
CompilationContext::CompilationContext()
|
||||
: mlirContext(nullptr), llvmContext(nullptr) {}
|
||||
|
||||
CompilationContext::~CompilationContext() {
|
||||
delete this->mlirContext;
|
||||
delete this->llvmContext;
|
||||
}
|
||||
|
||||
llvm::Error CompilerEngine::compile(
|
||||
std::string mlirStr,
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints) {
|
||||
module_ref = mlir::parseSourceString(mlirStr, context);
|
||||
if (!module_ref) {
|
||||
return llvm::make_error<llvm::StringError>("mlir parsing failed",
|
||||
llvm::inconvertibleErrorCode());
|
||||
// Returns the MLIR context for a compilation context. Creates and
|
||||
// initializes a new MLIR context if necessary.
|
||||
mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
if (this->mlirContext == nullptr) {
|
||||
this->mlirContext = new mlir::MLIRContext();
|
||||
|
||||
this->mlirContext->getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
this->mlirContext
|
||||
->getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
this->mlirContext
|
||||
->getOrLoadDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
|
||||
this->mlirContext
|
||||
->getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
this->mlirContext->getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
this->mlirContext->getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
this->mlirContext->getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
this->mlirContext->getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
mlir::ModuleOp module = module_ref.get();
|
||||
return this->mlirContext;
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraintsOpt =
|
||||
overrideConstraints;
|
||||
// Returns the LLVM context for a compilation context. Creates and
|
||||
// initializes a new LLVM context if necessary.
|
||||
llvm::LLVMContext *CompilationContext::getLLVMContext() {
|
||||
if (this->llvmContext == nullptr)
|
||||
this->llvmContext = new llvm::LLVMContext();
|
||||
|
||||
if (!fheConstraintsOpt.hasValue()) {
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
fheConstraintsOrErr =
|
||||
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(*context,
|
||||
module);
|
||||
return this->llvmContext;
|
||||
}
|
||||
|
||||
if (auto err = fheConstraintsOrErr.takeError())
|
||||
return std::move(err);
|
||||
// Sets the FHE constraints for the compilation. Overrides any
|
||||
// automatically detected configuration and prevents the autodetection
|
||||
// pass from running.
|
||||
void CompilerEngine::setFHEConstraints(
|
||||
const mlir::zamalang::V0FHEConstraint &c) {
|
||||
this->overrideMaxEintPrecision = c.p;
|
||||
this->overrideMaxMANP = c.norm2;
|
||||
}
|
||||
|
||||
if (!fheConstraintsOrErr.get().hasValue()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"Could not determine maximum required precision for encrypted "
|
||||
"integers "
|
||||
"and maximum value for the Minimal Arithmetic Noise Padding",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
void CompilerEngine::setVerifyDiagnostics(bool v) {
|
||||
this->verifyDiagnostics = v;
|
||||
}
|
||||
|
||||
fheConstraintsOpt = fheConstraintsOrErr.get();
|
||||
void CompilerEngine::setGenerateClientParameters(bool v) {
|
||||
this->generateClientParameters = v;
|
||||
}
|
||||
|
||||
void CompilerEngine::setMaxEintPrecision(size_t v) {
|
||||
this->overrideMaxEintPrecision = v;
|
||||
}
|
||||
|
||||
void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; }
|
||||
|
||||
void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) {
|
||||
this->clientParametersFuncName = name.str();
|
||||
}
|
||||
|
||||
void CompilerEngine::setEnablePass(
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
this->enablePass = enablePass;
|
||||
}
|
||||
|
||||
// Returns the overwritten V0FHEConstraint or try to compute them from HLFHE
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
CompilerEngine::getV0FHEConstraint(CompilationResult &res) {
|
||||
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
mlir::ModuleOp module = res.mlirModuleRef->get();
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
|
||||
// If the values has been overwritten returns
|
||||
if (this->overrideMaxEintPrecision.hasValue() &&
|
||||
this->overrideMaxMANP.hasValue()) {
|
||||
return mlir::zamalang::V0FHEConstraint{
|
||||
this->overrideMaxMANP.getValue(),
|
||||
this->overrideMaxEintPrecision.getValue()};
|
||||
}
|
||||
// Else compute constraint from HLFHE
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
fheConstraintsOrErr =
|
||||
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(
|
||||
mlirContext, module, enablePass);
|
||||
|
||||
mlir::zamalang::V0FHEConstraint fheConstraints = fheConstraintsOpt.getValue();
|
||||
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
|
||||
|
||||
if (!parameter) {
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream strs(buffer);
|
||||
strs << "Could not determine V0 parameters for 2-norm of "
|
||||
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
|
||||
|
||||
return llvm::make_error<llvm::StringError>(strs.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
mlir::zamalang::V0FHEContext fheContext{fheConstraints, *parameter};
|
||||
|
||||
// Lower to MLIR Std
|
||||
if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext,
|
||||
false)
|
||||
.failed()) {
|
||||
return llvm::make_error<llvm::StringError>("failed to lower to MLIR Std",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
fheContext, "main", module_ref.get());
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
if (auto err = fheConstraintsOrErr.takeError())
|
||||
return std::move(err);
|
||||
}
|
||||
auto maybeKeySet =
|
||||
mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0);
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
keySet = std::move(maybeKeySet.get());
|
||||
|
||||
// Lower to MLIR LLVM Dialect
|
||||
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(*context, module, false)
|
||||
.failed()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"failed to lower to LLVM dialect", llvm::inconvertibleErrorCode());
|
||||
return fheConstraintsOrErr.get();
|
||||
}
|
||||
|
||||
// set the fheContext field if the v0Constraint can be computed
|
||||
llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) {
|
||||
auto fheConstraintOrErr = getV0FHEConstraint(res);
|
||||
if (auto err = fheConstraintOrErr.takeError())
|
||||
return std::move(err);
|
||||
if (!fheConstraintOrErr.get().hasValue()) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
const mlir::zamalang::V0Parameter *fheParams =
|
||||
getV0Parameter(fheConstraintOrErr.get().getValue());
|
||||
|
||||
if (!fheParams) {
|
||||
return StreamStringError()
|
||||
<< "Could not determine V0 parameters for 2-norm of "
|
||||
<< (*fheConstraintOrErr)->norm2 << " and p of "
|
||||
<< (*fheConstraintOrErr)->p;
|
||||
}
|
||||
res.fheContext.emplace(mlir::zamalang::V0FHEContext{
|
||||
(*fheConstraintOrErr).getValue(), *fheParams});
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda::Argument>>
|
||||
CompilerEngine::buildArgument() {
|
||||
if (keySet.get() == nullptr) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"CompilerEngine::buildArgument: invalid engine state, the keySet has "
|
||||
"not been generated",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
return JITLambda::Argument::create(*keySet);
|
||||
}
|
||||
// Compile the sources managed by the source manager `sm` to the
|
||||
// target dialect `target`. If successful, the result can be retrieved
|
||||
// using `getModule()` and `getLLVMModule()`, respectively depending
|
||||
// on the target dialect.
|
||||
llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(llvm::SourceMgr &sm, Target target) {
|
||||
CompilationResult res(this->compilationContext);
|
||||
|
||||
llvm::Error CompilerEngine::invoke(JITLambda::Argument &arg) {
|
||||
// Create the JIT lambda
|
||||
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
auto module = module_ref.get();
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
|
||||
if (auto err = maybeLambda.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = maybeLambda.get()->invoke(arg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
|
||||
llvm::Expected<uint64_t> CompilerEngine::run(std::vector<uint64_t> args) {
|
||||
// Build the argument of the JIT lambda.
|
||||
auto maybeArgument = buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
return std::move(err);
|
||||
mlir::SourceMgrDiagnosticVerifierHandler smHandler(sm, &mlirContext);
|
||||
mlirContext.printOpOnDiagnostic(false);
|
||||
|
||||
mlir::OwningModuleRef mlirModuleRef =
|
||||
mlir::parseSourceFile<mlir::ModuleOp>(sm, &mlirContext);
|
||||
|
||||
if (this->verifyDiagnostics) {
|
||||
if (smHandler.verify().failed())
|
||||
return StreamStringError("Verification of diagnostics failed");
|
||||
else
|
||||
return res;
|
||||
}
|
||||
// Set the integer arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, args[i])) {
|
||||
return std::move(err);
|
||||
|
||||
if (!mlirModuleRef)
|
||||
return StreamStringError("Could not parse source");
|
||||
|
||||
res.mlirModuleRef = std::move(mlirModuleRef);
|
||||
mlir::ModuleOp module = res.mlirModuleRef->get();
|
||||
|
||||
if (target == Target::ROUND_TRIP)
|
||||
return res;
|
||||
|
||||
// HLFHE High level pass to determine FHE parameters
|
||||
if (auto err = this->determineFHEParameters(res))
|
||||
return std::move(err);
|
||||
if (target == Target::HLFHE)
|
||||
return res;
|
||||
|
||||
// HLFHE -> MidLFHE
|
||||
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Lowering from HLFHE to MidLFHE failed");
|
||||
}
|
||||
if (target == Target::MIDLFHE)
|
||||
return res;
|
||||
|
||||
// MidLFHE -> LowLFHE
|
||||
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
|
||||
mlirContext, module, res.fheContext, this->enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Lowering from MidLFHE to LowLFHE failed");
|
||||
}
|
||||
if (target == Target::LOWLFHE)
|
||||
return res;
|
||||
|
||||
// LowLFHE -> Canonical dialects
|
||||
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from LowLFHE to canonical MLIR dialects failed");
|
||||
}
|
||||
if (target == Target::STD)
|
||||
return res;
|
||||
|
||||
// Generate client parameters if requested
|
||||
if (this->generateClientParameters) {
|
||||
if (!this->clientParametersFuncName.hasValue()) {
|
||||
return StreamStringError(
|
||||
"Generation of client parameters requested, but no function name "
|
||||
"specified");
|
||||
}
|
||||
if (!res.fheContext.hasValue()) {
|
||||
return StreamStringError(
|
||||
"Cannot generate client parameters, the fhe context is empty");
|
||||
}
|
||||
|
||||
llvm::Expected<mlir::zamalang::ClientParameters> clientParametersOrErr =
|
||||
mlir::zamalang::createClientParametersForV0(
|
||||
*res.fheContext, *this->clientParametersFuncName, module);
|
||||
|
||||
if (llvm::Error err = clientParametersOrErr.takeError())
|
||||
return std::move(err);
|
||||
|
||||
res.clientParameters = clientParametersOrErr.get();
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = invoke(*arguments)) {
|
||||
return std::move(err);
|
||||
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module,
|
||||
enablePass)
|
||||
.failed()) {
|
||||
return StreamStringError("Failed to lower to LLVM dialect");
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
return std::move(err);
|
||||
|
||||
if (target == Target::LLVM)
|
||||
return res;
|
||||
|
||||
// Lowering to actual LLVM IR (i.e., not the LLVM dialect)
|
||||
llvm::LLVMContext &llvmContext = *this->compilationContext->getLLVMContext();
|
||||
|
||||
res.llvmModule = mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(
|
||||
mlirContext, llvmContext, module);
|
||||
|
||||
if (!res.llvmModule)
|
||||
return StreamStringError("Failed to convert from LLVM dialect to LLVM IR");
|
||||
|
||||
if (target == Target::LLVM_IR)
|
||||
return res;
|
||||
|
||||
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *res.llvmModule)
|
||||
.failed()) {
|
||||
return StreamStringError("Failed to optimize LLVM IR");
|
||||
}
|
||||
|
||||
if (target == Target::OPTIMIZED_LLVM_IR)
|
||||
return res;
|
||||
|
||||
return res;
|
||||
} // namespace zamalang
|
||||
|
||||
// Compile the source `s` to the target dialect `target`. If successful, the
|
||||
// result can be retrieved using `getModule()` and `getLLVMModule()`,
|
||||
// respectively depending on the target dialect.
|
||||
llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(llvm::StringRef s, Target target) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
||||
llvm::Expected<CompilationResult> res = this->compile(std::move(mb), target);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Compile the contained in `buffer` to the target dialect
|
||||
// `target`. If successful, the result can be retrieved using
|
||||
// `getModule()` and `getLLVMModule()`, respectively depending on the
|
||||
// target dialect.
|
||||
llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
Target target) {
|
||||
llvm::SourceMgr sm;
|
||||
|
||||
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
llvm::Expected<CompilationResult> res = this->compile(sm, target);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
12
compiler/lib/Support/Error.cpp
Normal file
12
compiler/lib/Support/Error.cpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#include <zamalang/Support/Error.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
// Specialized `operator<<` for `llvm::Error` that marks the error
|
||||
// as checked through `std::move` and `llvm::toString`
|
||||
StreamStringError &operator<<(StreamStringError &se, llvm::Error &err) {
|
||||
se << llvm::toString(std::move(err));
|
||||
return se;
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "llvm/Support/Error.h"
|
||||
#include <llvm/ADT/ArrayRef.h>
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
#include <llvm/ADT/StringRef.h>
|
||||
@@ -12,56 +13,6 @@
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// JIT-compiles `module` invokes `func` with the arguments passed in
|
||||
// `jitArguments` and `keySet`
|
||||
mlir::LogicalResult
|
||||
runJit(mlir::ModuleOp module, llvm::StringRef func,
|
||||
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create(func, module, optPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
::mlir::zamalang::log_error()
|
||||
<< "Cannot create lambda arguments: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
for (size_t i = 0; i < funcArgs.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, funcArgs[i])) {
|
||||
::mlir::zamalang::log_error()
|
||||
<< "Cannot push argument " << i << ": " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = lambda->invoke(*arguments)) {
|
||||
::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
llvm::errs() << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
|
||||
@@ -379,6 +330,20 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Result at pos %zu is not a tensor", pos);
|
||||
}
|
||||
|
||||
return info.shape.size;
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
|
||||
size_t size) {
|
||||
|
||||
|
||||
118
compiler/lib/Support/JitCompilerEngine.cpp
Normal file
118
compiler/lib/Support/JitCompilerEngine.cpp
Normal file
@@ -0,0 +1,118 @@
|
||||
#include "llvm/Support/Error.h"
|
||||
#include <llvm/ADT/STLExtras.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
#include <zamalang/Support/JitCompilerEngine.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
JitCompilerEngine::JitCompilerEngine(
|
||||
std::shared_ptr<CompilationContext> compilationContext,
|
||||
unsigned int optimizationLevel)
|
||||
: CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) {
|
||||
}
|
||||
|
||||
// Returns the `LLVMFuncOp` operation in the compiled module with the
|
||||
// specified name. If no LLVMFuncOp with that name exists or if there
|
||||
// was no prior call to `compile()` resulting in an MLIR module in the
|
||||
// LLVM dialect, an error is returned.
|
||||
llvm::Expected<mlir::LLVM::LLVMFuncOp>
|
||||
JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
|
||||
auto funcOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
|
||||
auto funcOp = llvm::find_if(
|
||||
funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; });
|
||||
|
||||
if (funcOp == funcOps.end()) {
|
||||
return StreamStringError()
|
||||
<< "Module does not contain function named '" << name.str() << "'";
|
||||
}
|
||||
|
||||
return *funcOp;
|
||||
}
|
||||
|
||||
// Build a lambda from the function with the name given in
|
||||
// `funcName` from the sources in `buffer`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
llvm::StringRef funcName) {
|
||||
llvm::SourceMgr sm;
|
||||
|
||||
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
llvm::Expected<JitCompilerEngine::Lambda> res =
|
||||
this->buildLambda(sm, funcName);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Build a lambda from the function with the name given in `funcName`
|
||||
// from the source string `s`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
||||
llvm::Expected<JitCompilerEngine::Lambda> res =
|
||||
this->buildLambda(std::move(mb), funcName);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Build a lambda from the function with the name given in
|
||||
// `funcName` from the sources managed by the source manager `sm`.
|
||||
llvm::Expected<JitCompilerEngine::Lambda>
|
||||
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
|
||||
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
|
||||
this->setGenerateClientParameters(true);
|
||||
this->setClientParametersFuncName(funcName);
|
||||
|
||||
// First, compile to LLVM Dialect
|
||||
llvm::Expected<CompilerEngine::CompilationResult> compResOrErr =
|
||||
this->compile(sm, Target::LLVM_IR);
|
||||
|
||||
if (!compResOrErr)
|
||||
return std::move(compResOrErr.takeError());
|
||||
|
||||
mlir::ModuleOp module = compResOrErr->mlirModuleRef->get();
|
||||
|
||||
// Locate function to JIT-compile
|
||||
llvm::Expected<mlir::LLVM::LLVMFuncOp> funcOrError =
|
||||
this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName);
|
||||
|
||||
if (!funcOrError)
|
||||
return std::move(funcOrError.takeError());
|
||||
|
||||
// Prepare LLVM infrastructure for JIT compilation
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::registerLLVMDialectTranslation(mlirContext);
|
||||
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
|
||||
mlir::zamalang::JITLambda::create(funcName, module, optPipeline);
|
||||
|
||||
// Generate the KeySet for encrypting lambda arguments, decrypting lambda
|
||||
// results
|
||||
if (!compResOrErr->clientParameters.hasValue()) {
|
||||
return StreamStringError("Cannot generate the keySet since client "
|
||||
"parameters has not been computed");
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::KeySet>> keySetOrErr =
|
||||
mlir::zamalang::KeySet::generate(*compResOrErr->clientParameters, 0, 0);
|
||||
|
||||
if (auto err = keySetOrErr.takeError())
|
||||
return std::move(err);
|
||||
|
||||
if (!lambdaOrErr)
|
||||
return std::move(lambdaOrErr.takeError());
|
||||
|
||||
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
|
||||
std::move(*keySetOrErr)};
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
7
compiler/lib/Support/LambdaArgument.cpp
Normal file
7
compiler/lib/Support/LambdaArgument.cpp
Normal file
@@ -0,0 +1,7 @@
|
||||
#include <zamalang/Support/LambdaArgument.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
char LambdaArgument::ID = 0;
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
@@ -19,35 +19,52 @@
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace pipeline {
|
||||
static void addPotentiallyNestedPass(mlir::PassManager &pm,
|
||||
std::unique_ptr<Pass> pass) {
|
||||
if (!pass->getOpName() || *pass->getOpName() == "builtin.module") {
|
||||
pm.addPass(std::move(pass));
|
||||
} else {
|
||||
pm.nest(*pass->getOpName()).addPass(std::move(pass));
|
||||
|
||||
static void pipelinePrinting(llvm::StringRef name, mlir::PassManager &pm,
|
||||
mlir::MLIRContext &ctx) {
|
||||
if (mlir::zamalang::isVerbose()) {
|
||||
mlir::zamalang::log_verbose()
|
||||
<< "##################################################\n"
|
||||
<< "### " << name << " pipeline\n";
|
||||
auto isModule = [](mlir::Pass *, mlir::Operation *op) {
|
||||
return mlir::isa<mlir::ModuleOp>(op);
|
||||
};
|
||||
ctx.disableMultithreading(true);
|
||||
pm.enableIRPrinting(isModule, isModule);
|
||||
pm.enableStatistics();
|
||||
pm.enableTiming();
|
||||
pm.enableVerifier();
|
||||
}
|
||||
}
|
||||
|
||||
// Creates an instance of the Minimal Arithmetic Noise Padding pass
|
||||
// and invokes it for all functions of `module`.
|
||||
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool debug) {
|
||||
mlir::PassManager pm(&context);
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::zamalang::createMANPPass(debug));
|
||||
return pm.run(module);
|
||||
static void
|
||||
addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr<Pass> pass,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
if (!enablePass(pass.get())) {
|
||||
return;
|
||||
}
|
||||
if (!pass->getOpName() || *pass->getOpName() == "builtin.module") {
|
||||
pm.addPass(std::move(pass));
|
||||
} else {
|
||||
mlir::OpPassManager &p = pm.nest(*pass->getOpName());
|
||||
p.addPass(std::move(pass));
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) {
|
||||
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
llvm::Optional<size_t> oMax2norm;
|
||||
llvm::Optional<size_t> oMaxWidth;
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass());
|
||||
pipelinePrinting("ComputeFHEConstraintOnHLFHE", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP,
|
||||
unsigned currMaxWidth) {
|
||||
pm,
|
||||
mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP,
|
||||
unsigned currMaxWidth) {
|
||||
assert((uint64_t)currMaxWidth < std::numeric_limits<size_t>::max() &&
|
||||
"Maximum width does not fit into size_t");
|
||||
|
||||
@@ -63,105 +80,95 @@ getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) {
|
||||
|
||||
if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width)
|
||||
oMaxWidth.emplace(width);
|
||||
}));
|
||||
|
||||
}),
|
||||
enablePass);
|
||||
if (pm.run(module.getOperation()).failed()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"Failed to determine the maximum Arithmetic Noise Padding and maximum"
|
||||
"required precision",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> ret;
|
||||
|
||||
if (oMax2norm.hasValue() && oMaxWidth.hasValue()) {
|
||||
ret = llvm::Optional<mlir::zamalang::V0FHEConstraint>(
|
||||
{.norm2 = ceilLog2(oMax2norm.getValue()), .p = oMaxWidth.getValue()});
|
||||
{/*.norm2 = */ ceilLog2(oMax2norm.getValue()),
|
||||
/*.p = */ oMaxWidth.getValue()});
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module, bool verbose) {
|
||||
mlir::LogicalResult
|
||||
lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("HLFHEToMidLFHE", pm, context);
|
||||
|
||||
if (verbose) {
|
||||
mlir::zamalang::log_verbose()
|
||||
<< "##################################################\n"
|
||||
<< "### HLFHE to MidLFHE pipeline\n";
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
|
||||
|
||||
pm.enableIRPrinting();
|
||||
pm.enableStatistics();
|
||||
pm.enableTiming();
|
||||
pm.enableVerifier();
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
llvm::Optional<V0FHEContext> &fheContext,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("MidLFHEToLowLFHE", pm, context);
|
||||
|
||||
if (fheContext.hasValue()) {
|
||||
addPotentiallyNestedPass(
|
||||
pm,
|
||||
mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(
|
||||
fheContext.getValue()),
|
||||
enablePass);
|
||||
}
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
|
||||
addPotentiallyNestedPass(pm,
|
||||
mlir::zamalang::createConvertHLFHEToMidLFHEPass());
|
||||
pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), enablePass);
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext,
|
||||
bool parametrize) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
if (parametrize) {
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(
|
||||
fheContext));
|
||||
}
|
||||
|
||||
addPotentiallyNestedPass(pm,
|
||||
mlir::zamalang::createConvertMidLFHEToLowLFHEPass());
|
||||
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module) {
|
||||
mlir::LogicalResult
|
||||
lowerLowLFHEToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("LowLFHEToStd", pm, context);
|
||||
pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass());
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
bool verbose) {
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::PassManager pm(&context);
|
||||
|
||||
if (verbose) {
|
||||
mlir::zamalang::log_verbose()
|
||||
<< "##################################################\n"
|
||||
<< "### MlirStdsDialectToMlirLLVMDialect pipeline\n";
|
||||
context.disableMultithreading();
|
||||
pm.enableIRPrinting();
|
||||
pm.enableStatistics();
|
||||
pm.enableTiming();
|
||||
pm.enableVerifier();
|
||||
}
|
||||
pipelinePrinting("StdToLLVM", pm, context);
|
||||
|
||||
// Unparametrize LowLFHE
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass());
|
||||
pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(), enablePass);
|
||||
|
||||
// Bufferize
|
||||
addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass());
|
||||
addPotentiallyNestedPass(pm, mlir::createStdBufferizePass());
|
||||
addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass());
|
||||
addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass());
|
||||
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass());
|
||||
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass());
|
||||
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass());
|
||||
addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createStdBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass(), enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass(),
|
||||
enablePass);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass());
|
||||
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
enablePass);
|
||||
|
||||
return pm.run(module);
|
||||
}
|
||||
@@ -179,7 +186,7 @@ lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context,
|
||||
|
||||
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
|
||||
llvm::Module &module) {
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline =
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
if (optPipeline(&module))
|
||||
@@ -188,18 +195,6 @@ mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
V0FHEContext &fheContext, bool verbose) {
|
||||
if (lowerHLFHEToMidLFHE(context, module, verbose).failed() ||
|
||||
lowerMidLFHEToLowLFHE(context, module, fheContext, true).failed() ||
|
||||
lowerLowLFHEToStd(context, module).failed()) {
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pipeline
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -18,5 +18,6 @@ StreamWrap<llvm::raw_ostream> &log_verbose(void) {
|
||||
// Sets up logging. If `verbose` is false, messages passed to
|
||||
// `log_verbose` will be discarded.
|
||||
void setupLogging(bool verbose) { ::mlir::zamalang::verbose = verbose; }
|
||||
bool isVerbose() { return verbose; }
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
55
compiler/setup.py
Normal file
55
compiler/setup.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
import subprocess
|
||||
import setuptools
|
||||
|
||||
from setuptools import Extension
|
||||
from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
class MakeExtension(Extension):
|
||||
def __init__(self, name, sourcedir=""):
|
||||
Extension.__init__(self, name, sources=[])
|
||||
self.sourcedir = os.path.abspath(sourcedir)
|
||||
|
||||
|
||||
class MakeBuild(build_ext):
|
||||
def run(self):
|
||||
for ext in self.extensions:
|
||||
self.build_extension(ext)
|
||||
|
||||
def build_extension(self, ext):
|
||||
subprocess.check_call(["make", "python-bindings"])
|
||||
|
||||
|
||||
setuptools.setup(
|
||||
name="concretefhe-compiler",
|
||||
version="0.1.0",
|
||||
author="Zama Team",
|
||||
author_email="hello@zama.ai",
|
||||
description="Concrete Compiler",
|
||||
license="",
|
||||
keywords="homomorphic encryption compiler",
|
||||
long_description=read("README.md"),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/zama-ai/homomorphizer",
|
||||
packages=setuptools.find_packages(
|
||||
where="build/tools/zamalang/python_packages/zamalang_core",
|
||||
include=["zamalang", "zamalang.*", "mlir", "mlir.*"],
|
||||
),
|
||||
package_dir={"": "build/tools/zamalang/python_packages/zamalang_core"},
|
||||
include_package_data=True,
|
||||
package_data={"": ["*.so"]},
|
||||
classifiers=[
|
||||
"Programming Language :: C++",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Topic :: Software Development :: Compilers",
|
||||
"Topic :: Security :: Cryptography",
|
||||
],
|
||||
ext_modules=[MakeExtension("python-bindings")],
|
||||
cmdclass=dict(build_ext=MakeBuild),
|
||||
zip_safe=False,
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
#include <llvm/Support/CommandLine.h>
|
||||
@@ -18,21 +19,19 @@
|
||||
#include "zamalang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/Error.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
#include "zamalang/Support/KeySet.h"
|
||||
#include "zamalang/Support/Pipeline.h"
|
||||
#include "zamalang/Support/logging.h"
|
||||
|
||||
enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM };
|
||||
|
||||
enum Action {
|
||||
ROUND_TRIP,
|
||||
DUMP_HLFHE_MANP,
|
||||
DUMP_HLFHE,
|
||||
DUMP_MIDLFHE,
|
||||
DUMP_LOWLFHE,
|
||||
DUMP_STD,
|
||||
@@ -76,30 +75,10 @@ llvm::cl::opt<std::string> output("o",
|
||||
llvm::cl::opt<bool> verbose("verbose", llvm::cl::desc("verbose logs"),
|
||||
llvm::cl::init<bool>(false));
|
||||
|
||||
llvm::cl::opt<bool> parametrizeMidLFHE(
|
||||
"parametrize-midlfhe",
|
||||
llvm::cl::desc("Perform MidLFHE global parametrization pass"),
|
||||
llvm::cl::init<bool>(true));
|
||||
|
||||
static llvm::cl::opt<enum EntryDialect> entryDialect(
|
||||
"e", "entry-dialect", llvm::cl::desc("Entry dialect"),
|
||||
llvm::cl::init<enum EntryDialect>(EntryDialect::HLFHE),
|
||||
llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required,
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::HLFHE, "hlfhe",
|
||||
"Input module is composed of HLFHE operations")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::MIDLFHE, "midlfhe",
|
||||
"Input module is composed of MidLFHE operations")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::LOWLFHE, "lowlfhe",
|
||||
"Input module is composed of LowLFHE operations")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::STD, "std",
|
||||
"Input module is composed of operations from std")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(EntryDialect::LLVM, "llvm",
|
||||
"Input module is composed of operations from llvm")));
|
||||
llvm::cl::list<std::string> passes(
|
||||
"passes",
|
||||
llvm::cl::desc("Specify the passes to run (use only for compiler tests)"),
|
||||
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
|
||||
|
||||
static llvm::cl::opt<enum Action> action(
|
||||
"a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired,
|
||||
@@ -107,9 +86,8 @@ static llvm::cl::opt<enum Action> action(
|
||||
llvm::cl::values(
|
||||
clEnumValN(Action::ROUND_TRIP, "roundtrip",
|
||||
"Parse input module and regenerate textual representation")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp",
|
||||
"Dump HLFHE module after running the Minimal "
|
||||
"Arithmetic Noise Padding pass")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe",
|
||||
"Dump HLFHE module")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe",
|
||||
"Lower to MidLFHE and dump result")),
|
||||
llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe",
|
||||
@@ -159,50 +137,7 @@ llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
|
||||
llvm::cl::desc(
|
||||
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
|
||||
|
||||
}; // namespace cmdline
|
||||
|
||||
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
std::unique_ptr<mlir::zamalang::KeySet>
|
||||
generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
|
||||
const std::string &jitFuncName) {
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet;
|
||||
|
||||
mlir::zamalang::log_verbose()
|
||||
<< "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2
|
||||
<< ", p:" << fheContext.constraint.p << "}\n";
|
||||
mlir::zamalang::log_verbose()
|
||||
<< "### FHE parameters for the atomic pattern: {k: "
|
||||
<< fheContext.parameter.k
|
||||
<< ", polynomialSize: " << fheContext.parameter.polynomialSize
|
||||
<< ", nSmall: " << fheContext.parameter.nSmall
|
||||
<< ", brLevel: " << fheContext.parameter.brLevel
|
||||
<< ", brLogBase: " << fheContext.parameter.brLogBase
|
||||
<< ", ksLevel: " << fheContext.parameter.ksLevel
|
||||
<< ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n";
|
||||
|
||||
// Create the client parameters
|
||||
auto clientParameter = mlir::zamalang::createClientParametersForV0(
|
||||
fheContext, jitFuncName, module);
|
||||
|
||||
if (auto err = clientParameter.takeError()) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "cannot generate client parameters: " << err << "\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mlir::zamalang::log_verbose() << "### Generate the key set\n";
|
||||
|
||||
auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0,
|
||||
0); // TODO: seed
|
||||
if (auto err = maybeKeySet.takeError()) {
|
||||
llvm::errs() << err;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::move(maybeKeySet.get());
|
||||
}
|
||||
} // namespace cmdline
|
||||
|
||||
llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
|
||||
@@ -210,65 +145,48 @@ llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
|
||||
llvm::Optional<size_t> overrideMaxMANP) {
|
||||
if (!autoFHEConstraints.hasValue() &&
|
||||
(!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
return mlir::zamalang::StreamStringError(
|
||||
"Maximum encrypted integer precision and maximum for the Minimal"
|
||||
"Arithmetic Noise Passing are required, but were neither specified"
|
||||
"explicitly nor determined automatically",
|
||||
llvm::inconvertibleErrorCode());
|
||||
"explicitly nor determined automatically");
|
||||
}
|
||||
|
||||
mlir::zamalang::V0FHEConstraint fheConstraints{
|
||||
.norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
|
||||
: autoFHEConstraints.getValue().norm2,
|
||||
.p = overrideMaxEintPrecision.hasValue()
|
||||
? overrideMaxEintPrecision.getValue()
|
||||
: autoFHEConstraints.getValue().p};
|
||||
overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
|
||||
: autoFHEConstraints.getValue().norm2,
|
||||
overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue()
|
||||
: autoFHEConstraints.getValue().p};
|
||||
|
||||
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
|
||||
|
||||
if (!parameter) {
|
||||
std::string buffer;
|
||||
llvm::raw_string_ostream strs(buffer);
|
||||
strs << "Could not determine V0 parameters for 2-norm of "
|
||||
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
|
||||
|
||||
return llvm::make_error<llvm::StringError>(strs.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
return mlir::zamalang::StreamStringError()
|
||||
<< "Could not determine V0 parameters for 2-norm of "
|
||||
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
|
||||
}
|
||||
|
||||
return mlir::zamalang::V0FHEContext{fheConstraints, *parameter};
|
||||
}
|
||||
|
||||
mlir::LogicalResult buildAssignFHEContext(
|
||||
llvm::Optional<mlir::zamalang::V0FHEContext> &fheContext,
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP) {
|
||||
namespace llvm {
|
||||
// This needs to be wrapped into the llvm namespace for proper
|
||||
// operator lookup
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
|
||||
const llvm::ArrayRef<uint64_t> arr) {
|
||||
os << "(";
|
||||
for (size_t i = 0; i < arr.size(); i++) {
|
||||
os << arr[i];
|
||||
|
||||
if (fheContext.hasValue())
|
||||
return mlir::success();
|
||||
|
||||
llvm::Expected<mlir::zamalang::V0FHEContext> fheContextOrErr =
|
||||
buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision,
|
||||
overrideMaxMANP);
|
||||
|
||||
if (auto err = fheContextOrErr.takeError()) {
|
||||
mlir::zamalang::log_error() << err;
|
||||
return mlir::failure();
|
||||
if (i != arr.size() - 1)
|
||||
os << ", ";
|
||||
}
|
||||
|
||||
fheContext.emplace(fheContextOrErr.get());
|
||||
|
||||
return mlir::success();
|
||||
return os;
|
||||
}
|
||||
} // namespace llvm
|
||||
|
||||
// Process a single source buffer
|
||||
//
|
||||
// The parameter `entryDialect` must specify the FHE dialect to which
|
||||
// belong all FHE operations used in the source buffer. The input
|
||||
// program must only contain FHE operations from that single FHE
|
||||
// dialect, otherwise processing might fail.
|
||||
//
|
||||
// The parameter `action` specifies how the buffer should be processed
|
||||
// and thus defines the output.
|
||||
//
|
||||
@@ -277,15 +195,14 @@ mlir::LogicalResult buildAssignFHEContext(
|
||||
// using the parameters given in `jitArgs`.
|
||||
//
|
||||
// The parameter `parametrizeMidLFHE` defines, whether the
|
||||
// parametrization pass for MidLFHE is executed. If the pair of
|
||||
// `entryDialect` and `action` does not involve any MidlFHE
|
||||
// manipulation, this parameter does not have any effect.
|
||||
// parametrization pass for MidLFHE is executed. If the `action` does
|
||||
// not involve any MidlFHE manipulation, this parameter does not have
|
||||
// any effect.
|
||||
//
|
||||
// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if
|
||||
// set, override the values for the maximum required precision of
|
||||
// encrypted integers and the maximum value for the Minimum Arithmetic
|
||||
// Noise Padding otherwise determined automatically if the entry
|
||||
// dialect is HLFHE..
|
||||
// Noise Padding otherwise determined automatically.
|
||||
//
|
||||
// If `verifyDiagnostics` is `true`, the procedure only checks if the
|
||||
// diagnostic messages provided in the source buffer using
|
||||
@@ -293,164 +210,106 @@ mlir::LogicalResult buildAssignFHEContext(
|
||||
// the procedure checks if the parsed module is valid and if all
|
||||
// requested transformations succeeded.
|
||||
//
|
||||
// If `verbose` is true, debug messages are displayed throughout the
|
||||
// compilation process.
|
||||
//
|
||||
// Compilation output is written to the stream specified by `os`.
|
||||
mlir::LogicalResult processInputBuffer(
|
||||
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
enum EntryDialect entryDialect, enum Action action,
|
||||
const std::string &jitFuncName, llvm::ArrayRef<uint64_t> jitArgs,
|
||||
bool parametrizeMidlHFE, llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
|
||||
bool verbose, llvm::raw_ostream &os) {
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
mlir::LogicalResult
|
||||
processInputBuffer(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
enum Action action, const std::string &jitFuncName,
|
||||
llvm::ArrayRef<uint64_t> jitArgs,
|
||||
llvm::Optional<size_t> overrideMaxEintPrecision,
|
||||
llvm::Optional<size_t> overrideMaxMANP,
|
||||
bool verifyDiagnostics, llvm::raw_ostream &os) {
|
||||
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
|
||||
mlir::zamalang::CompilationContext::createShared();
|
||||
|
||||
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
|
||||
&context);
|
||||
mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context);
|
||||
mlir::zamalang::JitCompilerEngine ce{ccx};
|
||||
|
||||
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
|
||||
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
|
||||
|
||||
std::unique_ptr<mlir::zamalang::KeySet> keySet = nullptr;
|
||||
|
||||
if (verbose)
|
||||
context.disableMultithreading();
|
||||
|
||||
if (verifyDiagnostics)
|
||||
return sourceMgrHandler.verify();
|
||||
|
||||
if (!moduleRef)
|
||||
return mlir::failure();
|
||||
|
||||
mlir::ModuleOp module = moduleRef.get();
|
||||
|
||||
if (action == Action::ROUND_TRIP) {
|
||||
module->print(os);
|
||||
return mlir::success();
|
||||
ce.setVerifyDiagnostics(verifyDiagnostics);
|
||||
if (cmdline::passes.size() != 0) {
|
||||
ce.setEnablePass([](mlir::Pass *pass) {
|
||||
return std::any_of(
|
||||
cmdline::passes.begin(), cmdline::passes.end(),
|
||||
[&](const std::string &p) { return pass->getArgument() == p; });
|
||||
});
|
||||
}
|
||||
|
||||
// Lowering pipeline. Each stage is represented as a label in the
|
||||
// switch statement, from the most abstract dialect to the lowest
|
||||
// level. Every labels acts as an entry point into the pipeline with
|
||||
// a fallthrough mechanism to the next stage. Actions act as exit
|
||||
// points from the pipeline.
|
||||
switch (entryDialect) {
|
||||
case EntryDialect::HLFHE:
|
||||
if (action == Action::DUMP_HLFHE_MANP) {
|
||||
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
if (overrideMaxEintPrecision.hasValue())
|
||||
ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue());
|
||||
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
} else {
|
||||
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
|
||||
fheConstraintsOrErr =
|
||||
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context,
|
||||
module);
|
||||
if (auto err = fheConstraintsOrErr.takeError()) {
|
||||
mlir::zamalang::log_error() << err;
|
||||
return mlir::failure();
|
||||
} else {
|
||||
fheConstraints = fheConstraintsOrErr.get();
|
||||
}
|
||||
}
|
||||
if (overrideMaxMANP.hasValue())
|
||||
ce.setMaxMANP(overrideMaxMANP.getValue());
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
if (action == Action::JIT_INVOKE) {
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
ce.buildLambda(std::move(buffer), jitFuncName);
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::MIDLFHE:
|
||||
if (action == Action::DUMP_MIDLFHE) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (buildAssignFHEContext(fheContext, fheConstraints,
|
||||
overrideMaxEintPrecision, overrideMaxMANP)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
|
||||
context, module, fheContext.getValue(), parametrizeMidlHFE)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::LOWLFHE:
|
||||
if (action == Action::DUMP_LOWLFHE) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed())
|
||||
return mlir::failure();
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::STD:
|
||||
if (action == Action::DUMP_STD) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
} else if (action == Action::JIT_INVOKE) {
|
||||
if (buildAssignFHEContext(fheContext, fheConstraints,
|
||||
overrideMaxEintPrecision, overrideMaxMANP)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
keySet = generateKeySet(module, fheContext.getValue(), jitFuncName);
|
||||
}
|
||||
|
||||
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module,
|
||||
verbose)
|
||||
.failed())
|
||||
return mlir::failure();
|
||||
|
||||
// fallthrough
|
||||
case EntryDialect::LLVM: {
|
||||
if (action == Action::DUMP_LLVM_DIALECT) {
|
||||
module.print(os);
|
||||
return mlir::success();
|
||||
} else if (action == Action::JIT_INVOKE) {
|
||||
return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet,
|
||||
defaultOptPipeline, os);
|
||||
}
|
||||
|
||||
llvm::LLVMContext llvmContext;
|
||||
std::unique_ptr<llvm::Module> llvmModule =
|
||||
mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext,
|
||||
module);
|
||||
|
||||
if (!llvmModule) {
|
||||
if (!lambdaOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Failed to translate LLVM dialect to LLVM IR\n";
|
||||
<< "Failed to JIT-compile " << jitFuncName << ": "
|
||||
<< llvm::toString(std::move(lambdaOrErr.takeError()));
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (action == Action::DUMP_LLVM_IR) {
|
||||
llvmModule->dump();
|
||||
return mlir::success();
|
||||
}
|
||||
llvm::Expected<uint64_t> resOrErr = (*lambdaOrErr)(jitArgs);
|
||||
|
||||
if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule)
|
||||
.failed()) {
|
||||
mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n";
|
||||
if (!resOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Failed to JIT-invoke " << jitFuncName << " with arguments "
|
||||
<< jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError()));
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (action == Action::DUMP_OPTIMIZED_LLVM_IR) {
|
||||
llvmModule->dump();
|
||||
return mlir::success();
|
||||
os << *resOrErr << "\n";
|
||||
} else {
|
||||
enum mlir::zamalang::CompilerEngine::Target target;
|
||||
|
||||
switch (action) {
|
||||
case Action::ROUND_TRIP:
|
||||
target = mlir::zamalang::CompilerEngine::Target::ROUND_TRIP;
|
||||
break;
|
||||
case Action::DUMP_HLFHE:
|
||||
target = mlir::zamalang::CompilerEngine::Target::HLFHE;
|
||||
break;
|
||||
case Action::DUMP_MIDLFHE:
|
||||
target = mlir::zamalang::CompilerEngine::Target::MIDLFHE;
|
||||
break;
|
||||
case Action::DUMP_LOWLFHE:
|
||||
target = mlir::zamalang::CompilerEngine::Target::LOWLFHE;
|
||||
break;
|
||||
case Action::DUMP_STD:
|
||||
target = mlir::zamalang::CompilerEngine::Target::STD;
|
||||
break;
|
||||
case Action::DUMP_LLVM_DIALECT:
|
||||
target = mlir::zamalang::CompilerEngine::Target::LLVM;
|
||||
break;
|
||||
case Action::DUMP_LLVM_IR:
|
||||
target = mlir::zamalang::CompilerEngine::Target::LLVM_IR;
|
||||
break;
|
||||
case Action::DUMP_OPTIMIZED_LLVM_IR:
|
||||
target = mlir::zamalang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
|
||||
break;
|
||||
case JIT_INVOKE:
|
||||
// Case just here to satisfy the compiler; already handled above
|
||||
break;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
|
||||
ce.compile(std::move(buffer), target);
|
||||
|
||||
if (!retOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (verifyDiagnostics) {
|
||||
return mlir::success();
|
||||
} else if (action == Action::DUMP_LLVM_IR ||
|
||||
action == Action::DUMP_OPTIMIZED_LLVM_IR) {
|
||||
retOrErr->llvmModule->print(os, nullptr);
|
||||
} else {
|
||||
retOrErr->mlirModuleRef->get().print(os);
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
@@ -460,45 +319,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
// Parse command line arguments
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
|
||||
// Initialize the MLIR context
|
||||
mlir::MLIRContext context;
|
||||
|
||||
mlir::zamalang::setupLogging(cmdline::verbose);
|
||||
|
||||
// String for error messages from library functions
|
||||
std::string errorMessage;
|
||||
|
||||
if (cmdline::action == Action::DUMP_HLFHE_MANP &&
|
||||
cmdline::entryDialect != EntryDialect::HLFHE) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (cmdline::action == Action::JIT_INVOKE &&
|
||||
cmdline::entryDialect != EntryDialect::HLFHE &&
|
||||
cmdline::entryDialect != EntryDialect::MIDLFHE &&
|
||||
cmdline::entryDialect != EntryDialect::LOWLFHE &&
|
||||
cmdline::entryDialect != EntryDialect::STD) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Load our Dialect in this MLIR Context.
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHELinalg::HLFHELinalgDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::zamalang::LowLFHE::LowLFHEDialect>();
|
||||
context.getOrLoadDialect<mlir::StandardOpsDialect>();
|
||||
context.getOrLoadDialect<mlir::memref::MemRefDialect>();
|
||||
context.getOrLoadDialect<mlir::linalg::LinalgDialect>();
|
||||
context.getOrLoadDialect<mlir::tensor::TensorDialect>();
|
||||
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
if (cmdline::verifyDiagnostics)
|
||||
context.printOpOnDiagnostic(false);
|
||||
|
||||
auto output = mlir::openOutputFile(cmdline::output, &errorMessage);
|
||||
|
||||
if (!output) {
|
||||
@@ -525,20 +350,18 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
|
||||
[&](std::unique_ptr<llvm::MemoryBuffer> inputBuffer,
|
||||
llvm::raw_ostream &os) {
|
||||
return processInputBuffer(
|
||||
context, std::move(inputBuffer), cmdline::entryDialect,
|
||||
cmdline::action, cmdline::jitFuncName, cmdline::jitArgs,
|
||||
cmdline::parametrizeMidLFHE,
|
||||
std::move(inputBuffer), cmdline::action,
|
||||
cmdline::jitFuncName, cmdline::jitArgs,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, cmdline::verbose, os);
|
||||
cmdline::verifyDiagnostics, os);
|
||||
},
|
||||
output->os())))
|
||||
return mlir::failure();
|
||||
} else {
|
||||
return processInputBuffer(
|
||||
context, std::move(file), cmdline::entryDialect, cmdline::action,
|
||||
cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE,
|
||||
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
|
||||
cmdline::verifyDiagnostics, cmdline::verbose, output->os());
|
||||
std::move(file), cmdline::action, cmdline::jitFuncName,
|
||||
cmdline::jitArgs, cmdline::assumeMaxEintPrecision,
|
||||
cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, output->os());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
|
||||
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --passes hlfhe-to-midlfhe --action=dump-midlfhe 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}>
|
||||
func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api --action=dump-std %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module
|
||||
// CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list)
|
||||
|
||||
7
compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir
Normal file
7
compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir
Normal file
@@ -0,0 +1,7 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> {
|
||||
// CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_>
|
||||
return %arg0: !LowLFHE.lwe_ciphertext<1024,4>
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
// RUN: zamacompiler --passes lowlfhe-unparametrize --action=dump-llvm-dialect %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_>
|
||||
func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<_,_> {
|
||||
// CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_>
|
||||
%0 = builtin.unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_>
|
||||
return %0: !LowLFHE.lwe_ciphertext<_,_>
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
|
||||
func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
|
||||
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4>
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
|
||||
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
|
||||
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --passes midlfhe-to-lowlfhe --action=dump-lowlfhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
|
||||
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
|
||||
// RUN: zamacompiler --passes MANP --action=dump-hlfhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
func @single_zero() -> !HLFHE.eint<2>
|
||||
{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s
|
||||
// RUN: zamacompiler --passes MANP --action=dump-hlfhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
func @tensor_from_elements_1(%a: !HLFHE.eint<2>, %b: !HLFHE.eint<2>, %c: !HLFHE.eint<2>, %d: !HLFHE.eint<2>) -> tensor<4x!HLFHE.eint<2>>
|
||||
{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// Incompatible shapes
|
||||
func @dot_incompatible_shapes(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: eint support only precision in ]0;7]
|
||||
func @test(%arg0: !HLFHE.eint<8>) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: eint support only precision in ]0;7]
|
||||
func @test(%arg0: !HLFHE.eint<0>) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals
|
||||
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals
|
||||
func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals
|
||||
func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
|
||||
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals
|
||||
func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1
|
||||
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals
|
||||
func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @zero() -> !HLFHE.eint<2>
|
||||
func @zero() -> !HLFHE.eint<2> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0) -> (d0)>
|
||||
//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>
|
||||
func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.add_eint_int
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.add_eint_int
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
|
||||
func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen
|
||||
func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter result
|
||||
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter
|
||||
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// Bad dimension of the lookup table
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}>
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter
|
||||
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s
|
||||
// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
// GLWE p parameter
|
||||
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s
|
||||
// RUN: zamacompiler %s --action=roundtrip 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}>
|
||||
func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> {
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_compile_and_run(mlir_input, args, expected_result):
|
||||
def test_compile_and_run_invalid_arg_number(mlir_input, args):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
with pytest.raises(RuntimeError, match=r"failed pushing integer argument"):
|
||||
with pytest.raises(ValueError, match=r"wrong number of arguments"):
|
||||
engine.run(*args)
|
||||
|
||||
|
||||
|
||||
@@ -2,14 +2,23 @@ enable_testing()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
|
||||
add_executable(
|
||||
end_to_end_jit_test
|
||||
end_to_end_jit_clear_tensor.cc
|
||||
end_to_end_jit_encrypted_tensor.cc
|
||||
end_to_end_jit_hlfhelinalg.cc
|
||||
end_to_end_jit_test.cc
|
||||
)
|
||||
|
||||
set_source_files_properties(
|
||||
end_to_end_jit_test.cc
|
||||
end_to_end_jit_clear_tensor.cc
|
||||
end_to_end_jit_encrypted_tensor.cc
|
||||
end_to_end_jit_hlfhelinalg.cc
|
||||
|
||||
PROPERTIES COMPILE_FLAGS "-fno-rtti"
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
end_to_end_jit_test
|
||||
gtest_main
|
||||
|
||||
@@ -5,382 +5,324 @@
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, identity) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(
|
||||
R"XXX(
|
||||
func @main(%t: tensor<10xi64>) -> tensor<10xi64> {
|
||||
return %t : tensor<10xi64>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint64_t arg[size]{0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, arg, size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[size];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, result, size));
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
EXPECT_EQ(arg[i], result[i]) << "result differ at index " << i;
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint64_t arg[]{0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>(arg, ARRAY_SIZE(arg));
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (size_t)10);
|
||||
|
||||
for (size_t i = 0; i < res->size(); i++) {
|
||||
EXPECT_EQ(arg[i], res->operator[](i)) << "result differ at index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, extract_64) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi64>, %i: index) -> i64{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi64>
|
||||
return %c : i64
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint64_t arg[]{0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(arg); i++) {
|
||||
ASSERT_EXPECTED_VALUE(lambda(arg, ARRAY_SIZE(arg), i), arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, extract_32) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi32>
|
||||
return %c : i32
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint32_t arg[]{0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(arg); i++) {
|
||||
ASSERT_EXPECTED_VALUE(lambda(arg, ARRAY_SIZE(arg), i), arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, extract_16) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi16>, %i: index) -> i16{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi16>
|
||||
return %c : i16
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
|
||||
63269, 36435, 52380, 7401, 13313};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint16_t arg[]{0xFFFF, 0, 59589, 47826, 16227,
|
||||
63269, 36435, 52380, 7401, 13313};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(arg); i++) {
|
||||
ASSERT_EXPECTED_VALUE(lambda(arg, ARRAY_SIZE(arg), i), arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, extract_8) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi8>, %i: index) -> i8{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi8>
|
||||
return %c : i8
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint8_t arg[]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(arg); i++) {
|
||||
ASSERT_EXPECTED_VALUE(lambda(arg, ARRAY_SIZE(arg), i), arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi5>, %i: index) -> i5{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi5>
|
||||
return %c : i5
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint8_t arg[]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(arg); i++) {
|
||||
ASSERT_EXPECTED_VALUE(lambda(arg, ARRAY_SIZE(arg), i), arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, extract_1) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi1>, %i: index) -> i1{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi1>
|
||||
return %c : i1
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint8_t arg[]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(arg); i++) {
|
||||
ASSERT_EXPECTED_VALUE(lambda(arg, ARRAY_SIZE(arg), i), arg[i]);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// 2D tensor //////////////////////////////////////////////////////////////////
|
||||
// 2D tensor
|
||||
//////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
const size_t numDim = 2;
|
||||
const int64_t dim0 = 2;
|
||||
const int64_t dim1 = 10;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
const uint64_t tensor2D[dim0][dim1]{
|
||||
{0xFFFFFFFFFFFFFFFF, 0, 8978, 2587490, 90, 197864, 698735, 72132, 87474,
|
||||
42},
|
||||
{986, 1873, 298493, 34939, 443, 59874, 43, 743, 8409, 9433},
|
||||
static std::vector<uint64_t> tensor2D{
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42,
|
||||
986,
|
||||
1873,
|
||||
298493,
|
||||
34939,
|
||||
443,
|
||||
59874,
|
||||
43,
|
||||
743,
|
||||
8409,
|
||||
9433,
|
||||
};
|
||||
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
|
||||
#define GET_2D(tensor, i, j) (tensor)[i * dims[1] + j]
|
||||
|
||||
#define TENSOR2D_GET(i, j) GET_2D(tensor2D, i, j)
|
||||
|
||||
TEST(End2EndJit_ClearTensor_2D, identity) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<2x10xi64>) -> tensor<2x10xi64> {
|
||||
return %t : tensor<2x10xi64>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[dims[0]][dims[1]];
|
||||
ASSERT_LLVM_ERROR(
|
||||
argument->getResult(0, (uint64_t *)result, dims[0] * dims[1]));
|
||||
for (size_t i = 0; i < dims[0]; i++) {
|
||||
for (size_t j = 0; j < dims[1]; j++) {
|
||||
EXPECT_EQ(tensor2D[i][j], result[i][j])
|
||||
<< "result differ at pos " << i << "," << j;
|
||||
}
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), tensor2D.size());
|
||||
|
||||
for (size_t i = 0; i < tensor2D.size(); i++) {
|
||||
EXPECT_EQ(tensor2D[i], (*res)[i]) << "result differ at pos " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_2D, extract) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<2x10xi64>, %i: index, %j: index) -> i64 {
|
||||
%c = tensor.extract %t[%i, %j] : tensor<2x10xi64>
|
||||
return %c : i64
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
for (size_t i = 0; i < dims[0]; i++) {
|
||||
for (size_t j = 0; j < dims[1]; j++) {
|
||||
// Set %i, %j
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, j));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, tensor2D[i][j]);
|
||||
auto pos = i * dims[1] + j;
|
||||
mlir::zamalang::IntLambdaArgument<size_t> argi(i);
|
||||
mlir::zamalang::IntLambdaArgument<size_t> argj(j);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&arg, &argi, &argj}), TENSOR2D_GET(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_2D, extract_slice) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
|
||||
%r = tensor.extract_slice %t[1, 5][1, 5][1, 1] : tensor<2x10xi64> to tensor<1x5xi64>
|
||||
return %r : tensor<1x5xi64>
|
||||
}
|
||||
)XXX";
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[1][5];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 1 * 5));
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
|
||||
%r = tensor.extract_slice %t[1, 5][1, 5][1, 1] : tensor<2x10xi64> to
|
||||
tensor<1x5xi64> return %r : tensor<1x5xi64>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 1 * 5);
|
||||
|
||||
// Check the sub slice
|
||||
for (size_t i = 0; i < 1; i++) {
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ(result[i][j], tensor2D[i + 1][j + 5]);
|
||||
}
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ((*res)[j], TENSOR2D_GET(1, j + 5));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_2D, extract_slice_stride) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
|
||||
%r = tensor.extract_slice %t[1, 0][1, 5][1, 2] : tensor<2x10xi64> to tensor<1x5xi64>
|
||||
return %r : tensor<1x5xi64>
|
||||
}
|
||||
)XXX";
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[1][5];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 1 * 5));
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<2x10xi64>) -> tensor<1x5xi64> {
|
||||
%r = tensor.extract_slice %t[1, 0][1, 5][1, 2] : tensor<2x10xi64> to
|
||||
tensor<1x5xi64> return %r : tensor<1x5xi64>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 1 * 5);
|
||||
|
||||
// Check the sub slice
|
||||
for (size_t i = 0; i < 1; i++) {
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ(result[i][j], tensor2D[i + 1][j * 2]);
|
||||
}
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ((*res)[j], TENSOR2D_GET(1, j * 2));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_2D, insert_slice) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> {
|
||||
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2xi64> into tensor<2x10xi64>
|
||||
return %r : tensor<2x10xi64>
|
||||
}
|
||||
)XXX";
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t0 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint64_t *)tensor2D, shape2D));
|
||||
// Set the %t1 argument
|
||||
int64_t t1_dim[2] = {2, 2};
|
||||
uint64_t t1[2][2]{{6, 9}, {4, 0}};
|
||||
ASSERT_LLVM_ERROR(
|
||||
argument->setArg(1, (uint64_t *)t1, llvm::ArrayRef<int64_t>(t1_dim, 2)));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[dim0][dim1];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, dim0 * dim1));
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t0: tensor<2x10xi64>, %t1: tensor<2x2xi64>) -> tensor<2x10xi64> {
|
||||
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2xi64>
|
||||
into tensor<2x10xi64> return %r : tensor<2x10xi64>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>
|
||||
t0(tensor2D, shape2D);
|
||||
int64_t t1Shape[] = {2, 2};
|
||||
uint64_t t1Buffer[]{6, 9, 4, 0};
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>
|
||||
t1(t1Buffer, t1Shape);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&t0, &t1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), tensor2D.size());
|
||||
|
||||
// Check the sub slice
|
||||
for (size_t i = 0; i < dim0; i++) {
|
||||
for (size_t j = 0; j < dim1; j++) {
|
||||
if (j < 5 || j >= 5 + 2) {
|
||||
ASSERT_EQ(result[i][j], tensor2D[i][j])
|
||||
ASSERT_EQ(GET_2D(*res, i, j), TENSOR2D_GET(i, j))
|
||||
<< "at indexes (" << i << "," << j << ")";
|
||||
} else {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ(result[i][j], t1[i][j - 5])
|
||||
ASSERT_EQ(GET_2D(*res, i, j), t1Buffer[i * 2 + j - 5])
|
||||
<< "at indexes (" << i << "," << j << ")";
|
||||
;
|
||||
}
|
||||
|
||||
@@ -8,164 +8,155 @@ const size_t numDim = 2;
|
||||
const int64_t dim0 = 2;
|
||||
const int64_t dim1 = 10;
|
||||
const int64_t dims[numDim]{dim0, dim1};
|
||||
const uint8_t tensor2D[dim0][dim1]{
|
||||
{63, 12, 7, 43, 52, 9, 26, 34, 22, 0},
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
static std::vector<uint8_t> tensor2D{
|
||||
63, 12, 7, 43, 52, 9, 26, 34, 22, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
|
||||
};
|
||||
const llvm::ArrayRef<int64_t> shape2D(dims, numDim);
|
||||
#define GET_2D(tensor, i, j) (tensor)[i * dims[1] + j]
|
||||
|
||||
#define TENSOR2D_GET(i, j) GET_2D(tensor2D, i, j)
|
||||
|
||||
TEST(End2EndJit_EncryptedTensor_2D, identity) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<2x10x!HLFHE.eint<6>> {
|
||||
return %t : tensor<2x10x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
)XXX");
|
||||
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[dims[0]][dims[1]];
|
||||
ASSERT_LLVM_ERROR(
|
||||
argument->getResult(0, (uint64_t *)result, dims[0] * dims[1]));
|
||||
for (size_t i = 0; i < dims[0]; i++) {
|
||||
for (size_t j = 0; j < dims[1]; j++) {
|
||||
EXPECT_EQ(tensor2D[i][j], result[i][j])
|
||||
<< "result differ at pos " << i << "," << j;
|
||||
}
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), tensor2D.size());
|
||||
|
||||
for (size_t i = 0; i < tensor2D.size(); i++) {
|
||||
EXPECT_EQ(tensor2D[i], (*res)[i]) << "result differ at pos " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_EncryptedTensor_2D, extract) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<2x10x!HLFHE.eint<6>>, %i: index, %j: index) -> !HLFHE.eint<6> {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<2x10x!HLFHE.eint<6>>, %i: index, %j: index) ->
|
||||
!HLFHE.eint<6> {
|
||||
%c = tensor.extract %t[%i, %j] : tensor<2x10x!HLFHE.eint<6>>
|
||||
return %c : !HLFHE.eint<6>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
for (size_t i = 0; i < dims[0]; i++) {
|
||||
for (size_t j = 0; j < dims[1]; j++) {
|
||||
// Set %i, %j
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, j));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, tensor2D[i][j]);
|
||||
auto pos = i * dims[1] + j;
|
||||
mlir::zamalang::IntLambdaArgument<size_t> argi(i);
|
||||
mlir::zamalang::IntLambdaArgument<size_t> argj(j);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&arg, &argi, &argj}), TENSOR2D_GET(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_EncryptedTensor_2D, extract_slice) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> {
|
||||
%r = tensor.extract_slice %t[1, 5][1, 5][1, 1] : tensor<2x10x!HLFHE.eint<6>> to tensor<1x5x!HLFHE.eint<6>>
|
||||
return %r : tensor<1x5x!HLFHE.eint<6>>
|
||||
%r = tensor.extract_slice %t[1, 5][1, 5][1, 1] :
|
||||
tensor<2x10x!HLFHE.eint<6>> to tensor<1x5x!HLFHE.eint<6>> return %r :
|
||||
tensor<1x5x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX";
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 1 * 5);
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[1][5];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 1 * 5));
|
||||
// Check the sub slice
|
||||
for (size_t i = 0; i < 1; i++) {
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ(result[i][j], tensor2D[i + 1][j + 5]);
|
||||
}
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ((*res)[j], TENSOR2D_GET(1, j + 5));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_EncryptedTensor_2D, extract_slice_stride) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> {
|
||||
%r = tensor.extract_slice %t[1, 0][1, 5][1, 2] : tensor<2x10x!HLFHE.eint<6>> to tensor<1x5x!HLFHE.eint<6>>
|
||||
return %r : tensor<1x5x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX";
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
func @main(%t: tensor<2x10x!HLFHE.eint<6>>) -> tensor<1x5x!HLFHE.eint<6>> {
|
||||
%r = tensor.extract_slice %t[1, 0][1, 5][1, 2] :
|
||||
tensor<2x10x!HLFHE.eint<6>> to tensor<1x5x!HLFHE.eint<6>> return %r :
|
||||
tensor<1x5x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
arg(tensor2D, shape2D);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 1 * 5);
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[1][5];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, 1 * 5));
|
||||
// Check the sub slice
|
||||
for (size_t i = 0; i < 1; i++) {
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ(result[i][j], tensor2D[i + 1][j * 2]);
|
||||
}
|
||||
for (size_t j = 0; j < 5; j++) {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ((*res)[j], TENSOR2D_GET(1, j * 2));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_EncryptedTensor_2D, insert_slice) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>) -> tensor<2x10x!HLFHE.eint<6>> {
|
||||
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] : tensor<2x2x!HLFHE.eint<6>> into tensor<2x10x!HLFHE.eint<6>>
|
||||
return %r : tensor<2x10x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX";
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
func @main(%t0: tensor<2x10x!HLFHE.eint<6>>, %t1: tensor<2x2x!HLFHE.eint<6>>)
|
||||
-> tensor<2x10x!HLFHE.eint<6>> {
|
||||
%r = tensor.insert_slice %t1 into %t0[0, 5][2, 2][1, 1] :
|
||||
tensor<2x2x!HLFHE.eint<6>> into tensor<2x10x!HLFHE.eint<6>> return %r :
|
||||
tensor<2x10x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
t0(tensor2D, shape2D);
|
||||
int64_t t1Shape[] = {2, 2};
|
||||
uint8_t t1Buffer[]{6, 9, 4, 0};
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
t1(t1Buffer, t1Shape);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&t0, &t1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), tensor2D.size());
|
||||
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t0 argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, (uint8_t *)tensor2D, shape2D));
|
||||
// Set the %t1 argument
|
||||
int64_t t1_dim[2] = {2, 2};
|
||||
uint8_t t1[2][2]{{6, 9}, {4, 0}};
|
||||
ASSERT_LLVM_ERROR(
|
||||
argument->setArg(1, (uint8_t *)t1, llvm::ArrayRef<int64_t>(t1_dim, 2)));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t result[dim0][dim1];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, (uint64_t *)result, dim0 * dim1));
|
||||
// Check the sub slice
|
||||
for (size_t i = 0; i < dim0; i++) {
|
||||
for (size_t j = 0; j < dim1; j++) {
|
||||
if (j < 5 || j >= 5 + 2) {
|
||||
ASSERT_EQ(result[i][j], tensor2D[i][j])
|
||||
ASSERT_EQ(GET_2D(*res, i, j), TENSOR2D_GET(i, j))
|
||||
<< "at indexes (" << i << "," << j << ")";
|
||||
} else {
|
||||
// Get and assert the result
|
||||
ASSERT_EQ(result[i][j], t1[i][j - 5])
|
||||
ASSERT_EQ(GET_2D(*res, i, j), t1Buffer[i * 2 + j - 5])
|
||||
<< "at indexes (" << i << "," << j << ")";
|
||||
;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,224 +1,308 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <gtest/gtest.h>
|
||||
#include <type_traits>
|
||||
|
||||
#include "end_to_end_jit_test.h"
|
||||
|
||||
mlir::zamalang::V0FHEConstraint defaultV0Constraints() {
|
||||
return {.norm2 = 10, .p = 7};
|
||||
}
|
||||
mlir::zamalang::V0FHEConstraint defaultV0Constraints() { return {10, 7}; }
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_eint) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
auto maybeResult = engine.run({1, 2});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, 3);
|
||||
)XXX");
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), 3);
|
||||
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), 9);
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), 2);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_eint_2) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<4> {
|
||||
%cst = constant dense<[2, 1, 3, 10]> : tensor<4xi64>
|
||||
%0 = "HLFHE.apply_lookup_table"(%arg0, %cst) : (!HLFHE.eint<2>, tensor<4xi64>) -> !HLFHE.eint<4>
|
||||
return %0 : !HLFHE.eint<4>
|
||||
// Same as CompileAndRunHLFHE::add_eint above, but using
|
||||
// `LambdaArgument` instances
|
||||
TEST(CompileAndRunHLFHE, add_eint_lambda_argument) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
mlir::zamalang::IntLambdaArgument<> ila1(1);
|
||||
mlir::zamalang::IntLambdaArgument<> ila2(2);
|
||||
mlir::zamalang::IntLambdaArgument<> ila7(7);
|
||||
mlir::zamalang::IntLambdaArgument<> ila9(9);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila2}), 3);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila7, &ila9}), 16);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila7}), 8);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila9}), 10);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunHLFHE, add_u64) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%1 = addi %arg0, %arg1 : i64
|
||||
return %1: i64
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), (uint64_t)3);
|
||||
ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), (uint64_t)9);
|
||||
ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), (uint64_t)2);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_64) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi64>, %i: index) -> i64{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi64>
|
||||
return %c : i64
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint64_t t_arg[] = {0xFFFFFFFFFFFFFFFF,
|
||||
0,
|
||||
8978,
|
||||
2587490,
|
||||
90,
|
||||
197864,
|
||||
698735,
|
||||
72132,
|
||||
87474,
|
||||
42};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_32) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi32>
|
||||
return %c : i32
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
static uint32_t t_arg[] = {0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
// Same as `CompileAndRunTensorStd::extract_32` above, but using
|
||||
// `LambdaArgument` instances
|
||||
TEST(CompileAndRunTensorStd, extract_32_lambda_argument) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi32>, %i: index) -> i32{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi32>
|
||||
return %c : i32
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
static std::vector<uint32_t> t_arg{0xFFFFFFFF, 0, 8978, 2587490, 90,
|
||||
197864, 698735, 72132, 87474, 42};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint32_t>>
|
||||
tla(t_arg);
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) {
|
||||
mlir::zamalang::IntLambdaArgument<size_t> idx(i);
|
||||
ASSERT_EXPECTED_VALUE(lambda({&tla, &idx}), t_arg[i]);
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
auto maybeResult = engine.run({0});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, 2);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_16) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi16>, %i: index) -> i16{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi16>
|
||||
return %c : i16
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
uint16_t t_arg[] = {0xFFFF, 0, 59589, 47826, 16227,
|
||||
63269, 36435, 52380, 7401, 13313};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_8) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi8>, %i: index) -> i8{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi8>
|
||||
return %c : i8
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t t_arg[] = {0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_5) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi5>, %i: index) -> i5{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi5>
|
||||
return %c : i5
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorStd, extract_1) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10xi1>, %i: index) -> i1{
|
||||
%c = tensor.extract %t[%i] : tensor<10xi1>
|
||||
return %c : i1
|
||||
}
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t t_arg[] = {0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
|
||||
%c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i]);
|
||||
}
|
||||
)XXX");
|
||||
|
||||
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) ->
|
||||
!HLFHE.eint<5>{
|
||||
%ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>>
|
||||
%tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>>
|
||||
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5>
|
||||
return %c : !HLFHE.eint<5>
|
||||
%c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
|
||||
!HLFHE.eint<5> return %c : !HLFHE.eint<5>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Set the %i argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, i));
|
||||
// Set the %j argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, j));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, t_arg[i] + t_arg[j]);
|
||||
}
|
||||
}
|
||||
)XXX");
|
||||
|
||||
static uint8_t t_arg[] = {3, 0, 7, 12, 14, 6, 5, 4, 1, 2};
|
||||
|
||||
for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++)
|
||||
for (size_t j = 0; j < ARRAY_SIZE(t_arg); j++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i, j),
|
||||
t_arg[i] + t_arg[j]);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, dim_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
|
||||
%c0 = constant 0 : index
|
||||
%c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>>
|
||||
return %c : index
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
const size_t size = 10;
|
||||
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res = 0;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, size);
|
||||
)XXX");
|
||||
|
||||
static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
|
||||
ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg)), ARRAY_SIZE(t_arg));
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, from_elements_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
|
||||
%t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>>
|
||||
return %t: tensor<1x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the %t argument
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, 10));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
const size_t size_res = 1;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], 10);
|
||||
)XXX");
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>(10_u64);
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
ASSERT_EQ(res->size(), (size_t)1);
|
||||
ASSERT_EQ(res->at(0), 10_u64);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
|
||||
%c_0 = constant 0 : index
|
||||
%c_1 = constant 1 : index
|
||||
%a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>>
|
||||
%b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>>
|
||||
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>)
|
||||
%out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
|
||||
%aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) ->
|
||||
(!HLFHE.eint<5>) %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>,
|
||||
!HLFHE.eint<5>) -> (!HLFHE.eint<5>) %bplusb = "HLFHE.add_eint"(%b, %b):
|
||||
(!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %out =
|
||||
tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>>
|
||||
return %out: tensor<3x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set the argument
|
||||
const size_t in_size = 2;
|
||||
uint8_t in[in_size] = {2, 16};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
const size_t size_res = 3;
|
||||
uint64_t t_res[size_res];
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res));
|
||||
ASSERT_EQ(t_res[0], in[0] + in[0]);
|
||||
ASSERT_EQ(t_res[1], in[0] + in[1]);
|
||||
ASSERT_EQ(t_res[2], in[1] + in[1]);
|
||||
)XXX");
|
||||
|
||||
static uint8_t in[] = {2, 16};
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>(in, ARRAY_SIZE(in));
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (size_t)3);
|
||||
ASSERT_EQ(res->at(0), (uint64_t)(in[0] + in[0]));
|
||||
ASSERT_EQ(res->at(1), (uint64_t)(in[0] + in[1]));
|
||||
ASSERT_EQ(res->at(2), (uint64_t)(in[1] + in[1]));
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, linalg_generic) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
#map1 = affine_map<(d0) -> (0)>
|
||||
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc:
|
||||
!HLFHE.eint<7>) -> !HLFHE.eint<7> {
|
||||
%tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>>
|
||||
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) {
|
||||
^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors
|
||||
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7>
|
||||
%5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7>
|
||||
linalg.yield %5 : !HLFHE.eint<7>
|
||||
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types
|
||||
= ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>)
|
||||
outs(%tacc : tensor<1x!HLFHE.eint<7>>) { ^bb0(%arg2: !HLFHE.eint<7>, %arg3:
|
||||
i8, %arg4: !HLFHE.eint<7>): // no predecessors
|
||||
%4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) ->
|
||||
!HLFHE.eint<7> %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>,
|
||||
!HLFHE.eint<7>) -> !HLFHE.eint<7> linalg.yield %5 : !HLFHE.eint<7>
|
||||
} -> tensor<1x!HLFHE.eint<7>>
|
||||
%c0 = constant 0 : index
|
||||
%ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints()));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set arg0, arg1, acc
|
||||
const size_t in_size = 2;
|
||||
uint8_t arg0[in_size] = {2, 8};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
|
||||
uint8_t arg1[in_size] = {6, 8};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
|
||||
ASSERT_LLVM_ERROR(argument->setArg(2, 0));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, 76);
|
||||
)XXX",
|
||||
"main", "true");
|
||||
|
||||
static uint8_t arg0[] = {2, 8};
|
||||
static uint8_t arg1[] = {6, 8};
|
||||
|
||||
llvm::Expected<uint64_t> res =
|
||||
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1), 0_u64);
|
||||
|
||||
ASSERT_EXPECTED_VALUE(res, 76);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
|
||||
%arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
{
|
||||
@@ -226,68 +310,62 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>,
|
||||
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
|
||||
auto maybeArgument = engine.buildArgument();
|
||||
ASSERT_LLVM_ERROR(maybeArgument.takeError());
|
||||
auto argument = std::move(maybeArgument.get());
|
||||
// Set arg0, arg1, acc
|
||||
const size_t in_size = 4;
|
||||
uint8_t arg0[in_size] = {0, 1, 2, 3};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size));
|
||||
uint8_t arg1[in_size] = {0, 1, 2, 3};
|
||||
ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size));
|
||||
// Invoke the function
|
||||
ASSERT_LLVM_ERROR(engine.invoke(*argument));
|
||||
// Get and assert the result
|
||||
uint64_t res;
|
||||
ASSERT_LLVM_ERROR(argument->getResult(0, res));
|
||||
ASSERT_EQ(res, 14);
|
||||
)XXX");
|
||||
static uint8_t arg0[] = {0, 1, 2, 3};
|
||||
static uint8_t arg1[] = {0, 1, 2, 3};
|
||||
|
||||
llvm::Expected<uint64_t> res =
|
||||
lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1));
|
||||
|
||||
ASSERT_EXPECTED_VALUE(res, 14);
|
||||
}
|
||||
|
||||
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {
|
||||
protected:
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
void compile(std::string mlirStr) { ASSERT_FALSE(engine.compile(mlirStr)); }
|
||||
void run(std::vector<uint64_t> args, uint64_t expected) {
|
||||
auto maybeResult = engine.run(args);
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
if (result == expected) {
|
||||
ASSERT_TRUE(true);
|
||||
} else {
|
||||
// TODO: Better way to test the probability of exactness
|
||||
llvm::errs() << "one fail retry\n";
|
||||
maybeResult = engine.run(args);
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, expected);
|
||||
}
|
||||
}
|
||||
};
|
||||
class CompileAndRunWithPrecision : public ::testing::TestWithParam<int> {};
|
||||
|
||||
TEST_P(CompileAndRunWithPrecision, identity_func) {
|
||||
int precision = GetParam();
|
||||
uint64_t precision = GetParam();
|
||||
std::ostringstream mlirProgram;
|
||||
auto sizeOfTLU = 1 << precision;
|
||||
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
|
||||
<< ">) -> !HLFHE.eint<" << precision << "> { \n";
|
||||
mlirProgram << " %tlu = std.constant dense<[0";
|
||||
for (auto i = 1; i < sizeOfTLU; i++) {
|
||||
mlirProgram << ", " << i;
|
||||
}
|
||||
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n";
|
||||
mlirProgram << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
|
||||
"(!HLFHE.eint<"
|
||||
<< precision << ">, tensor<" << sizeOfTLU
|
||||
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n ";
|
||||
mlirProgram << "return %1: !HLFHE.eint<" << precision << ">\n";
|
||||
uint64_t sizeOfTLU = 1 << precision;
|
||||
|
||||
mlirProgram << "}\n";
|
||||
llvm::errs() << mlirProgram.str();
|
||||
compile(mlirProgram.str());
|
||||
for (auto i = 0; i < sizeOfTLU; i++) {
|
||||
run({(uint64_t)i}, i);
|
||||
mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision
|
||||
<< ">) -> !HLFHE.eint<" << precision << "> { \n"
|
||||
<< " %tlu = std.constant dense<[0";
|
||||
|
||||
for (uint64_t i = 1; i < sizeOfTLU; i++)
|
||||
mlirProgram << ", " << i;
|
||||
|
||||
mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n"
|
||||
<< " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): "
|
||||
<< "(!HLFHE.eint<" << precision << ">, tensor<" << sizeOfTLU
|
||||
<< "xi64>) -> (!HLFHE.eint<" << precision << ">)\n "
|
||||
<< "return %1: !HLFHE.eint<" << precision << ">\n"
|
||||
<< "}\n";
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda =
|
||||
checkedJit(mlirProgram.str());
|
||||
|
||||
if (precision == 7) {
|
||||
// Test fails with a probability of 5% for a precision of 7. The
|
||||
// probability of the test failing 5 times in a row is .05^5,
|
||||
// which is less than 1:10,000 and comparable to the probability
|
||||
// of failure for the other values.
|
||||
static const int max_tries = 3;
|
||||
|
||||
for (uint64_t i = 0; i < sizeOfTLU; i++) {
|
||||
for (int retry = 0; retry <= max_tries; retry++) {
|
||||
if (retry == max_tries)
|
||||
GTEST_FATAL_FAILURE_("Maximum number of tries exceeded");
|
||||
|
||||
llvm::Expected<uint64_t> val = lambda(i);
|
||||
ASSERT_EXPECTED_SUCCESS(val);
|
||||
|
||||
if (*val == i)
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (uint64_t i = 0; i < sizeOfTLU; i++)
|
||||
ASSERT_EXPECTED_VALUE(lambda(i), i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,8 +373,7 @@ INSTANTIATE_TEST_SUITE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision,
|
||||
::testing::Values(1, 2, 3, 4, 5, 6, 7));
|
||||
|
||||
TEST(TestHLFHEApplyLookupTable, multiple_precision) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
|
||||
%tlu_7 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64>
|
||||
%tlu_3 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64>
|
||||
@@ -305,45 +382,22 @@ func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> {
|
||||
%a_plus_b = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<6>, !HLFHE.eint<6>) -> (!HLFHE.eint<6>)
|
||||
return %a_plus_b: !HLFHE.eint<6>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
uint64_t arg0 = 23;
|
||||
uint64_t arg1 = 7;
|
||||
uint64_t expected = 30;
|
||||
auto maybeResult = engine.run({arg0, arg1});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, expected);
|
||||
)XXX");
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(23_u64, 7_u64), 30);
|
||||
}
|
||||
|
||||
TEST(CompileAndRunTLU, random_func) {
|
||||
mlir::zamalang::CompilerEngine engine;
|
||||
auto mlirStr = R"XXX(
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
func @main(%arg0: !HLFHE.eint<6>) -> !HLFHE.eint<6> {
|
||||
%tlu = std.constant dense<[16, 91, 16, 83, 80, 74, 21, 96, 1, 63, 49, 122, 76, 89, 74, 55, 109, 110, 103, 54, 105, 14, 66, 47, 52, 89, 7, 10, 73, 44, 119, 92, 25, 104, 123, 100, 108, 86, 29, 121, 118, 52, 107, 48, 34, 37, 13, 122, 107, 48, 74, 59, 96, 36, 50, 55, 120, 72, 27, 45, 12, 5, 96, 12]> : tensor<64xi64>
|
||||
%1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<6>, tensor<64xi64>) -> (!HLFHE.eint<6>)
|
||||
return %1: !HLFHE.eint<6>
|
||||
}
|
||||
)XXX";
|
||||
ASSERT_FALSE(engine.compile(mlirStr));
|
||||
// first value
|
||||
auto maybeResult = engine.run({5});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
uint64_t result = maybeResult.get();
|
||||
ASSERT_EQ(result, 74);
|
||||
// second value
|
||||
maybeResult = engine.run({62});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, 96);
|
||||
// edge value low
|
||||
maybeResult = engine.run({0});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, 16);
|
||||
// edge value high
|
||||
maybeResult = engine.run({63});
|
||||
ASSERT_TRUE((bool)maybeResult);
|
||||
result = maybeResult.get();
|
||||
ASSERT_EQ(result, 12);
|
||||
)XXX");
|
||||
|
||||
ASSERT_EXPECTED_VALUE(lambda(5_u64), 74);
|
||||
ASSERT_EXPECTED_VALUE(lambda(62_u64), 96);
|
||||
ASSERT_EXPECTED_VALUE(lambda(0_u64), 16);
|
||||
ASSERT_EXPECTED_VALUE(lambda(63_u64), 12);
|
||||
}
|
||||
|
||||
@@ -4,13 +4,117 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
|
||||
mlir::zamalang::V0FHEConstraint defaultV0Constraints();
|
||||
|
||||
#define ASSERT_LLVM_ERROR(err) \
|
||||
if (err) { \
|
||||
llvm::errs() << "error: " << err << "\n"; \
|
||||
llvm::errs() << "error: " << std::move(err) << "\n"; \
|
||||
ASSERT_TRUE(false); \
|
||||
}
|
||||
|
||||
// Checks that the value `val` is not in an error state. Returns
|
||||
// `true` if the test passes, otherwise `false`.
|
||||
template <typename T>
|
||||
static bool assert_expected_success(llvm::Expected<T> &val) {
|
||||
if (!((bool)val)) {
|
||||
llvm::errs() << llvm::toString(std::move(val.takeError()));
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks that the value `val` is not in an error state. Returns
|
||||
// `true` if the test passes, otherwise `false`.
|
||||
template <typename T>
|
||||
static bool assert_expected_success(llvm::Expected<T> &&val) {
|
||||
return assert_expected_success(val);
|
||||
}
|
||||
|
||||
// Checks that the value `val` of type `llvm::Expected<T>` is not in
|
||||
// an error state.
|
||||
#define ASSERT_EXPECTED_SUCCESS(val) \
|
||||
do { \
|
||||
if (!assert_expected_success(val)) \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> contained in error state"); \
|
||||
} while (0)
|
||||
|
||||
// Checks that the value `val` is not in an error state and is equal
|
||||
// to the value given in `exp`. Returns `true` if the test passes,
|
||||
// otherwise `false`.
|
||||
template <typename T, typename V>
|
||||
static bool assert_expected_value(llvm::Expected<T> &val, const V &exp) {
|
||||
if (!assert_expected_success(val))
|
||||
return false;
|
||||
|
||||
if (!(val.get() == static_cast<T>(exp))) {
|
||||
llvm::errs() << "Expected value " << exp << ", but got " << val.get()
|
||||
<< "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks that the value `val` is not in an error state and is equal
|
||||
// to the value given in `exp`. Returns `true` if the test passes,
|
||||
// otherwise `false`.
|
||||
template <typename T, typename V>
|
||||
static bool assert_expected_value(llvm::Expected<T> &&val, const V &exp) {
|
||||
return assert_expected_value(val, exp);
|
||||
}
|
||||
|
||||
// Checks that the value `val` of type `llvm::Expected<T>` is not in
|
||||
// an error state and is equal to the value of type `T` given in
|
||||
// `exp`.
|
||||
#define ASSERT_EXPECTED_VALUE(val, exp) \
|
||||
do { \
|
||||
if (!assert_expected_value(val, exp)) { \
|
||||
GTEST_FATAL_FAILURE_("Expected<T> with wrong value"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Jit-compiles the function specified by `func` from `src` and
|
||||
// returns the corresponding lambda. Any compilation errors are caught
|
||||
// and reult in abnormal termination.
|
||||
template <typename F>
|
||||
mlir::zamalang::JitCompilerEngine::Lambda
|
||||
internalCheckedJit(F checkfunc, llvm::StringRef src,
|
||||
llvm::StringRef func = "main",
|
||||
bool useDefaultFHEConstraints = false) {
|
||||
mlir::zamalang::JitCompilerEngine engine;
|
||||
|
||||
if (useDefaultFHEConstraints)
|
||||
engine.setFHEConstraints(defaultV0Constraints());
|
||||
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(src, func);
|
||||
|
||||
checkfunc(lambdaOrErr);
|
||||
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
|
||||
// Shorthands to create integer literals of a specific type
|
||||
static inline uint8_t operator"" _u8(unsigned long long int v) { return v; }
|
||||
static inline uint16_t operator"" _u16(unsigned long long int v) { return v; }
|
||||
static inline uint32_t operator"" _u32(unsigned long long int v) { return v; }
|
||||
static inline uint64_t operator"" _u64(unsigned long long int v) { return v; }
|
||||
|
||||
// Evaluates to the number of elements of a statically initialized
|
||||
// array
|
||||
#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0]))
|
||||
|
||||
// Wrapper around `internalCheckedJit` that causes
|
||||
// `ASSERT_EXPECTED_SUCCESS` to use the file and line number of the
|
||||
// caller instead of `internalCheckedJit`.
|
||||
#define checkedJit(...) \
|
||||
internalCheckedJit( \
|
||||
[](llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> &lambda) { \
|
||||
ASSERT_EXPECTED_SUCCESS(lambda); \
|
||||
}, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#endif
|
||||
Submodule llvm-project updated: f1e9ecea44...55e76c70a4
Reference in New Issue
Block a user