diff --git a/compiler/include/zamalang-c/Support/CompilerEngine.h b/compiler/include/zamalang-c/Support/CompilerEngine.h index d7e4dbd8e..834b30c50 100644 --- a/compiler/include/zamalang-c/Support/CompilerEngine.h +++ b/compiler/include/zamalang-c/Support/CompilerEngine.h @@ -5,15 +5,17 @@ #include "mlir-c/Registration.h" #include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/ExecutionArgument.h" +#include "zamalang/Support/Jit.h" +#include "zamalang/Support/JitCompilerEngine.h" #ifdef __cplusplus extern "C" { #endif -struct compilerEngine { - mlir::zamalang::CompilerEngine *ptr; +struct lambda { + mlir::zamalang::JitCompilerEngine::Lambda *ptr; }; -typedef struct compilerEngine compilerEngine; +typedef struct lambda lambda; struct executionArguments { mlir::zamalang::ExecutionArgument *data; @@ -21,13 +23,12 @@ struct executionArguments { }; typedef struct executionArguments exectuionArguments; -// Compile an MLIR module -MLIR_CAPI_EXPORTED void compilerEngineCompile(compilerEngine engine, - const char *module); +MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda +buildLambda(const char *module, const char *funcName); -// Run the compiled module -MLIR_CAPI_EXPORTED uint64_t compilerEngineRun(compilerEngine e, - executionArguments args); +MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args); + +MLIR_CAPI_EXPORTED std::string roundTrip(const char *module); #ifdef __cplusplus } diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index f7dbc981f..e9c496b1a 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -1,49 +1,138 @@ #ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H #define ZAMALANG_SUPPORT_COMPILER_ENGINE_H -#include "Jit.h" +#include +#include +#include +#include +#include +#include +#include +#include namespace mlir { namespace zamalang { -/// CompilerEngine is an tools that provides tools to implements the compilation -/// flow and manage the compilation flow state. +// Compilation context that acts as the root owner of LLVM and MLIR +// data structures directly and indirectly referenced by artefacts +// produced by the `CompilerEngine`. +class CompilationContext { +public: + CompilationContext(); + ~CompilationContext(); + + mlir::MLIRContext *getMLIRContext(); + llvm::LLVMContext *getLLVMContext(); + + static std::shared_ptr createShared(); + +protected: + mlir::MLIRContext *mlirContext; + llvm::LLVMContext *llvmContext; +}; + class CompilerEngine { public: - CompilerEngine() { - context = new mlir::MLIRContext(); - loadDialects(); - } - ~CompilerEngine() { - if (context != nullptr) - delete context; - } + // Result of an invocation of the `CompilerEngine` with optional + // fields for the results produced by different stages. + class CompilationResult { + public: + CompilationResult(std::shared_ptr compilationContext = + CompilationContext::createShared()) + : compilationContext(compilationContext) {} - // Compile an mlir programs from it's textual representation. - llvm::Error compile( - std::string mlirStr, - llvm::Optional overrideConstraints = {}); + llvm::Optional mlirModuleRef; + llvm::Optional clientParameters; + std::unique_ptr keySet; + std::unique_ptr llvmModule; + llvm::Optional fheContext; - // Build the jit lambda argument. - llvm::Expected> buildArgument(); + protected: + std::shared_ptr compilationContext; + }; - // Call the compiled function with and argument object. - llvm::Error invoke(JITLambda::Argument &arg); + // Specification of the exit stage of the compilation pipeline + enum class Target { + // Only read sources and produce corresponding MLIR module + ROUND_TRIP, - // Call the compiled function with a list of integer arguments. - llvm::Expected run(std::vector args); + // Read sources and exit before any lowering + HLFHE, - // Get a printable representation of the compiled module - std::string getCompiledModule(); + // Read sources and attempt to run the Minimal Arithmetic Noise + // Padding pass + HLFHE_MANP, + + // Read sources and lower all HLFHE operations to MidLFHE + // operations + MIDLFHE, + + // Read sources and lower all HLFHE and MidLFHE operations to LowLFHE + // operations + LOWLFHE, + + // Read sources and lower all HLFHE, MidLFHE and LowLFHE + // operations to canonical MLIR dialects. Cryptographic operations + // are lowered to invocations of the concrete library. + STD, + + // Read sources and lower all HLFHE, MidLFHE and LowLFHE + // operations to operations from the LLVM dialect. Cryptographic + // operations are lowered to invocations of the concrete library. + LLVM, + + // Same as `LLVM`, but lowers to actual LLVM IR instead of the + // LLVM dialect + LLVM_IR, + + // Same as `LLVM_IR`, but invokes the LLVM optimization pipeline + // to produce optimized LLVM IR + OPTIMIZED_LLVM_IR + }; + + CompilerEngine(std::shared_ptr compilationContext) + : overrideMaxEintPrecision(), overrideMaxMANP(), + clientParametersFuncName(), verifyDiagnostics(false), + generateKeySet(false), generateClientParameters(false), + parametrizeMidLFHE(true), compilationContext(compilationContext) {} + + llvm::Expected compile(llvm::StringRef s, Target target); + + llvm::Expected + compile(std::unique_ptr buffer, Target target); + + llvm::Expected compile(llvm::SourceMgr &sm, Target target); + + void setFHEConstraints(const mlir::zamalang::V0FHEConstraint &c); + void setMaxEintPrecision(size_t v); + void setMaxMANP(size_t v); + void setVerifyDiagnostics(bool v); + void setGenerateKeySet(bool v); + void setGenerateClientParameters(bool v); + void setParametrizeMidLFHE(bool v); + void setClientParametersFuncName(const llvm::StringRef &name); + +protected: + llvm::Optional overrideMaxEintPrecision; + llvm::Optional overrideMaxMANP; + llvm::Optional clientParametersFuncName; + bool verifyDiagnostics; + bool generateKeySet; + bool generateClientParameters; + bool parametrizeMidLFHE; + + std::shared_ptr compilationContext; + + // Helper enum identifying an FHE dialect (`HLFHE`, `MIDLFHE`, `LOWLFHE`) + // or indicating that no FHE dialect is used (`NONE`). + enum class FHEDialect { HLFHE, MIDLFHE, LOWLFHE, NONE }; + static FHEDialect detectHighestFHEDialect(mlir::ModuleOp module); private: - // Load the necessary dialects into the engine's context - void loadDialects(); - - mlir::OwningModuleRef module_ref; - mlir::MLIRContext *context; - std::unique_ptr keySet; + llvm::Error lowerParamDependentHalf(Target target, CompilationResult &res); + llvm::Error determineFHEParameters(CompilationResult &res, bool noOverride); }; + } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index 835749b0f..7c22a117c 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -9,11 +9,6 @@ namespace mlir { namespace zamalang { -mlir::LogicalResult -runJit(mlir::ModuleOp module, llvm::StringRef func, - llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, - std::function optPipeline, - llvm::raw_ostream &os); /// JITLambda is a tool to JIT compile an mlir module and to invoke a function /// of the module. diff --git a/compiler/include/zamalang/Support/JitCompilerEngine.h b/compiler/include/zamalang/Support/JitCompilerEngine.h new file mode 100644 index 000000000..8957c0bb3 --- /dev/null +++ b/compiler/include/zamalang/Support/JitCompilerEngine.h @@ -0,0 +1,296 @@ +#ifndef ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H +#define ZAMALANG_SUPPORT_JIT_COMPILER_ENGINE_H + +#include +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { + +namespace { +// Generic function template as well as specializations of +// `typedResult` must be declared at namespace scope due to return +// type template specialization + +// Helper function for `JitCompilerEngine::Lambda::operator()` +// implementing type-dependent preparation of the result. +template +llvm::Expected typedResult(JITLambda::Argument &arguments); + +// Specialization of `typedResult()` for scalar results, forwarding +// scalar value to caller +template <> +inline llvm::Expected typedResult(JITLambda::Argument &arguments) { + uint64_t res = 0; + + if (auto err = arguments.getResult(0, res)) + return StreamStringError() << "Cannot retrieve result:" << err; + + return res; +} + +// Specialization of `typedResult()` for vector results, initializing +// an `std::vector` of the right size with the results and forwarding +// it to the caller with move semantics. +template <> +inline llvm::Expected> +typedResult(JITLambda::Argument &arguments) { + llvm::Expected n = arguments.getResultVectorSize(0); + + if (auto err = n.takeError()) + return std::move(err); + + std::vector res(*n); + + if (auto err = arguments.getResult(0, res.data(), res.size())) + return StreamStringError() << "Cannot retrieve result:" << err; + + return std::move(res); +} + +// Adaptor class that adds arguments specified as instances of +// `LambdaArgument` to `JitLambda::Argument`. +class JITLambdaArgumentAdaptor { +public: + // Checks if the argument `arg` is an plaintext / encrypted integer + // argument or a plaintext / encrypted tensor argument with a + // backing integer type `IntT` and adds the argument to `jla` at + // position `pos`. + // + // Returns `true` if `arg` has one of the types above and its value + // was successfully added to `jla`, `false` if none of the types + // matches or an error if a type matched, but adding the argument to + // `jla` failed. + template + static inline llvm::Expected + tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { + if (auto ila = arg.dyn_cast>()) { + if (llvm::Error err = jla.setArg(pos, ila->getValue())) + return std::move(err); + else + return true; + } else if (auto tla = arg.dyn_cast< + TensorLambdaArgument>>()) { + llvm::Expected numElements = tla->getNumElements(); + + if (!numElements) + return std::move(numElements.takeError()); + + if (llvm::Error err = jla.setArg(pos, tla->getValue(), *numElements)) + return std::move(err); + else + return true; + } + + return false; + } + + // Recursive case for `tryAddArg(...)` + template + static inline llvm::Expected + tryAddArg(JITLambda::Argument &jla, size_t pos, const LambdaArgument &arg) { + llvm::Expected successOrError = tryAddArg(jla, pos, arg); + + if (!successOrError) + return std::move(successOrError.takeError()); + + if (successOrError.get() == false) + return tryAddArg(jla, pos, arg); + else + return true; + } + + // Attempts to add a single argument `arg` to `jla` at position + // `pos`. Returns an error if either the argument type is + // unsupported or if the argument types is supported, but adding it + // to `jla` failed. + static inline llvm::Error addArgument(JITLambda::Argument &jla, size_t pos, + const LambdaArgument &arg) { + llvm::Expected successOrError = + JITLambdaArgumentAdaptor::tryAddArg(jla, pos, arg); + + if (!successOrError) + return std::move(successOrError.takeError()); + + if (successOrError.get() == false) + return StreamStringError("Unknown argument type"); + else + return llvm::Error::success(); + } +}; +} // namespace + +// A compiler engine that JIT-compiles a source and produces a lambda +// object directly invocable through its call operator. +class JitCompilerEngine : public CompilerEngine { +public: + // Wrapper class around `JITLambda` and `JITLambda::Argument` that + // allows for direct invocation of a compiled function through + // `operator ()`. + class Lambda { + public: + Lambda(Lambda &&other) + : innerLambda(std::move(other.innerLambda)), + keySet(std::move(other.keySet)), + compilationContext(other.compilationContext) {} + + Lambda(std::shared_ptr compilationContext, + std::unique_ptr lambda, std::unique_ptr keySet) + : innerLambda(std::move(lambda)), keySet(std::move(keySet)), + compilationContext(compilationContext) {} + + // Returns the number of arguments required for an invocation of + // the lambda + size_t getNumArguments() { return this->keySet->numInputs(); } + + // Returns the number of results an invocation of the lambda + // produces + size_t getNumResults() { return this->keySet->numOutputs(); } + + // Invocation with an dynamic list of arguments of different + // types, specified as `LambdaArgument`s + template + llvm::Expected + operator()(llvm::ArrayRef lambdaArgs) { + // Create the arguments of the JIT lambda + llvm::Expected> argsOrErr = + mlir::zamalang::JITLambda::Argument::create(*this->keySet.get()); + + if (llvm::Error err = argsOrErr.takeError()) + return StreamStringError("Could not create lambda arguments"); + + // Set the arguments + std::unique_ptr arguments = + std::move(argsOrErr.get()); + + for (size_t i = 0; i < lambdaArgs.size(); i++) { + if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument( + *arguments, i, *lambdaArgs[i])) { + return std::move(err); + } + } + + // Invoke the lambda + if (auto err = this->innerLambda->invoke(*arguments)) + return StreamStringError() << "Cannot invoke lambda:" << err; + + return std::move(typedResult(*arguments)); + } + + // Invocation with an array of arguments of the same type + template + llvm::Expected operator()(const llvm::ArrayRef args) { + // Create the arguments of the JIT lambda + llvm::Expected> argsOrErr = + mlir::zamalang::JITLambda::Argument::create(*this->keySet.get()); + + if (llvm::Error err = argsOrErr.takeError()) + return StreamStringError("Could not create lambda arguments"); + + // Set the arguments + std::unique_ptr arguments = + std::move(argsOrErr.get()); + + for (size_t i = 0; i < args.size(); i++) { + if (auto err = arguments->setArg(i, args[i])) { + return StreamStringError() + << "Cannot push argument " << i << ": " << err; + } + } + + // Invoke the lambda + if (auto err = this->innerLambda->invoke(*arguments)) + return StreamStringError() << "Cannot invoke lambda:" << err; + + return std::move(typedResult(*arguments)); + } + + // Invocation with arguments of different types + template + llvm::Expected operator()(const Ts... ts) { + // Create the arguments of the JIT lambda + llvm::Expected> argsOrErr = + mlir::zamalang::JITLambda::Argument::create(*this->keySet.get()); + + if (llvm::Error err = argsOrErr.takeError()) + return StreamStringError("Could not create lambda arguments"); + + // Set the arguments + std::unique_ptr arguments = + std::move(argsOrErr.get()); + + if (llvm::Error err = this->addArgs<0>(arguments.get(), ts...)) + return std::move(err); + + // Invoke the lambda + if (auto err = this->innerLambda->invoke(*arguments)) + return StreamStringError() << "Cannot invoke lambda:" << err; + + return std::move(typedResult(*arguments)); + } + + protected: + template + inline llvm::Error addArgs(JITLambda::Argument *jitArgs) { + // base case -- nothing to do + return llvm::Error::success(); + } + + // Recursive case for scalars: extract first scalar argument from + // parameter pack and forward rest + template + inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT arg, + Ts... remainder) { + if (auto err = jitArgs->setArg(pos, arg)) { + return StreamStringError() + << "Cannot push scalar argument " << pos << ": " << err; + } + + return this->addArgs(jitArgs, remainder...); + } + + // Recursive case for tensors: extract pointer and size from + // parameter pack and forward rest + template + inline llvm::Error addArgs(JITLambda::Argument *jitArgs, ArgT *arg, + size_t size, Ts... remainder) { + if (auto err = jitArgs->setArg(pos, arg, size)) { + return StreamStringError() + << "Cannot push tensor argument " << pos << ": " << err; + } + + return this->addArgs(jitArgs, remainder...); + } + + std::unique_ptr innerLambda; + std::unique_ptr keySet; + std::shared_ptr compilationContext; + }; + + JitCompilerEngine(std::shared_ptr compilationContext = + CompilationContext::createShared(), + unsigned int optimizationLevel = 3); + + llvm::Expected buildLambda(llvm::StringRef src, + llvm::StringRef funcName = "main"); + + llvm::Expected buildLambda(std::unique_ptr buffer, + llvm::StringRef funcName = "main"); + + llvm::Expected buildLambda(llvm::SourceMgr &sm, + llvm::StringRef funcName = "main"); + +protected: + llvm::Expected findLLVMFuncOp(mlir::ModuleOp module, + llvm::StringRef name); + unsigned int optimizationLevel; +}; + +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/include/zamalang/Support/LambdaArgument.h b/compiler/include/zamalang/Support/LambdaArgument.h new file mode 100644 index 000000000..9d5378377 --- /dev/null +++ b/compiler/include/zamalang/Support/LambdaArgument.h @@ -0,0 +1,157 @@ +#ifndef ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H +#define ZAMALANG_SUPPORT_LAMBDA_ARGUMENT_H + +#include +#include + +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { + +// Abstract base class for lambda arguments +class LambdaArgument + : public llvm::RTTIExtends { +public: + LambdaArgument(LambdaArgument &) = delete; + + template bool isa() const { return llvm::isa(*this); } + + // Cast functions on constant instances + template const T &cast() const { return llvm::cast(*this); } + template const T *dyn_cast() const { + return llvm::dyn_cast(this); + } + + // Cast functions for mutable instances + template T &cast() { return llvm::cast(*this); } + template T *dyn_cast() { return llvm::dyn_cast(this); } + + static char ID; + +protected: + LambdaArgument(){}; +}; + +// Class for integer arguments. `BackingIntType` is used as the data +// type to hold the argument's value. The precision is the actual +// precision of the value, which might be different from the precision +// of the backing integer type. +template +class IntLambdaArgument + : public llvm::RTTIExtends, + LambdaArgument> { +public: + typedef BackingIntType value_type; + + IntLambdaArgument(BackingIntType value, + unsigned int precision = 8 * sizeof(BackingIntType)) + : precision(precision) { + if (precision < 8 * sizeof(BackingIntType)) { + this->value = value & (1 << (this->precision - 1)); + } else { + this->value = value; + } + } + + unsigned int getPrecision() const { return this->precision; } + BackingIntType getValue() const { return this->value; } + + static char ID; + +protected: + unsigned int precision; + BackingIntType value; +}; + +template +char IntLambdaArgument::ID = 0; + +namespace { +// Calculates `accu *= factor` or returns an error if the result +// would overflow +template +llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) { + static_assert(std::numeric_limits::is_integer && + std::numeric_limits::is_integer && + !std::numeric_limits::is_signed && + !std::numeric_limits::is_signed, + "Only unsigned integers are supported"); + + const AccuT left = std::numeric_limits::max() / accu; + + if (left > factor) { + accu *= factor; + return llvm::Error::success(); + } + + return StreamStringError("Multiplying value ") + << accu << " with " << factor << " would cause an overflow"; +} +} // namespace + +// Class for Tensor arguments. This can either be plaintext tensors +// (for `ScalarArgumentT = IntLambaArgument`) or tensors +// representing encrypted integers (for `ScalarArgumentT = +// EIntLambaArgument`). +template +class TensorLambdaArgument + : public llvm::RTTIExtends, + LambdaArgument> { +public: + typedef ScalarArgumentT scalar_type; + + // Construct tensor argument from the one-dimensional array `value`, + // but interpreting the array's values as a linearized + // multi-dimensional tensor with the sizes of the dimensions + // specified in `dimensions`. + TensorLambdaArgument( + llvm::MutableArrayRef value, + llvm::ArrayRef dimensions) + : value(value), dimensions(dimensions.vec()) {} + + // Construct a one-dimensional tensor argument from the + // array `value`. + TensorLambdaArgument( + llvm::MutableArrayRef value) + : TensorLambdaArgument(value, {(unsigned int)value.size()}) {} + + const std::vector &getDimensions() const { + return this->dimensions; + } + + // Returns the total number of elements in the tensor. If the number + // of elements cannot be represented as a `size_t`, the method + // returns an error. + llvm::Expected getNumElements() const { + size_t accu = 1; + + for (unsigned int dimSize : dimensions) + if (llvm::Error err = safeUnsignedMul(accu, dimSize)) + return std::move(err); + + return accu; + } + + // Returns a bare pointer to the linearized values of the tensor. + typename ScalarArgumentT::value_type *getValue() const { + return this->value.data(); + } + + static char ID; + +protected: + llvm::MutableArrayRef value; + std::vector dimensions; +}; + +template +char TensorLambdaArgument::ID = 0; + +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index a99f1c1a6..65e100e68 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -1,8 +1,9 @@ #include "CompilerAPIModule.h" #include "zamalang-c/Support/CompilerEngine.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEOpsDialect.h.inc" -#include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/ExecutionArgument.h" +#include "zamalang/Support/Jit.h" +#include "zamalang/Support/JitCompilerEngine.h" #include #include #include @@ -14,27 +15,15 @@ #include #include -using mlir::zamalang::CompilerEngine; using mlir::zamalang::ExecutionArgument; +using mlir::zamalang::JitCompilerEngine; /// Populate the compiler API python module. void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { m.doc() = "Zamalang compiler python API"; - m.def("round_trip", [](std::string mlir_input) { - mlir::MLIRContext context; - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - auto module_ref = mlir::parseSourceString(mlir_input, &context); - if (!module_ref) { - throw std::logic_error("mlir parsing failed"); - } - std::string result; - llvm::raw_string_ostream os(result); - module_ref->print(os); - return os.str(); - }); + m.def("round_trip", + [](std::string mlir_input) { return roundTrip(mlir_input.c_str()); }); pybind11::class_>( m, "ExecutionArgument") @@ -45,20 +34,19 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) { .def("is_tensor", &ExecutionArgument::isTensor) .def("is_int", &ExecutionArgument::isInt); - pybind11::class_(m, "CompilerEngine") + pybind11::class_(m, "JitCompilerEngine") .def(pybind11::init()) - .def("run", - [](CompilerEngine &engine, std::vector args) { - // wrap and call CAPI - compilerEngine e{&engine}; - exectuionArguments a{args.data(), args.size()}; - return compilerEngineRun(e, a); - }) - .def("compile_fhe", - [](CompilerEngine &engine, std::string mlir_input) { - // wrap and call CAPI - compilerEngine e{&engine}; - compilerEngineCompile(e, mlir_input.c_str()); - }) - .def("get_compiled_module", &CompilerEngine::getCompiledModule); + .def_static("build_lambda", + [](std::string mlir_input, std::string func_name) { + return buildLambda(mlir_input.c_str(), func_name.c_str()); + }); + + pybind11::class_(m, "Lambda") + .def("invoke", [](JitCompilerEngine::Lambda &py_lambda, + std::vector args) { + // wrap and call CAPI + lambda c_lambda{&py_lambda}; + exectuionArguments a{args.data(), args.size()}; + return invokeLambda(c_lambda, a); + }); } diff --git a/compiler/lib/Bindings/Python/zamalang/compiler.py b/compiler/lib/Bindings/Python/zamalang/compiler.py index 76e372463..130f4275e 100644 --- a/compiler/lib/Bindings/Python/zamalang/compiler.py +++ b/compiler/lib/Bindings/Python/zamalang/compiler.py @@ -1,10 +1,9 @@ """Compiler submodule""" from typing import List, Union -from mlir._mlir_libs._zamalang._compiler import CompilerEngine as _CompilerEngine +from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip - def round_trip(mlir_str: str) -> str: """Parse the MLIR input, then return it back. @@ -49,25 +48,24 @@ def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgume class CompilerEngine: def __init__(self, mlir_str: str = None): - self._engine = _CompilerEngine() + self._engine = _JitCompilerEngine() + self._lambda = None if mlir_str is not None: self.compile_fhe(mlir_str) - def compile_fhe(self, mlir_str: str) -> "CompilerEngine": - """Compile the MLIR input and build a CompilerEngine. + def compile_fhe(self, mlir_str: str, func_name: str = "main"): + """Compile the MLIR input. Args: mlir_str (str): MLIR to compile. + func_name (str): name of the function to set as entrypoint. Raises: TypeError: if the argument is not an str. - - Returns: - CompilerEngine: engine used for execution. """ if not isinstance(mlir_str, str): raise TypeError("input must be an `str`") - return self._engine.compile_fhe(mlir_str) + self._lambda = self._engine.build_lambda(mlir_str, func_name) def run(self, *args: List[Union[int, List[int]]]) -> int: """Run the compiled code. @@ -77,17 +75,12 @@ class CompilerEngine: Raises: TypeError: if execution arguments can't be constructed + RuntimeError: if the engine has not compiled any code yet Returns: int: result of execution. """ + if self._lambda is None: + raise RuntimeError("need to compile an MLIR code first") execution_arguments = [create_execution_argument(arg) for arg in args] - return self._engine.run(execution_arguments) - - def get_compiled_module(self) -> str: - """Compiled module in printable form. - - Returns: - str: Compiled module in printable form. - """ - return self._engine.get_compiled_module() + return self._lambda.invoke(execution_arguments) diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index e48f7ddf2..a8f9dd71e 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -1,62 +1,83 @@ #include "zamalang-c/Support/CompilerEngine.h" #include "zamalang/Support/CompilerEngine.h" #include "zamalang/Support/ExecutionArgument.h" +#include "zamalang/Support/Jit.h" +#include "zamalang/Support/JitCompilerEngine.h" +#include "zamalang/Support/logging.h" -using mlir::zamalang::CompilerEngine; +// using mlir::zamalang::CompilerEngine; using mlir::zamalang::ExecutionArgument; +using mlir::zamalang::JitCompilerEngine; -void compilerEngineCompile(compilerEngine engine, const char *module) { - auto error = engine.ptr->compile(module); - if (error) { - llvm::errs() << "Compilation failed: " << error << "\n"; - llvm::consumeError(std::move(error)); +mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module, + const char *funcName) { + mlir::zamalang::JitCompilerEngine engine; + llvm::Expected lambdaOrErr = + engine.buildLambda(module, funcName); + if (!lambdaOrErr) { + mlir::zamalang::log_error() + << "Compilation failed: " + << llvm::toString(std::move(lambdaOrErr.takeError())) << "\n"; throw std::runtime_error( "failed compiling, see previous logs for more info"); } + return std::move(*lambdaOrErr); } -uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) { - auto args_size = args.size; - auto maybeArgument = engine.ptr->buildArgument(); - if (auto err = maybeArgument.takeError()) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error( - "failed building arguments, see previous logs for more info"); +uint64_t invokeLambda(lambda l, executionArguments args) { + mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr = + (mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr; + + if (args.size != lambda_ptr->getNumArguments()) { + throw std::invalid_argument("wrong number of arguments"); } // Set the integer/tensor arguments - auto arguments = std::move(maybeArgument.get()); - for (auto i = 0; i < args_size; i++) { + std::vector lambdaArgumentsRef; + for (auto i = 0; i < args.size; i++) { if (args.data[i].isInt()) { // integer argument - if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error("failed pushing integer argument, see " - "previous logs for more info"); - } + lambdaArgumentsRef.push_back(new mlir::zamalang::IntLambdaArgument<>( + args.data[i].getIntegerArgument())); } else { // tensor argument - assert(args.data[i].isTensor() && "should be tensor argument"); - if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(), - args.data[i].getTensorSize())) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error("failed pushing tensor argument, see " - "previous logs for more info"); - } + llvm::MutableArrayRef tensor(args.data[i].getTensorArgument(), + args.data[i].getTensorSize()); + lambdaArgumentsRef.push_back( + new mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument>(tensor)); } } - // Invoke the lambda - if (auto err = engine.ptr->invoke(*arguments)) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); - throw std::runtime_error("failed running, see previous logs for more info"); - } - uint64_t result = 0; - if (auto err = arguments->getResult(0, result)) { - llvm::errs() << "Execution failed: " << err << "\n"; - llvm::consumeError(std::move(err)); + // Run lambda + llvm::Expected resOrError = (*lambda_ptr)( + llvm::ArrayRef(lambdaArgumentsRef)); + // Free heap + for (size_t i = 0; i < lambdaArgumentsRef.size(); i++) + delete lambdaArgumentsRef[i]; + + if (!resOrError) { + mlir::zamalang::log_error() + << "Lambda invokation failed: " + << llvm::toString(std::move(resOrError.takeError())) << "\n"; throw std::runtime_error( - "failed getting result, see previous logs for more info"); + "failed invoking lambda, see previous logs for more info"); } - return result; -} \ No newline at end of file + return *resOrError; +} + +std::string roundTrip(const char *module) { + std::shared_ptr ccx = + mlir::zamalang::CompilationContext::createShared(); + mlir::zamalang::JitCompilerEngine ce{ccx}; + + llvm::Expected retOrErr = + ce.compile(module, mlir::zamalang::CompilerEngine::Target::ROUND_TRIP); + if (!retOrErr) { + mlir::zamalang::log_error() + << llvm::toString(std::move(retOrErr.takeError())) << "\n"; + throw std::runtime_error( + "mlir parsing failed, see previous logs for more info"); + } + + std::string result; + llvm::raw_string_ostream os(result); + retOrErr->mlirModuleRef->get().print(os); + return os.str(); +} diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 9694989d3..a1fde7a86 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -3,6 +3,8 @@ add_mlir_library(ZamalangSupport Pipeline.cpp Jit.cpp CompilerEngine.cpp + JitCompilerEngine.cpp + LambdaArgument.cpp V0Parameters.cpp V0Curves.cpp ClientParameters.cpp diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 5a41c7e46..50d98b721 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -1,3 +1,5 @@ +#include +#include #include #include #include @@ -9,155 +11,419 @@ #include #include #include +#include +#include #include namespace mlir { namespace zamalang { -void CompilerEngine::loadDialects() { - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); +// Creates a new compilation context that can be shared across +// compilation engines and results +std::shared_ptr CompilationContext::createShared() { + return std::make_shared(); } -std::string CompilerEngine::getCompiledModule() { - std::string compiledModule; - llvm::raw_string_ostream os(compiledModule); - module_ref->print(os); - return os.str(); +CompilationContext::CompilationContext() + : mlirContext(nullptr), llvmContext(nullptr) {} + +CompilationContext::~CompilationContext() { + delete this->mlirContext; + delete this->llvmContext; } -llvm::Error CompilerEngine::compile( - std::string mlirStr, - llvm::Optional overrideConstraints) { - module_ref = mlir::parseSourceString(mlirStr, context); - if (!module_ref) { - return llvm::make_error("mlir parsing failed", - llvm::inconvertibleErrorCode()); +// Returns the MLIR context for a compilation context. Creates and +// initializes a new MLIR context if necessary. +mlir::MLIRContext *CompilationContext::getMLIRContext() { + if (this->mlirContext == nullptr) { + this->mlirContext = new mlir::MLIRContext(); + + this->mlirContext->getOrLoadDialect(); + this->mlirContext + ->getOrLoadDialect(); + this->mlirContext + ->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); + this->mlirContext->getOrLoadDialect(); } - mlir::ModuleOp module = module_ref.get(); + return this->mlirContext; +} - llvm::Optional fheConstraintsOpt = - overrideConstraints; +// Returns the LLVM context for a compilation context. Creates and +// initializes a new LLVM context if necessary. +llvm::LLVMContext *CompilationContext::getLLVMContext() { + if (this->llvmContext == nullptr) + this->llvmContext = new llvm::LLVMContext(); - if (!fheConstraintsOpt.hasValue()) { + return this->llvmContext; +} + +// Sets the FHE constraints for the compilation. Overrides any +// automatically detected configuration and prevents the autodetection +// pass from running. +void CompilerEngine::setFHEConstraints( + const mlir::zamalang::V0FHEConstraint &c) { + this->overrideMaxEintPrecision = c.p; + this->overrideMaxMANP = c.norm2; +} + +void CompilerEngine::setVerifyDiagnostics(bool v) { + this->verifyDiagnostics = v; +} + +void CompilerEngine::setGenerateKeySet(bool v) { this->generateKeySet = v; } + +void CompilerEngine::setGenerateClientParameters(bool v) { + this->generateClientParameters = v; +} + +void CompilerEngine::setMaxEintPrecision(size_t v) { + this->overrideMaxEintPrecision = v; +} + +void CompilerEngine::setParametrizeMidLFHE(bool v) { + this->parametrizeMidLFHE = v; +} + +void CompilerEngine::setMaxMANP(size_t v) { this->overrideMaxMANP = v; } + +void CompilerEngine::setClientParametersFuncName(const llvm::StringRef &name) { + this->clientParametersFuncName = name.str(); +} + +// Helper function detecting the FHE dialect with the highest level of +// abstraction used in `module`. If no FHE dialect is used, the +// function returns `CompilerEngine::FHEDialect::NONE`. +CompilerEngine::FHEDialect +CompilerEngine::detectHighestFHEDialect(mlir::ModuleOp module) { + CompilerEngine::FHEDialect highestDialect = CompilerEngine::FHEDialect::NONE; + + mlir::TypeID hlfheID = + mlir::TypeID::get(); + mlir::TypeID midlfheID = + mlir::TypeID::get(); + mlir::TypeID lowlfheID = + mlir::TypeID::get(); + + // Helper lambda updating the currently highest dialect if necessary + // by dialect type ID + auto updateDialectFromDialectID = [&](mlir::TypeID dialectID) { + if (dialectID == hlfheID) { + highestDialect = CompilerEngine::FHEDialect::HLFHE; + return true; + } else if (dialectID == lowlfheID && + highestDialect == CompilerEngine::FHEDialect::NONE) { + highestDialect = CompilerEngine::FHEDialect::LOWLFHE; + } else if (dialectID == midlfheID && + (highestDialect == CompilerEngine::FHEDialect::NONE || + highestDialect == CompilerEngine::FHEDialect::LOWLFHE)) { + highestDialect = CompilerEngine::FHEDialect::MIDLFHE; + } + + return false; + }; + + // Helper lambda updating the currently highest dialect if necessary + // by value type + std::function updateDialectFromType = + [&](mlir::Type ty) -> bool { + if (updateDialectFromDialectID(ty.getDialect().getTypeID())) + return true; + + if (mlir::TensorType tensorTy = ty.dyn_cast_or_null()) + return updateDialectFromType(tensorTy.getElementType()); + + return false; + }; + + module.walk([&](mlir::Operation *op) { + // Check operation itself + if (updateDialectFromDialectID(op->getDialect()->getTypeID())) + return mlir::WalkResult::interrupt(); + + // Check types of operands + for (mlir::Value operand : op->getOperands()) { + if (updateDialectFromType(operand.getType())) + return mlir::WalkResult::interrupt(); + } + + // Check types of results + for (mlir::Value res : op->getResults()) { + if (updateDialectFromType(res.getType())) { + return mlir::WalkResult::interrupt(); + } + } + + return mlir::WalkResult::advance(); + }); + + return highestDialect; +} + +// Sets the FHE parameters of `res` either through autodetection or +// fixed constraints provided in +// `CompilerEngine::overrideMaxEintPrecision` and +// `CompilerEngine::overrideMaxMANP`. +// +// Autodetected values can be partially or fully overridden through +// `CompilerEngine::overrideMaxEintPrecision` and +// `CompilerEngine::overrideMaxMANP`. +// +// If `noOverrideAutodetected` is true, autodetected values are not +// overriden and used directly for `res`. +// +// Return an error if autodetection fails. +llvm::Error +CompilerEngine::determineFHEParameters(CompilationResult &res, + bool noOverrideAutodetected) { + mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + mlir::ModuleOp module = res.mlirModuleRef->get(); + llvm::Optional fheConstraints; + + // Determine FHE constraints either through autodetection or through + // overridden values + if (this->overrideMaxEintPrecision.hasValue() && + this->overrideMaxMANP.hasValue() && !noOverrideAutodetected) { + fheConstraints.emplace(mlir::zamalang::V0FHEConstraint{ + this->overrideMaxMANP.getValue(), + this->overrideMaxEintPrecision.getValue()}); + + } else { llvm::Expected> fheConstraintsOrErr = - mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(*context, + mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(mlirContext, module); if (auto err = fheConstraintsOrErr.takeError()) return std::move(err); if (!fheConstraintsOrErr.get().hasValue()) { - return llvm::make_error( - "Could not determine maximum required precision for encrypted " - "integers " - "and maximum value for the Minimal Arithmetic Noise Padding", - llvm::inconvertibleErrorCode()); + return StreamStringError("Could not determine maximum required precision " + "for encrypted integers and maximum value for " + "the Minimal Arithmetic Noise Padding"); } - fheConstraintsOpt = fheConstraintsOrErr.get(); + if (noOverrideAutodetected) + return llvm::Error::success(); + + fheConstraints = fheConstraintsOrErr.get(); + + // Override individual values if requested + if (this->overrideMaxEintPrecision.hasValue()) + fheConstraints->p = this->overrideMaxEintPrecision.getValue(); + + if (this->overrideMaxMANP.hasValue()) + fheConstraints->norm2 = this->overrideMaxMANP.getValue(); } - mlir::zamalang::V0FHEConstraint fheConstraints = fheConstraintsOpt.getValue(); - const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); + const mlir::zamalang::V0Parameter *fheParams = + getV0Parameter(fheConstraints.getValue()); - if (!parameter) { - std::string buffer; - llvm::raw_string_ostream strs(buffer); - strs << "Could not determine V0 parameters for 2-norm of " - << fheConstraints.norm2 << " and p of " << fheConstraints.p; - - return llvm::make_error(strs.str(), - llvm::inconvertibleErrorCode()); + if (!fheParams) { + return StreamStringError() + << "Could not determine V0 parameters for 2-norm of " + << fheConstraints->norm2 << " and p of " << fheConstraints->p; } - mlir::zamalang::V0FHEContext fheContext{fheConstraints, *parameter}; + res.fheContext.emplace( + mlir::zamalang::V0FHEContext{*fheConstraints, *fheParams}); - // Lower to MLIR Std - if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext, - false) - .failed()) { - return llvm::make_error("failed to lower to MLIR Std", - llvm::inconvertibleErrorCode()); - } - // Create the client parameters - auto clientParameter = mlir::zamalang::createClientParametersForV0( - fheContext, "main", module_ref.get()); - if (auto err = clientParameter.takeError()) { - return std::move(err); - } - auto maybeKeySet = - mlir::zamalang::KeySet::generate(clientParameter.get(), 0, 0); - if (auto err = maybeKeySet.takeError()) { - return std::move(err); - } - keySet = std::move(maybeKeySet.get()); - - // Lower to MLIR LLVM Dialect - if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(*context, module, false) - .failed()) { - return llvm::make_error( - "failed to lower to LLVM dialect", llvm::inconvertibleErrorCode()); - } return llvm::Error::success(); } -llvm::Expected> -CompilerEngine::buildArgument() { - if (keySet.get() == nullptr) { - return llvm::make_error( - "CompilerEngine::buildArgument: invalid engine state, the keySet has " - "not been generated", - llvm::inconvertibleErrorCode()); - } - return JITLambda::Argument::create(*keySet); -} +// Performs all lowering from HLFHE to the FHE dialect with the lwoest +// level of abstraction that requires FHE parameters. +// +// Returns an error if any of the lowerings fails. +llvm::Error CompilerEngine::lowerParamDependentHalf(Target target, + CompilationResult &res) { + mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + mlir::ModuleOp module = res.mlirModuleRef->get(); -llvm::Error CompilerEngine::invoke(JITLambda::Argument &arg) { - // Create the JIT lambda - auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); - auto module = module_ref.get(); - auto maybeLambda = - mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline); - if (auto err = maybeLambda.takeError()) { - return std::move(err); + // HLFHE -> MidLFHE + if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(mlirContext, module, false) + .failed()) { + return StreamStringError("Lowering from HLFHE to MidLFHE failed"); } - // Invoke the lambda - if (auto err = maybeLambda.get()->invoke(arg)) { - return std::move(err); + + if (target == Target::MIDLFHE) + return llvm::Error::success(); + + // MidLFHE -> LowLFHE + if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( + mlirContext, module, *res.fheContext, this->parametrizeMidLFHE) + .failed()) { + return StreamStringError("Lowering from MidLFHE to LowLFHE failed"); } + return llvm::Error::success(); } -llvm::Expected CompilerEngine::run(std::vector args) { - // Build the argument of the JIT lambda. - auto maybeArgument = buildArgument(); - if (auto err = maybeArgument.takeError()) { - return std::move(err); +// Compile the sources managed by the source manager `sm` to the +// target dialect `target`. If successful, the result can be retrieved +// using `getModule()` and `getLLVMModule()`, respectively depending +// on the target dialect. +llvm::Expected +CompilerEngine::compile(llvm::SourceMgr &sm, Target target) { + CompilationResult res(this->compilationContext); + + mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + + mlir::SourceMgrDiagnosticVerifierHandler smHandler(sm, &mlirContext); + mlirContext.printOpOnDiagnostic(false); + + mlir::OwningModuleRef mlirModuleRef = + mlir::parseSourceFile(sm, &mlirContext); + + if (this->verifyDiagnostics) { + if (smHandler.verify().failed()) + return StreamStringError("Verification of diagnostics failed"); + else + return res; } - // Set the integer arguments - auto arguments = std::move(maybeArgument.get()); - for (auto i = 0; i < args.size(); i++) { - if (auto err = arguments->setArg(i, args[i])) { + + if (!mlirModuleRef) + return StreamStringError("Could not parse source"); + + res.mlirModuleRef = std::move(mlirModuleRef); + mlir::ModuleOp module = res.mlirModuleRef->get(); + + if (target == Target::HLFHE || target == Target::ROUND_TRIP) + return res; + + // Detect highest FHE dialect and check if FHE parameter + // autodetection / lowering of parameter-dependent dialects can be + // skipped + FHEDialect highestFHEDialect = this->detectHighestFHEDialect(module); + + if (highestFHEDialect == FHEDialect::HLFHE || + highestFHEDialect == FHEDialect::MIDLFHE || + this->generateClientParameters) { + bool noOverrideAutoDetected = (target == Target::HLFHE_MANP); + if (auto err = this->determineFHEParameters(res, noOverrideAutoDetected)) return std::move(err); + } + + // return early if only the MANP pass was requested + if (target == Target::HLFHE_MANP) + return res; + + if (highestFHEDialect == FHEDialect::HLFHE || + highestFHEDialect == FHEDialect::MIDLFHE) { + if (llvm::Error err = this->lowerParamDependentHalf(target, res)) + return std::move(err); + } + + if (target == Target::HLFHE_MANP || target == Target::MIDLFHE || + target == Target::LOWLFHE) + return res; + + // LowLFHE -> Canonical dialects + if (mlir::zamalang::pipeline::lowerLowLFHEToStd(mlirContext, module) + .failed()) { + return StreamStringError( + "Lowering from LowLFHE to canonical MLIR dialects failed"); + } + + if (target == Target::STD) + return res; + + // Generate client parameters if requested + if (this->generateClientParameters) { + if (!this->clientParametersFuncName.hasValue()) { + return StreamStringError( + "Generation of client parameters requested, but no function name " + "specified"); } + + llvm::Expected clientParametersOrErr = + mlir::zamalang::createClientParametersForV0( + *res.fheContext, *this->clientParametersFuncName, module); + + if (llvm::Error err = clientParametersOrErr.takeError()) + return std::move(err); + + res.clientParameters = clientParametersOrErr.get(); } - // Invoke the lambda - if (auto err = invoke(*arguments)) { - return std::move(err); + + // Generate Key set if requested + if (this->generateKeySet) { + if (!res.clientParameters.hasValue()) { + return StreamStringError("Generation of keyset requested without request " + "for generation of client parameters"); + } + + llvm::Expected> keySetOrErr = + mlir::zamalang::KeySet::generate(*res.clientParameters, 0, 0); + + if (auto err = keySetOrErr.takeError()) + return std::move(err); + + res.keySet = std::move(*keySetOrErr); } - uint64_t res = 0; - if (auto err = arguments->getResult(0, res)) { - return std::move(err); + + // MLIR canonical dialects -> LLVM Dialect + if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(mlirContext, module, + false) + .failed()) { + return StreamStringError("Failed to lower to LLVM dialect"); } + + if (target == Target::LLVM) + return res; + + // Lowering to actual LLVM IR (i.e., not the LLVM dialect) + llvm::LLVMContext &llvmContext = *this->compilationContext->getLLVMContext(); + + res.llvmModule = mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR( + mlirContext, llvmContext, module); + + if (!res.llvmModule) + return StreamStringError("Failed to convert from LLVM dialect to LLVM IR"); + + if (target == Target::LLVM_IR) + return res; + + if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *res.llvmModule) + .failed()) { + return StreamStringError("Failed to optimize LLVM IR"); + } + + if (target == Target::OPTIMIZED_LLVM_IR) + return res; + return res; +} // namespace zamalang + +// Compile the source `s` to the target dialect `target`. If successful, the +// result can be retrieved using `getModule()` and `getLLVMModule()`, +// respectively depending on the target dialect. +llvm::Expected +CompilerEngine::compile(llvm::StringRef s, Target target) { + std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); + llvm::Expected res = this->compile(std::move(mb), target); + + return std::move(res); } + +// Compile the contained in `buffer` to the target dialect +// `target`. If successful, the result can be retrieved using +// `getModule()` and `getLLVMModule()`, respectively depending on the +// target dialect. +llvm::Expected +CompilerEngine::compile(std::unique_ptr buffer, + Target target) { + llvm::SourceMgr sm; + + sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); + + llvm::Expected res = this->compile(sm, target); + + return std::move(res); +} + } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 85ab88aaf..0b100e03e 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -1,3 +1,4 @@ +#include "llvm/Support/Error.h" #include #include #include @@ -12,56 +13,6 @@ namespace mlir { namespace zamalang { -// JIT-compiles `module` invokes `func` with the arguments passed in -// `jitArguments` and `keySet` -mlir::LogicalResult -runJit(mlir::ModuleOp module, llvm::StringRef func, - llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, - std::function optPipeline, - llvm::raw_ostream &os) { - // Create the JIT lambda - auto maybeLambda = - mlir::zamalang::JITLambda::create(func, module, optPipeline); - if (!maybeLambda) { - return mlir::failure(); - } - auto lambda = std::move(maybeLambda.get()); - - // Create the arguments of the JIT lambda - auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet); - if (auto err = maybeArguments.takeError()) { - ::mlir::zamalang::log_error() - << "Cannot create lambda arguments: " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - - // Set the arguments - auto arguments = std::move(maybeArguments.get()); - for (size_t i = 0; i < funcArgs.size(); i++) { - if (auto err = arguments->setArg(i, funcArgs[i])) { - ::mlir::zamalang::log_error() - << "Cannot push argument " << i << ": " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - } - // Invoke the lambda - if (auto err = lambda->invoke(*arguments)) { - ::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - uint64_t res = 0; - if (auto err = arguments->getResult(0, res)) { - ::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n"; - llvm::consumeError(std::move(err)); - return mlir::failure(); - } - llvm::errs() << res << "\n"; - return mlir::success(); -} - llvm::Expected> JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, llvm::function_ref optPipeline) { diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp new file mode 100644 index 000000000..05359ac4f --- /dev/null +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -0,0 +1,105 @@ +#include "llvm/Support/Error.h" +#include +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { + +JitCompilerEngine::JitCompilerEngine( + std::shared_ptr compilationContext, + unsigned int optimizationLevel) + : CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) { +} + +// Returns the `LLVMFuncOp` operation in the compiled module with the +// specified name. If no LLVMFuncOp with that name exists or if there +// was no prior call to `compile()` resulting in an MLIR module in the +// LLVM dialect, an error is returned. +llvm::Expected +JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) { + auto funcOps = module.getOps(); + auto funcOp = llvm::find_if( + funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; }); + + if (funcOp == funcOps.end()) { + return StreamStringError() + << "Module does not contain function named '" << name.str() << "'"; + } + + return *funcOp; +} + +// Build a lambda from the function with the name given in +// `funcName` from the sources in `buffer`. +llvm::Expected +JitCompilerEngine::buildLambda(std::unique_ptr buffer, + llvm::StringRef funcName) { + llvm::SourceMgr sm; + + sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); + + llvm::Expected res = + this->buildLambda(sm, funcName); + + return std::move(res); +} + +// Build a lambda from the function with the name given in `funcName` +// from the source string `s`. +llvm::Expected +JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) { + std::unique_ptr mb = llvm::MemoryBuffer::getMemBuffer(s); + llvm::Expected res = + this->buildLambda(std::move(mb), funcName); + + return std::move(res); +} + +// Build a lambda from the function with the name given in +// `funcName` from the sources managed by the source manager `sm`. +llvm::Expected +JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) { + MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); + + this->setGenerateKeySet(true); + this->setGenerateClientParameters(true); + this->setClientParametersFuncName(funcName); + + // First, compile to LLVM Dialect + llvm::Expected compResOrErr = + this->compile(sm, Target::LLVM_IR); + + if (!compResOrErr) + return std::move(compResOrErr.takeError()); + + mlir::ModuleOp module = compResOrErr->mlirModuleRef->get(); + + // Locate function to JIT-compile + llvm::Expected funcOrError = + this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName); + + if (!funcOrError) + return std::move(funcOrError.takeError()); + + // Prepare LLVM infrastructure for JIT compilation + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerLLVMDialectTranslation(mlirContext); + + std::function optPipeline = + mlir::makeOptimizingTransformer(3, 0, nullptr); + + llvm::Expected> lambdaOrErr = + mlir::zamalang::JITLambda::create(funcName, module, optPipeline); + + if (!lambdaOrErr) + return std::move(lambdaOrErr.takeError()); + + return Lambda{this->compilationContext, std::move(lambdaOrErr.get()), + std::move(compResOrErr->keySet)}; +} +} // namespace zamalang +} // namespace mlir diff --git a/compiler/lib/Support/LambdaArgument.cpp b/compiler/lib/Support/LambdaArgument.cpp new file mode 100644 index 000000000..a693c0177 --- /dev/null +++ b/compiler/lib/Support/LambdaArgument.cpp @@ -0,0 +1,7 @@ +#include + +namespace mlir { +namespace zamalang { +char LambdaArgument::ID = 0; +} // namespace zamalang +} // namespace mlir diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index afc0abf6a..1e81f2803 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -22,15 +23,15 @@ #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" -#include "zamalang/Support/Jit.h" +#include "zamalang/Support/Error.h" +#include "zamalang/Support/JitCompilerEngine.h" #include "zamalang/Support/KeySet.h" #include "zamalang/Support/Pipeline.h" #include "zamalang/Support/logging.h" -enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM }; - enum Action { ROUND_TRIP, + DUMP_HLFHE, DUMP_HLFHE_MANP, DUMP_MIDLFHE, DUMP_LOWLFHE, @@ -80,26 +81,6 @@ llvm::cl::opt parametrizeMidLFHE( llvm::cl::desc("Perform MidLFHE global parametrization pass"), llvm::cl::init(true)); -static llvm::cl::opt entryDialect( - "e", "entry-dialect", llvm::cl::desc("Entry dialect"), - llvm::cl::init(EntryDialect::HLFHE), - llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, - llvm::cl::values( - clEnumValN(EntryDialect::HLFHE, "hlfhe", - "Input module is composed of HLFHE operations")), - llvm::cl::values( - clEnumValN(EntryDialect::MIDLFHE, "midlfhe", - "Input module is composed of MidLFHE operations")), - llvm::cl::values( - clEnumValN(EntryDialect::LOWLFHE, "lowlfhe", - "Input module is composed of LowLFHE operations")), - llvm::cl::values( - clEnumValN(EntryDialect::STD, "std", - "Input module is composed of operations from std")), - llvm::cl::values( - clEnumValN(EntryDialect::LLVM, "llvm", - "Input module is composed of operations from llvm"))); - static llvm::cl::opt action( "a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, @@ -109,6 +90,8 @@ static llvm::cl::opt action( llvm::cl::values(clEnumValN(Action::DUMP_HLFHE_MANP, "dump-hlfhe-manp", "Dump HLFHE module after running the Minimal " "Arithmetic Noise Padding pass")), + llvm::cl::values(clEnumValN(Action::DUMP_HLFHE, "dump-hlfhe", + "Dump HLFHE module")), llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe", "Lower to MidLFHE and dump result")), llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe", @@ -158,50 +141,7 @@ llvm::cl::opt, false, OptionalSizeTParser> assumeMaxMANP( llvm::cl::desc( "Assume a maximum for the Minimum Arithmetic Noise Padding")); -}; // namespace cmdline - -std::function defaultOptPipeline = - mlir::makeOptimizingTransformer(3, 0, nullptr); - -std::unique_ptr -generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, - const std::string &jitFuncName) { - std::unique_ptr keySet; - - mlir::zamalang::log_verbose() - << "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2 - << ", p:" << fheContext.constraint.p << "}\n"; - mlir::zamalang::log_verbose() - << "### FHE parameters for the atomic pattern: {k: " - << fheContext.parameter.k - << ", polynomialSize: " << fheContext.parameter.polynomialSize - << ", nSmall: " << fheContext.parameter.nSmall - << ", brLevel: " << fheContext.parameter.brLevel - << ", brLogBase: " << fheContext.parameter.brLogBase - << ", ksLevel: " << fheContext.parameter.ksLevel - << ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n"; - - // Create the client parameters - auto clientParameter = mlir::zamalang::createClientParametersForV0( - fheContext, jitFuncName, module); - - if (auto err = clientParameter.takeError()) { - mlir::zamalang::log_error() - << "cannot generate client parameters: " << err << "\n"; - return nullptr; - } - - mlir::zamalang::log_verbose() << "### Generate the key set\n"; - - auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, - 0); // TODO: seed - if (auto err = maybeKeySet.takeError()) { - llvm::errs() << err; - return nullptr; - } - - return std::move(maybeKeySet.get()); -} +} // namespace cmdline llvm::Expected buildFHEContext( llvm::Optional autoFHEConstraints, @@ -209,65 +149,48 @@ llvm::Expected buildFHEContext( llvm::Optional overrideMaxMANP) { if (!autoFHEConstraints.hasValue() && (!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) { - return llvm::make_error( + return mlir::zamalang::StreamStringError( "Maximum encrypted integer precision and maximum for the Minimal" "Arithmetic Noise Passing are required, but were neither specified" - "explicitly nor determined automatically", - llvm::inconvertibleErrorCode()); + "explicitly nor determined automatically"); } mlir::zamalang::V0FHEConstraint fheConstraints{ - .norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue() - : autoFHEConstraints.getValue().norm2, - .p = overrideMaxEintPrecision.hasValue() - ? overrideMaxEintPrecision.getValue() - : autoFHEConstraints.getValue().p}; + overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue() + : autoFHEConstraints.getValue().norm2, + overrideMaxEintPrecision.hasValue() ? overrideMaxEintPrecision.getValue() + : autoFHEConstraints.getValue().p}; const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); if (!parameter) { - std::string buffer; - llvm::raw_string_ostream strs(buffer); - strs << "Could not determine V0 parameters for 2-norm of " - << fheConstraints.norm2 << " and p of " << fheConstraints.p; - - return llvm::make_error(strs.str(), - llvm::inconvertibleErrorCode()); + return mlir::zamalang::StreamStringError() + << "Could not determine V0 parameters for 2-norm of " + << fheConstraints.norm2 << " and p of " << fheConstraints.p; } return mlir::zamalang::V0FHEContext{fheConstraints, *parameter}; } -mlir::LogicalResult buildAssignFHEContext( - llvm::Optional &fheContext, - llvm::Optional autoFHEConstraints, - llvm::Optional overrideMaxEintPrecision, - llvm::Optional overrideMaxMANP) { +namespace llvm { +// This needs to be wrapped into the llvm namespace for proper +// operator lookup +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const llvm::ArrayRef arr) { + os << "("; + for (size_t i = 0; i < arr.size(); i++) { + os << arr[i]; - if (fheContext.hasValue()) - return mlir::success(); - - llvm::Expected fheContextOrErr = - buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision, - overrideMaxMANP); - - if (auto err = fheContextOrErr.takeError()) { - mlir::zamalang::log_error() << err; - return mlir::failure(); + if (i != arr.size() - 1) + os << ", "; } - fheContext.emplace(fheContextOrErr.get()); - - return mlir::success(); + return os; } +} // namespace llvm // Process a single source buffer // -// The parameter `entryDialect` must specify the FHE dialect to which -// belong all FHE operations used in the source buffer. The input -// program must only contain FHE operations from that single FHE -// dialect, otherwise processing might fail. -// // The parameter `action` specifies how the buffer should be processed // and thus defines the output. // @@ -276,15 +199,14 @@ mlir::LogicalResult buildAssignFHEContext( // using the parameters given in `jitArgs`. // // The parameter `parametrizeMidLFHE` defines, whether the -// parametrization pass for MidLFHE is executed. If the pair of -// `entryDialect` and `action` does not involve any MidlFHE -// manipulation, this parameter does not have any effect. +// parametrization pass for MidLFHE is executed. If the `action` does +// not involve any MidlFHE manipulation, this parameter does not have +// any effect. // // The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if // set, override the values for the maximum required precision of // encrypted integers and the maximum value for the Minimum Arithmetic -// Noise Padding otherwise determined automatically if the entry -// dialect is HLFHE.. +// Noise Padding otherwise determined automatically. // // If `verifyDiagnostics` is `true`, the procedure only checks if the // diagnostic messages provided in the source buffer using @@ -292,164 +214,103 @@ mlir::LogicalResult buildAssignFHEContext( // the procedure checks if the parsed module is valid and if all // requested transformations succeeded. // -// If `verbose` is true, debug messages are displayed throughout the -// compilation process. -// // Compilation output is written to the stream specified by `os`. -mlir::LogicalResult processInputBuffer( - mlir::MLIRContext &context, std::unique_ptr buffer, - enum EntryDialect entryDialect, enum Action action, - const std::string &jitFuncName, llvm::ArrayRef jitArgs, - bool parametrizeMidlHFE, llvm::Optional overrideMaxEintPrecision, - llvm::Optional overrideMaxMANP, bool verifyDiagnostics, - bool verbose, llvm::raw_ostream &os) { - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); +mlir::LogicalResult +processInputBuffer(std::unique_ptr buffer, + enum Action action, const std::string &jitFuncName, + llvm::ArrayRef jitArgs, bool parametrizeMidlHFE, + llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP, + bool verifyDiagnostics, llvm::raw_ostream &os) { + std::shared_ptr ccx = + mlir::zamalang::CompilationContext::createShared(); - mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, - &context); - mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context); + mlir::zamalang::JitCompilerEngine ce{ccx}; - llvm::Optional fheConstraints; - llvm::Optional fheContext; + ce.setVerifyDiagnostics(verifyDiagnostics); + ce.setParametrizeMidLFHE(parametrizeMidlHFE); - std::unique_ptr keySet = nullptr; + if (overrideMaxEintPrecision.hasValue()) + ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue()); - if (verbose) - context.disableMultithreading(); + if (overrideMaxMANP.hasValue()) + ce.setMaxMANP(overrideMaxMANP.getValue()); - if (verifyDiagnostics) - return sourceMgrHandler.verify(); + if (action == Action::JIT_INVOKE) { + llvm::Expected lambdaOrErr = + ce.buildLambda(std::move(buffer), jitFuncName); - if (!moduleRef) - return mlir::failure(); - - mlir::ModuleOp module = moduleRef.get(); - - if (action == Action::ROUND_TRIP) { - module->print(os); - return mlir::success(); - } - - // Lowering pipeline. Each stage is represented as a label in the - // switch statement, from the most abstract dialect to the lowest - // level. Every labels acts as an entry point into the pipeline with - // a fallthrough mechanism to the next stage. Actions act as exit - // points from the pipeline. - switch (entryDialect) { - case EntryDialect::HLFHE: - if (action == Action::DUMP_HLFHE_MANP) { - if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false) - .failed()) { - return mlir::failure(); - } - - module.print(os); - return mlir::success(); - } else { - llvm::Expected> - fheConstraintsOrErr = - mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context, - module); - if (auto err = fheConstraintsOrErr.takeError()) { - mlir::zamalang::log_error() << err; - return mlir::failure(); - } else { - fheConstraints = fheConstraintsOrErr.get(); - } - } - - if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) - .failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::MIDLFHE: - if (action == Action::DUMP_MIDLFHE) { - module.print(os); - return mlir::success(); - } - - if (buildAssignFHEContext(fheContext, fheConstraints, - overrideMaxEintPrecision, overrideMaxMANP) - .failed()) { - return mlir::failure(); - } - - if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( - context, module, fheContext.getValue(), parametrizeMidlHFE) - .failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::LOWLFHE: - if (action == Action::DUMP_LOWLFHE) { - module.print(os); - return mlir::success(); - } - - if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::STD: - if (action == Action::DUMP_STD) { - module.print(os); - return mlir::success(); - } else if (action == Action::JIT_INVOKE) { - if (buildAssignFHEContext(fheContext, fheConstraints, - overrideMaxEintPrecision, overrideMaxMANP) - .failed()) { - return mlir::failure(); - } - - keySet = generateKeySet(module, fheContext.getValue(), jitFuncName); - } - - if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module, - verbose) - .failed()) - return mlir::failure(); - - // fallthrough - case EntryDialect::LLVM: { - if (action == Action::DUMP_LLVM_DIALECT) { - module.print(os); - return mlir::success(); - } else if (action == Action::JIT_INVOKE) { - return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet, - defaultOptPipeline, os); - } - - llvm::LLVMContext llvmContext; - std::unique_ptr llvmModule = - mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext, - module); - - if (!llvmModule) { + if (!lambdaOrErr) { mlir::zamalang::log_error() - << "Failed to translate LLVM dialect to LLVM IR\n"; + << "Failed to JIT-compile " << jitFuncName << ": " + << llvm::toString(std::move(lambdaOrErr.takeError())); return mlir::failure(); } - if (action == Action::DUMP_LLVM_IR) { - llvmModule->dump(); - return mlir::success(); - } + llvm::Expected resOrErr = (*lambdaOrErr)(jitArgs); - if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule) - .failed()) { - mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n"; + if (!resOrErr) { + mlir::zamalang::log_error() + << "Failed to JIT-invoke " << jitFuncName << " with arguments " + << jitArgs << ": " << llvm::toString(std::move(resOrErr.takeError())); return mlir::failure(); } - if (action == Action::DUMP_OPTIMIZED_LLVM_IR) { - llvmModule->dump(); - return mlir::success(); + os << *resOrErr << "\n"; + } else { + enum mlir::zamalang::CompilerEngine::Target target; + + switch (action) { + case Action::ROUND_TRIP: + target = mlir::zamalang::CompilerEngine::Target::ROUND_TRIP; + break; + case Action::DUMP_HLFHE: + target = mlir::zamalang::CompilerEngine::Target::HLFHE; + break; + case Action::DUMP_HLFHE_MANP: + target = mlir::zamalang::CompilerEngine::Target::HLFHE_MANP; + break; + case Action::DUMP_MIDLFHE: + target = mlir::zamalang::CompilerEngine::Target::MIDLFHE; + break; + case Action::DUMP_LOWLFHE: + target = mlir::zamalang::CompilerEngine::Target::LOWLFHE; + break; + case Action::DUMP_STD: + target = mlir::zamalang::CompilerEngine::Target::STD; + break; + case Action::DUMP_LLVM_DIALECT: + target = mlir::zamalang::CompilerEngine::Target::LLVM; + break; + case Action::DUMP_LLVM_IR: + target = mlir::zamalang::CompilerEngine::Target::LLVM_IR; + break; + case Action::DUMP_OPTIMIZED_LLVM_IR: + target = mlir::zamalang::CompilerEngine::Target::OPTIMIZED_LLVM_IR; + break; + case JIT_INVOKE: + // Case just here to satisfy the compiler; already handled above + break; } - break; - } + llvm::Expected retOrErr = + ce.compile(std::move(buffer), target); + + if (!retOrErr) { + mlir::zamalang::log_error() + << llvm::toString(std::move(retOrErr.takeError())) << "\n"; + + return mlir::failure(); + } + + if (verifyDiagnostics) { + return mlir::success(); + } else if (action == Action::DUMP_LLVM_IR || + action == Action::DUMP_OPTIMIZED_LLVM_IR) { + retOrErr->llvmModule->print(os, nullptr); + } else { + retOrErr->mlirModuleRef->get().print(os); + } } return mlir::success(); @@ -459,44 +320,11 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { // Parse command line arguments llvm::cl::ParseCommandLineOptions(argc, argv); - // Initialize the MLIR context - mlir::MLIRContext context; - mlir::zamalang::setupLogging(cmdline::verbose); // String for error messages from library functions std::string errorMessage; - if (cmdline::action == Action::DUMP_HLFHE_MANP && - cmdline::entryDialect != EntryDialect::HLFHE) { - mlir::zamalang::log_error() - << "Can only invoke Minimal Arithmetic Noise pass on HLFHE programs"; - return mlir::failure(); - } - - if (cmdline::action == Action::JIT_INVOKE && - cmdline::entryDialect != EntryDialect::HLFHE && - cmdline::entryDialect != EntryDialect::MIDLFHE && - cmdline::entryDialect != EntryDialect::LOWLFHE && - cmdline::entryDialect != EntryDialect::STD) { - mlir::zamalang::log_error() - << "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs"; - return mlir::failure(); - } - - // Load our Dialect in this MLIR Context. - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - context.getOrLoadDialect(); - - if (cmdline::verifyDiagnostics) - context.printOpOnDiagnostic(false); - auto output = mlir::openOutputFile(cmdline::output, &errorMessage); if (!output) { @@ -523,20 +351,20 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { return processInputBuffer( - context, std::move(inputBuffer), cmdline::entryDialect, - cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, + std::move(inputBuffer), cmdline::action, + cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, cmdline::verbose, os); + cmdline::verifyDiagnostics, os); }, output->os()))) return mlir::failure(); } else { return processInputBuffer( - context, std::move(file), cmdline::entryDialect, cmdline::action, - cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, + std::move(file), cmdline::action, cmdline::jitFuncName, + cmdline::jitArgs, cmdline::parametrizeMidLFHE, cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, - cmdline::verifyDiagnostics, cmdline::verbose, output->os()); + cmdline::verifyDiagnostics, output->os()); } } diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir index bf08e2e90..fc460a435 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir index 49b8063a3..224270914 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir index 2b7a9a761..846572c00 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}> func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi64>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir index 23a7b3c28..9163d28f0 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir index c98974a2a..6947a3942 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe --assume-max-manp=10 --assume-max-eint-precision=2 2>&1| FileCheck %s // CHECK: #map0 = affine_map<(d0) -> (d0)> // CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir index c83bac0ba..ff156d615 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir index a34343f21..f0da29950 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 23d8585ae..94892649d 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index a40f5df5a..13a8c7214 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index 93a60d527..0e6ff2534 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir index a3ec7b838..497ce0cd8 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir index 42b77bcf4..a0c63723d 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index 20ae4d5f0..5f917cd3b 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi64>) -> !MidLFHE.glwe<{1024,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index 201a53669..e8bc3ad06 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir index 24fc2ff96..17a86d8d5 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir index 1359aaa1e..f352e0f99 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir index 302a75e0e..acbb19672 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --entry-dialect=hlfhe --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s +// RUN: zamacompiler --split-input-file --action=dump-hlfhe-manp %s 2>&1 | FileCheck %s func @single_zero() -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir index 0e53d8d5a..939c52c71 100644 --- a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir +++ b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // Incompatible shapes func @dot_incompatible_shapes( diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir index 6a6d4f962..800e8df69 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<8>) { diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir index bb43f441c..7e543efaa 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<0>) { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir index 1bb62e224..39b97ed31 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir index d43bc7194..5608ffdd0 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir index 205e7afe1..9e91eb794 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir index 0a8ae9a8c..79d014609 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir index 0a8d9cd48..d05921ccf 100644 --- a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir +++ b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument. func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir index 6a9e6e059..45b847b8e 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir index ee84b2a49..50f288ba1 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir index deded0859..5bf4c57af 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1 func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir index 207414189..3aa512584 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index f7653afb7..44fefdf8e 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @zero() -> !HLFHE.eint<2> func @zero() -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir index 1130181ea..573d03ac9 100644 --- a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s //CHECK: #map0 = affine_map<(d0) -> (d0)> //CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> diff --git a/compiler/tests/Dialect/HLFHE/types.mlir b/compiler/tests/Dialect/HLFHE/types.mlir index 8e6b6bc85..2a9ad463c 100644 --- a/compiler/tests/Dialect/HLFHE/types.mlir +++ b/compiler/tests/Dialect/HLFHE/types.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>> func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) { diff --git a/compiler/tests/Dialect/LowLFHE/ops.mlir b/compiler/tests/Dialect/LowLFHE/ops.mlir index b909b0473..fc80bebb5 100644 --- a/compiler/tests/Dialect/LowLFHE/ops.mlir +++ b/compiler/tests/Dialect/LowLFHE/ops.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> { diff --git a/compiler/tests/Dialect/LowLFHE/types.mlir b/compiler/tests/Dialect/LowLFHE/types.mlir index 27552cb2a..07cf87134 100644 --- a/compiler/tests/Dialect/LowLFHE/types.mlir +++ b/compiler/tests/Dialect/LowLFHE/types.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir index 55df41028..46318b97a 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter result func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir index 3d1f81407..970c22942 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir index 97ead991e..cf743fb65 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir index ba6a37313..3b7d1d510 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir index ee8a78f6d..86bb8e44b 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // Bad dimension of the lookup table func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir index bfb720643..e42502600 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi64>) -> !MidLFHE.glwe<{512,10,64}{2}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir index f21873208..4c3a2d238 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir index ae9daa983..3c4a01957 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir index 0903aeb00..debf54b4d 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --action=roundtrip %s // GLWE p parameter func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir index 47fab99a6..47b850daa 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s +// RUN: zamacompiler --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir index b66236c76..974fb3cc6 100644 --- a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s +// RUN: zamacompiler %s --action=roundtrip 2>&1| FileCheck %s // CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/python/test_compiler_engine.py b/compiler/tests/python/test_compiler_engine.py index 6e8bfb33a..54d7cdac3 100644 --- a/compiler/tests/python/test_compiler_engine.py +++ b/compiler/tests/python/test_compiler_engine.py @@ -56,7 +56,7 @@ def test_compile_and_run(mlir_input, args, expected_result): def test_compile_and_run_invalid_arg_number(mlir_input, args): engine = CompilerEngine() engine.compile_fhe(mlir_input) - with pytest.raises(RuntimeError, match=r"failed pushing integer argument"): + with pytest.raises(ValueError, match=r"wrong number of arguments"): engine.run(*args) diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index aa7cbd7fb..91365ec64 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -1,8 +1,11 @@ +#include #include +#include #include "zamalang/Support/CompilerEngine.h" +#include "zamalang/Support/JitCompilerEngine.h" -mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7}; +mlir::zamalang::V0FHEConstraint defaultV0Constraints = {10, 7}; #define ASSERT_LLVM_ERROR(err) \ if (err) { \ @@ -10,384 +13,405 @@ mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7}; ASSERT_TRUE(false); \ } +// Checks that the value `val` is not in an error state. Returns +// `true` if the test passes, otherwise `false`. +template +static bool assert_expected_success(llvm::Expected &val) { + if (!((bool)val)) { + llvm::errs() << llvm::toString(std::move(val.takeError())); + return false; + } + + return true; +} + +// Checks that the value `val` is not in an error state. Returns +// `true` if the test passes, otherwise `false`. +template +static bool assert_expected_success(llvm::Expected &&val) { + return assert_expected_success(val); +} + +// Checks that the value `val` of type `llvm::Expected` is not in +// an error state. +#define ASSERT_EXPECTED_SUCCESS(val) \ + do { \ + if (!assert_expected_success(val)) \ + GTEST_FATAL_FAILURE_("Expected contained in error state"); \ + } while (0) + +// Checks that the value `val` is not in an error state and is equal +// to the value given in `exp`. Returns `true` if the test passes, +// otherwise `false`. +template +static bool assert_expected_value(llvm::Expected &val, const V &exp) { + if (!assert_expected_success(val)) + return false; + + if (!(val.get() == static_cast(exp))) { + llvm::errs() << "Expected value " << exp << ", but got " << val.get() + << "\n"; + return false; + } + + return true; +} + +// Checks that the value `val` is not in an error state and is equal +// to the value given in `exp`. Returns `true` if the test passes, +// otherwise `false`. +template +static bool assert_expected_value(llvm::Expected &&val, const V &exp) { + return assert_expected_value(val, exp); +} + +// Checks that the value `val` of type `llvm::Expected` is not in +// an error state and is equal to the value of type `T` given in +// `exp`. +#define ASSERT_EXPECTED_VALUE(val, exp) \ + do { \ + if (!assert_expected_value(val, exp)) { \ + GTEST_FATAL_FAILURE_("Expected with wrong value"); \ + } \ + } while (0) + +// Jit-compiles the function specified by `func` from `src` and +// returns the corresponding lambda. Any compilation errors are caught +// and reult in abnormal termination. +template +mlir::zamalang::JitCompilerEngine::Lambda +internalCheckedJit(F checkfunc, llvm::StringRef src, + llvm::StringRef func = "main", + bool useDefaultFHEConstraints = false) { + mlir::zamalang::JitCompilerEngine engine; + + if (useDefaultFHEConstraints) + engine.setFHEConstraints(defaultV0Constraints); + + llvm::Expected lambdaOrErr = + engine.buildLambda(src, func); + + checkfunc(lambdaOrErr); + + return std::move(*lambdaOrErr); +} + +// Shorthands to create integer literals of a specific type +uint8_t operator"" _u8(unsigned long long int v) { return v; } +uint16_t operator"" _u16(unsigned long long int v) { return v; } +uint32_t operator"" _u32(unsigned long long int v) { return v; } +uint64_t operator"" _u64(unsigned long long int v) { return v; } + +// Evaluates to the number of elements of a statically initialized +// array +#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof(arr[0])) + +// Wrapper around `internalCheckedJit` that causes +// `ASSERT_EXPECTED_SUCCESS` to use the file and line number of the +// caller instead of `internalCheckedJit`. +#define checkedJit(...) \ + internalCheckedJit( \ + [](llvm::Expected &lambda) { \ + ASSERT_EXPECTED_SUCCESS(lambda); \ + }, \ + __VA_ARGS__) + TEST(CompileAndRunHLFHE, add_eint) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> } -)XXX"; - ASSERT_FALSE(engine.compile(mlirStr)); - auto maybeResult = engine.run({1, 2}); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - ASSERT_EQ(result, 3); +)XXX"); + + ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), 3); + ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), 9); + ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), 2); +} + +// Same as CompileAndRunHLFHE::add_eint above, but using +// `LambdaArgument` instances +TEST(CompileAndRunHLFHE, add_eint_lambda_argument) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { + %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>) + return %1: !HLFHE.eint<7> +} +)XXX"); + + mlir::zamalang::IntLambdaArgument<> ila1(1); + mlir::zamalang::IntLambdaArgument<> ila2(2); + mlir::zamalang::IntLambdaArgument<> ila7(7); + mlir::zamalang::IntLambdaArgument<> ila9(9); + + ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila2}), 3); + ASSERT_EXPECTED_VALUE(lambda({&ila7, &ila9}), 16); + ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila7}), 8); + ASSERT_EXPECTED_VALUE(lambda({&ila1, &ila9}), 10); + ASSERT_EXPECTED_VALUE(lambda({&ila2, &ila7}), 9); +} + +TEST(CompileAndRunHLFHE, add_u64) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%arg0: i64, %arg1: i64) -> i64 { + %1 = addi %arg0, %arg1 : i64 + return %1: i64 +} +)XXX", + "main", true); + + ASSERT_EXPECTED_VALUE(lambda(1_u64, 2_u64), (uint64_t)3); + ASSERT_EXPECTED_VALUE(lambda(4_u64, 5_u64), (uint64_t)9); + ASSERT_EXPECTED_VALUE(lambda(1_u64, 1_u64), (uint64_t)2); } TEST(CompileAndRunTensorStd, extract_64) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi64>, %i: index) -> i64{ %c = tensor.extract %t[%i] : tensor<10xi64> return %c : i64 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF, - 0, - 8978, - 2587490, - 90, - 197864, - 698735, - 72132, - 87474, - 42}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint64_t t_arg[] = {0xFFFFFFFFFFFFFFFF, + 0, + 8978, + 2587490, + 90, + 197864, + 698735, + 72132, + 87474, + 42}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_32) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi32>, %i: index) -> i32{ %c = tensor.extract %t[%i] : tensor<10xi32> return %c : i32 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90, - 197864, 698735, 72132, 87474, 42}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); +)XXX", + "main", "true"); + static uint32_t t_arg[] = {0xFFFFFFFF, 0, 8978, 2587490, 90, + 197864, 698735, 72132, 87474, 42}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); +} + +// Same as `CompileAndRunTensorStd::extract_32` above, but using +// `LambdaArgument` instances +TEST(CompileAndRunTensorStd, extract_32_lambda_argument) { + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<10xi32>, %i: index) -> i32{ + %c = tensor.extract %t[%i] : tensor<10xi32> + return %c : i32 +} +)XXX", + "main", "true"); + static std::vector t_arg{0xFFFFFFFF, 0, 8978, 2587490, 90, + 197864, 698735, 72132, 87474, 42}; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + tla(t_arg); + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) { + mlir::zamalang::IntLambdaArgument idx(i); + ASSERT_EXPECTED_VALUE(lambda({&tla, &idx}), t_arg[i]); } } TEST(CompileAndRunTensorStd, extract_16) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi16>, %i: index) -> i16{ %c = tensor.extract %t[%i] : tensor<10xi16> return %c : i16 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227, - 63269, 36435, 52380, 7401, 13313}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + uint16_t t_arg[] = {0xFFFF, 0, 59589, 47826, 16227, + 63269, 36435, 52380, 7401, 13313}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_8) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi8>, %i: index) -> i8{ %c = tensor.extract %t[%i] : tensor<10xi8> return %c : i8 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint8_t t_arg[] = {0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi5>, %i: index) -> i5{ %c = tensor.extract %t[%i] : tensor<10xi5> return %c : i5 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorStd, extract_1) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10xi1>, %i: index) -> i1{ %c = tensor.extract %t[%i] : tensor<10xi1> return %c : i1 } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX", + "main", "true"); + + static uint8_t t_arg[] = {0, 0, 1, 0, 1, 1, 0, 1, 1, 0}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorEncrypted, extract_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{ %c = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>> return %c : !HLFHE.eint<5> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - for (size_t i = 0; i < size; i++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i]); - } +)XXX"); + + static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i), t_arg[i]); } TEST(CompileAndRunTensorEncrypted, extract_twice_and_add_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( -func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5>{ + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> +!HLFHE.eint<5>{ %ti = tensor.extract %t[%i] : tensor<10x!HLFHE.eint<5>> %tj = tensor.extract %t[%j] : tensor<10x!HLFHE.eint<5>> - %c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> !HLFHE.eint<5> - return %c : !HLFHE.eint<5> + %c = "HLFHE.add_eint"(%ti, %tj) : (!HLFHE.eint<5>, !HLFHE.eint<5>) -> + !HLFHE.eint<5> return %c : !HLFHE.eint<5> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - for (size_t i = 0; i < size; i++) { - for (size_t j = 0; j < size; j++) { - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Set the %i argument - ASSERT_LLVM_ERROR(argument->setArg(1, i)); - // Set the %j argument - ASSERT_LLVM_ERROR(argument->setArg(2, j)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, t_arg[i] + t_arg[j]); - } - } +)XXX"); + + static uint8_t t_arg[] = {3, 0, 7, 12, 14, 6, 5, 4, 1, 2}; + + for (size_t i = 0; i < ARRAY_SIZE(t_arg); i++) + for (size_t j = 0; j < ARRAY_SIZE(t_arg); j++) + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg), i, j), + t_arg[i] + t_arg[j]); } TEST(CompileAndRunTensorEncrypted, dim_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{ %c0 = constant 0 : index %c = tensor.dim %t, %c0 : tensor<10x!HLFHE.eint<5>> return %c : index } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - const size_t size = 10; - uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, t_arg, size)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res = 0; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, size); +)XXX"); + + static uint8_t t_arg[] = {32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; + ASSERT_EXPECTED_VALUE(lambda(t_arg, ARRAY_SIZE(t_arg)), ARRAY_SIZE(t_arg)); } TEST(CompileAndRunTensorEncrypted, from_elements_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> { %t = tensor.from_elements %0 : tensor<1x!HLFHE.eint<5>> return %t: tensor<1x!HLFHE.eint<5>> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the %t argument - ASSERT_LLVM_ERROR(argument->setArg(0, 10)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - size_t size_res = 1; - uint64_t t_res[size_res]; - ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res)); - ASSERT_EQ(t_res[0], 10); +)XXX"); + + llvm::Expected> res = + lambda.operator()>(10_u64); + + ASSERT_EXPECTED_SUCCESS(res); + ASSERT_EQ(res->size(), (size_t)1); + ASSERT_EQ(res->at(0), 10_u64); } TEST(CompileAndRunTensorEncrypted, in_out_tensor_with_op_5) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> { %c_0 = constant 0 : index %c_1 = constant 1 : index %a = tensor.extract %in[%c_0] : tensor<2x!HLFHE.eint<5>> %b = tensor.extract %in[%c_1] : tensor<2x!HLFHE.eint<5>> - %aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) - %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) - %bplusb = "HLFHE.add_eint"(%b, %b): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) - %out = tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>> + %aplusa = "HLFHE.add_eint"(%a, %a): (!HLFHE.eint<5>, !HLFHE.eint<5>) -> + (!HLFHE.eint<5>) %aplusb = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<5>, + !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %bplusb = "HLFHE.add_eint"(%b, %b): + (!HLFHE.eint<5>, !HLFHE.eint<5>) -> (!HLFHE.eint<5>) %out = + tensor.from_elements %aplusa, %aplusb, %bplusb : tensor<3x!HLFHE.eint<5>> return %out: tensor<3x!HLFHE.eint<5>> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set the argument - const size_t in_size = 2; - uint8_t in[in_size] = {2, 16}; - ASSERT_LLVM_ERROR(argument->setArg(0, in, in_size)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - const size_t size_res = 3; - uint64_t t_res[size_res]; - ASSERT_LLVM_ERROR(argument->getResult(0, t_res, size_res)); - ASSERT_EQ(t_res[0], in[0] + in[0]); - ASSERT_EQ(t_res[1], in[0] + in[1]); - ASSERT_EQ(t_res[2], in[1] + in[1]); +)XXX"); + + static uint8_t in[] = {2, 16}; + + llvm::Expected> res = + lambda.operator()>(in, ARRAY_SIZE(in)); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (size_t)3); + ASSERT_EQ(res->at(0), (uint64_t)(in[0] + in[0])); + ASSERT_EQ(res->at(1), (uint64_t)(in[0] + in[1])); + ASSERT_EQ(res->at(2), (uint64_t)(in[1] + in[1])); } TEST(CompileAndRunTensorEncrypted, linalg_generic) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> (0)> -func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.eint<7>) -> !HLFHE.eint<7> { +func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: +!HLFHE.eint<7>) -> !HLFHE.eint<7> { %tacc = tensor.from_elements %acc : tensor<1x!HLFHE.eint<7>> - %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) outs(%tacc : tensor<1x!HLFHE.eint<7>>) { - ^bb0(%arg2: !HLFHE.eint<7>, %arg3: i8, %arg4: !HLFHE.eint<7>): // no predecessors - %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> !HLFHE.eint<7> - %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, !HLFHE.eint<7>) -> !HLFHE.eint<7> - linalg.yield %5 : !HLFHE.eint<7> + %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types + = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<7>>, tensor<2xi8>) + outs(%tacc : tensor<1x!HLFHE.eint<7>>) { ^bb0(%arg2: !HLFHE.eint<7>, %arg3: + i8, %arg4: !HLFHE.eint<7>): // no predecessors + %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<7>, i8) -> + !HLFHE.eint<7> %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<7>, + !HLFHE.eint<7>) -> !HLFHE.eint<7> linalg.yield %5 : !HLFHE.eint<7> } -> tensor<1x!HLFHE.eint<7>> %c0 = constant 0 : index %ret = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<7>> return %ret : !HLFHE.eint<7> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set arg0, arg1, acc - const size_t in_size = 2; - uint8_t arg0[in_size] = {2, 8}; - ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size)); - uint8_t arg1[in_size] = {6, 8}; - ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size)); - ASSERT_LLVM_ERROR(argument->setArg(2, 0)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, 76); +)XXX", + "main", "true"); + + static uint8_t arg0[] = {2, 8}; + static uint8_t arg1[] = {6, 8}; + + llvm::Expected res = + lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1), 0_u64); + + ASSERT_EXPECTED_VALUE(res, 76); } TEST(CompileAndRunTensorEncrypted, dot_eint_int_7) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7> { @@ -395,77 +419,70 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>, (tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7> return %ret : !HLFHE.eint<7> } -)XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); - auto maybeArgument = engine.buildArgument(); - ASSERT_LLVM_ERROR(maybeArgument.takeError()); - auto argument = std::move(maybeArgument.get()); - // Set arg0, arg1, acc - const size_t in_size = 4; - uint8_t arg0[in_size] = {0, 1, 2, 3}; - ASSERT_LLVM_ERROR(argument->setArg(0, arg0, in_size)); - uint8_t arg1[in_size] = {0, 1, 2, 3}; - ASSERT_LLVM_ERROR(argument->setArg(1, arg1, in_size)); - // Invoke the function - ASSERT_LLVM_ERROR(engine.invoke(*argument)); - // Get and assert the result - uint64_t res; - ASSERT_LLVM_ERROR(argument->getResult(0, res)); - ASSERT_EQ(res, 14); +)XXX"); + static uint8_t arg0[] = {0, 1, 2, 3}; + static uint8_t arg1[] = {0, 1, 2, 3}; + + llvm::Expected res = + lambda(arg0, ARRAY_SIZE(arg0), arg1, ARRAY_SIZE(arg1)); + + ASSERT_EXPECTED_VALUE(res, 14); } -class CompileAndRunWithPrecision : public ::testing::TestWithParam { -protected: - mlir::zamalang::CompilerEngine engine; - void compile(std::string mlirStr) { ASSERT_FALSE(engine.compile(mlirStr)); } - void run(std::vector args, uint64_t expected) { - auto maybeResult = engine.run(args); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - if (result == expected) { - ASSERT_TRUE(true); - } else { - // TODO: Better way to test the probability of exactness - llvm::errs() << "one fail retry\n"; - maybeResult = engine.run(args); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, expected); - } - } -}; +class CompileAndRunWithPrecision : public ::testing::TestWithParam {}; TEST_P(CompileAndRunWithPrecision, identity_func) { - int precision = GetParam(); + uint64_t precision = GetParam(); std::ostringstream mlirProgram; - auto sizeOfTLU = 1 << precision; - mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision - << ">) -> !HLFHE.eint<" << precision << "> { \n"; - mlirProgram << " %tlu = std.constant dense<[0"; - for (auto i = 1; i < sizeOfTLU; i++) { - mlirProgram << ", " << i; - } - mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n"; - mlirProgram << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): " - "(!HLFHE.eint<" - << precision << ">, tensor<" << sizeOfTLU - << "xi64>) -> (!HLFHE.eint<" << precision << ">)\n "; - mlirProgram << "return %1: !HLFHE.eint<" << precision << ">\n"; + uint64_t sizeOfTLU = 1 << precision; - mlirProgram << "}\n"; - llvm::errs() << mlirProgram.str(); - compile(mlirProgram.str()); - for (auto i = 0; i < sizeOfTLU; i++) { - run({(uint64_t)i}, i); + mlirProgram << "func @main(%arg0: !HLFHE.eint<" << precision + << ">) -> !HLFHE.eint<" << precision << "> { \n" + << " %tlu = std.constant dense<[0"; + + for (uint64_t i = 1; i < sizeOfTLU; i++) + mlirProgram << ", " << i; + + mlirProgram << "]> : tensor<" << sizeOfTLU << "xi64>\n" + << " %1 = \"HLFHE.apply_lookup_table\"(%arg0, %tlu): " + << "(!HLFHE.eint<" << precision << ">, tensor<" << sizeOfTLU + << "xi64>) -> (!HLFHE.eint<" << precision << ">)\n " + << "return %1: !HLFHE.eint<" << precision << ">\n" + << "}\n"; + + mlir::zamalang::JitCompilerEngine::Lambda lambda = + checkedJit(mlirProgram.str()); + + if (precision == 7) { + // Test fails with a probability of 5% for a precision of 7. The + // probability of the test failing 5 times in a row is .05^5, + // which is less than 1:10,000 and comparable to the probability + // of failure for the other values. + static const int max_tries = 3; + + for (uint64_t i = 0; i < sizeOfTLU; i++) { + for (int retry = 0; retry <= max_tries; retry++) { + if (retry == max_tries) + GTEST_FATAL_FAILURE_("Maximum number of tries exceeded"); + + llvm::Expected val = lambda(i); + ASSERT_EXPECTED_SUCCESS(val); + + if (*val == i) + break; + } + } + } else { + for (uint64_t i = 0; i < sizeOfTLU; i++) + ASSERT_EXPECTED_VALUE(lambda(i), i); } } -INSTANTIATE_TEST_CASE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision, - ::testing::Values(1, 2, 3, 4, 5, 6, 7)); +INSTANTIATE_TEST_SUITE_P(TestHLFHEApplyLookupTable, CompileAndRunWithPrecision, + ::testing::Values(1, 2, 3, 4, 5, 6, 7)); TEST(TestHLFHEApplyLookupTable, multiple_precision) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> { %tlu_7 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]> : tensor<64xi64> %tlu_3 = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : tensor<8xi64> @@ -474,45 +491,22 @@ func @main(%arg0: !HLFHE.eint<6>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<6> { %a_plus_b = "HLFHE.add_eint"(%a, %b): (!HLFHE.eint<6>, !HLFHE.eint<6>) -> (!HLFHE.eint<6>) return %a_plus_b: !HLFHE.eint<6> } -)XXX"; - ASSERT_FALSE(engine.compile(mlirStr)); - uint64_t arg0 = 23; - uint64_t arg1 = 7; - uint64_t expected = 30; - auto maybeResult = engine.run({arg0, arg1}); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - ASSERT_EQ(result, expected); +)XXX"); + + ASSERT_EXPECTED_VALUE(lambda(23_u64, 7_u64), 30); } TEST(CompileAndRunTLU, random_func) { - mlir::zamalang::CompilerEngine engine; - auto mlirStr = R"XXX( + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%arg0: !HLFHE.eint<6>) -> !HLFHE.eint<6> { %tlu = std.constant dense<[16, 91, 16, 83, 80, 74, 21, 96, 1, 63, 49, 122, 76, 89, 74, 55, 109, 110, 103, 54, 105, 14, 66, 47, 52, 89, 7, 10, 73, 44, 119, 92, 25, 104, 123, 100, 108, 86, 29, 121, 118, 52, 107, 48, 34, 37, 13, 122, 107, 48, 74, 59, 96, 36, 50, 55, 120, 72, 27, 45, 12, 5, 96, 12]> : tensor<64xi64> %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<6>, tensor<64xi64>) -> (!HLFHE.eint<6>) return %1: !HLFHE.eint<6> } -)XXX"; - ASSERT_FALSE(engine.compile(mlirStr)); - // first value - auto maybeResult = engine.run({5}); - ASSERT_TRUE((bool)maybeResult); - uint64_t result = maybeResult.get(); - ASSERT_EQ(result, 74); - // second value - maybeResult = engine.run({62}); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, 96); - // edge value low - maybeResult = engine.run({0}); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, 16); - // edge value high - maybeResult = engine.run({63}); - ASSERT_TRUE((bool)maybeResult); - result = maybeResult.get(); - ASSERT_EQ(result, 12); +)XXX"); + + ASSERT_EXPECTED_VALUE(lambda(5_u64), 74); + ASSERT_EXPECTED_VALUE(lambda(62_u64), 96); + ASSERT_EXPECTED_VALUE(lambda(0_u64), 16); + ASSERT_EXPECTED_VALUE(lambda(63_u64), 12); }