From 30374ebb2cd8b2005796c8192b508750ffbfe786 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 17 Sep 2021 10:45:53 +0200 Subject: [PATCH] refactor(compiler): Introduce compilation pipeline with multiple entries / exits This refactoring commit restructures the compilation pipeline of `zamacompiler`, such that it is possible to enter and exit the pipeline at different points, effectively defining the level of abstraction at the input and the required level of abstraction for the output. The entry point is specified using the `--entry-dialect` argument. Valid choices are: `--entry-dialect=hlfhe`: Source contains HLFHE operations `--entry-dialect=midlfhe`: Source contains MidLFHE operations `--entry-dialect=lowlfhe`: Source contains LowLFHE operations `--entry-dialect=std`: Source does not contain any FHE Operations `--entry-dialect=llvm`: Source is in LLVM dialect The exit point is defined by an action, specified using --action. `--action=roundtrip`: Parse the source file to in-memory representation and immediately dump as text without any processing `--action=dump-midlfhe`: Lower source to MidLFHE and dump result as text `--action=dump-lowlfhe`: Lower source to LowLFHE and dump result as text `--action=dump-std`: Lower source to only standard MLIR dialects (i.e., all FHE operations have already been lowered) `--action=dump-llvm-dialect`: Lower source to MLIR's LLVM dialect (i.e., the LLVM dialect, not LLVM IR) `--action=dump-llvm-ir`: Lower source to plain LLVM IR (i.e., not the LLVM dialect, but actual LLVM IR) `--action=dump-optimized-llvm-ir`: Lower source to plain LLVM IR (i.e., not the LLVM dialect, but actual LLVM IR), pass the result through the LLVM optimizer and print the result. `--action=dump-jit-invoke`: Execute the full lowering pipeline to optimized LLVM IR, JIT compile the result, invoke the function specified in `--jit-funcname` with the parameters from `--jit-args` and print the functions return value. --- .../Conversion/Utils/GlobalFHEContext.h | 9 +- .../include/zamalang/Support/CompilerEngine.h | 14 +- .../include/zamalang/Support/CompilerTools.h | 138 ------ compiler/include/zamalang/Support/Jit.h | 95 +++- compiler/include/zamalang/Support/Pipeline.h | 42 ++ compiler/lib/Support/CMakeLists.txt | 3 +- compiler/lib/Support/CompilerEngine.cpp | 33 +- compiler/lib/Support/CompilerTools.cpp | 467 ------------------ compiler/lib/Support/Jit.cpp | 328 ++++++++++++ compiler/lib/Support/Pipeline.cpp | 148 ++++++ compiler/python/CompilerAPIModule.cpp | 1 - compiler/src/main.cpp | 388 ++++++++++----- .../Conversion/HLFHEToMidLFHE/add_eint.mlir | 4 +- .../HLFHEToMidLFHE/add_eint_int.mlir | 2 +- .../HLFHEToMidLFHE/apply_univariate.mlir | 4 +- .../HLFHEToMidLFHE/apply_univariate_cst.mlir | 4 +- .../HLFHEToMidLFHE/linalg_generic.mlir | 29 +- .../HLFHEToMidLFHE/mul_eint_int.mlir | 4 +- .../HLFHEToMidLFHE/sub_int_eint.mlir | 2 +- .../LowLFHEToConcreteCAPI/bootstrap.mlir | 4 +- .../glwe_from_table.mlir | 4 +- .../LowLFHEToConcreteCAPI/keyswitch_lwe.mlir | 4 +- .../Conversion/LowLFHEUnparametrize/func.mlir | 7 - .../unrealized_conversion_cast.mlir | 8 - .../Conversion/MidLFHEToLowLFHE/add_glwe.mlir | 2 +- .../MidLFHEToLowLFHE/add_glwe_int.mlir | 4 +- .../MidLFHEToLowLFHE/apply_lookup_table.mlir | 4 +- .../apply_lookup_table_cst.mlir | 4 +- .../MidLFHEToLowLFHE/mul_glwe_int.mlir | 4 +- .../MidLFHEToLowLFHE/sub_int_glwe.mlir | 4 +- compiler/tests/Dialect/HLFHE/dot.invalid.mlir | 4 +- .../Dialect/HLFHE/eint_error_p_too_big.mlir | 2 +- .../Dialect/HLFHE/eint_error_p_too_small.mlir | 2 +- .../Dialect/HLFHE/op_add_eint_err_inputs.mlir | 4 +- .../Dialect/HLFHE/op_add_eint_err_result.mlir | 4 +- .../HLFHE/op_add_eint_int_err_inputs.mlir | 4 +- .../HLFHE/op_add_eint_int_err_result.mlir | 4 +- .../op_apply_lookup_table_bad_dimension.mlir | 4 +- .../HLFHE/op_mul_eint_int_err_inputs.mlir | 4 +- .../HLFHE/op_mul_eint_int_err_result.mlir | 4 +- .../HLFHE/op_sub_int_eint_err_inputs.mlir | 4 +- .../HLFHE/op_sub_int_eint_err_result.mlir | 4 +- compiler/tests/Dialect/HLFHE/ops.mlir | 2 +- .../Dialect/HLFHE/tensor-ops-to-linalg.mlir | 26 +- compiler/tests/Dialect/HLFHE/types.mlir | 2 +- compiler/tests/Dialect/LowLFHE/ops.mlir | 3 +- compiler/tests/Dialect/LowLFHE/types.mlir | 3 +- .../Dialect/MidLFHE/op_add_glwe.invalid.mlir | 2 +- .../tests/Dialect/MidLFHE/op_add_glwe.mlir | 2 +- .../MidLFHE/op_add_glwe_int.invalid.mlir | 2 +- .../Dialect/MidLFHE/op_add_glwe_int.mlir | 4 +- .../op_apply_lookup_table.invalid.mlir | 2 +- .../MidLFHE/op_apply_lookup_table.mlir | 4 +- .../MidLFHE/op_mul_glwe_int.invalid.mlir | 2 +- .../Dialect/MidLFHE/op_mul_glwe_int.mlir | 4 +- .../MidLFHE/op_sub_int_glwe.invalid.mlir | 2 +- .../Dialect/MidLFHE/op_sub_int_glwe.mlir | 4 +- .../tests/Dialect/MidLFHE/types_glwe.mlir | 4 +- 58 files changed, 1014 insertions(+), 862 deletions(-) delete mode 100644 compiler/include/zamalang/Support/CompilerTools.h create mode 100644 compiler/include/zamalang/Support/Pipeline.h delete mode 100644 compiler/lib/Support/CompilerTools.cpp create mode 100644 compiler/lib/Support/Pipeline.cpp delete mode 100644 compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir delete mode 100644 compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir diff --git a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h index 603fdd2b8..f9c2d4c99 100644 --- a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h +++ b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h @@ -19,7 +19,7 @@ struct V0Parameter { size_t ksLevel; size_t ksLogBase; - V0Parameter() {} + V0Parameter() = delete; V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel, size_t brLogBase, size_t ksLevel, size_t ksLogBase) @@ -31,11 +31,14 @@ struct V0Parameter { }; struct V0FHEContext { + V0FHEContext() = delete; + V0FHEContext(const V0FHEConstraint &constraint, const V0Parameter ¶meter) + : constraint(constraint), parameter(parameter) {} + V0FHEConstraint constraint; V0Parameter parameter; }; - } // namespace zamalang } // namespace mlir -#endif \ No newline at end of file +#endif diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index b380a850b..047a33dc8 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -1,17 +1,7 @@ #ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H #define ZAMALANG_SUPPORT_COMPILER_ENGINE_H -#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" -#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" -#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" -#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" -#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" -#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" -#include "zamalang/Support/CompilerTools.h" -#include -#include -#include -#include +#include "Jit.h" namespace mlir { namespace zamalang { @@ -55,4 +45,4 @@ private: } // namespace zamalang } // namespace mlir -#endif \ No newline at end of file +#endif diff --git a/compiler/include/zamalang/Support/CompilerTools.h b/compiler/include/zamalang/Support/CompilerTools.h deleted file mode 100644 index 1017e0906..000000000 --- a/compiler/include/zamalang/Support/CompilerTools.h +++ /dev/null @@ -1,138 +0,0 @@ -#ifndef ZAMALANG_SUPPORT_COMPILERTOOLS_H_ -#define ZAMALANG_SUPPORT_COMPILERTOOLS_H_ - -#include -#include -#include - -#include "zamalang/Support/ClientParameters.h" -#include "zamalang/Support/KeySet.h" -#include "zamalang/Support/V0Parameters.h" - -namespace mlir { -namespace zamalang { - -class CompilerTools { -public: - struct LowerOptions { - llvm::function_ref enablePass; - bool verbose; - - LowerOptions() - : verbose(false), enablePass([](std::string pass) { return true; }){}; - }; - - /// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir - /// lowerable to llvm dialect. - /// The given module MLIR operation would be modified and the constraints set. - static mlir::LogicalResult - lowerHLFHEToMlirStdsDialect(mlir::MLIRContext &context, - mlir::Operation *module, V0FHEContext &fheContext, - LowerOptions options = LowerOptions()); - - /// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR - /// dialects to MLIR LLVM dialect. The given module MLIR operation would be - /// modified. - static mlir::LogicalResult - lowerMlirStdsDialectToMlirLLVMDialect(mlir::MLIRContext &context, - mlir::Operation *module, - LowerOptions options = LowerOptions()); - - static llvm::Expected> - toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module, - llvm::function_ref optPipeline); -}; - -/// JITLambda is a tool to JIT compile an mlir module and to invoke a function -/// of the module. -class JITLambda { -public: - class Argument { - public: - Argument(KeySet &keySet); - ~Argument(); - - // Create lambda Argument that use the given KeySet to perform encryption - // and decryption operations. - static llvm::Expected> create(KeySet &keySet); - - // Set a scalar argument at the given pos as a uint64_t. - llvm::Error setArg(size_t pos, uint64_t arg); - - // Set a argument at the given pos as a tensor of int64. - llvm::Error setArg(size_t pos, uint64_t *data, size_t size) { - return setArg(pos, 64, (void *)data, size); - } - - // Set a argument at the given pos as a tensor of int32. - llvm::Error setArg(size_t pos, uint32_t *data, size_t size) { - return setArg(pos, 32, (void *)data, size); - } - - // Set a argument at the given pos as a tensor of int32. - llvm::Error setArg(size_t pos, uint16_t *data, size_t size) { - return setArg(pos, 16, (void *)data, size); - } - - // Set a tensor argument at the given pos as a uint64_t. - llvm::Error setArg(size_t pos, uint8_t *data, size_t size) { - return setArg(pos, 8, (void *)data, size); - } - - // Get the result at the given pos as an uint64_t. - llvm::Error getResult(size_t pos, uint64_t &res); - - // Fill the result. - llvm::Error getResult(size_t pos, uint64_t *res, size_t size); - - private: - llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); - - friend JITLambda; - // Store the pointer on inputs values and outputs values - std::vector rawArg; - // Store the values of inputs - std::vector inputs; - // Store the values of outputs - std::vector outputs; - // Store the input gates description and the offset of the argument. - std::vector> inputGates; - // Store the outputs gates description and the offset of the argument. - std::vector> outputGates; - // Store allocated lwe ciphertexts (for free) - std::vector allocatedCiphertexts; - // Store buffers of ciphertexts - std::vector ciphertextBuffers; - - KeySet &keySet; - }; - JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name) - : type(type), name(name){}; - - /// create a JITLambda that point to the function name of the given module. - static llvm::Expected> - create(llvm::StringRef name, mlir::ModuleOp &module, - llvm::function_ref optPipeline); - - /// invokeRaw execute the jit lambda with a list of Argument, the last one is - /// used to store the result of the computation. - /// Example: - /// uin64_t arg0 = 1; - /// uin64_t res; - /// llvm::SmallVector args{&arg1, &res}; - /// lambda.invokeRaw(args); - llvm::Error invokeRaw(llvm::MutableArrayRef args); - - /// invoke the jit lambda with the Argument. - llvm::Error invoke(Argument &args); - -private: - mlir::LLVM::LLVMFunctionType type; - llvm::StringRef name; - std::unique_ptr engine; -}; - -} // namespace zamalang -} // namespace mlir - -#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index c72d67f82..c7358ae77 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -1,11 +1,12 @@ #ifndef COMPILER_JIT_H #define COMPILER_JIT_H -#include "zamalang/Support/CompilerTools.h" - +#include #include #include +#include + namespace mlir { namespace zamalang { mlir::LogicalResult @@ -13,6 +14,96 @@ runJit(mlir::ModuleOp module, llvm::StringRef func, llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, std::function optPipeline, llvm::raw_ostream &os); + +/// JITLambda is a tool to JIT compile an mlir module and to invoke a function +/// of the module. +class JITLambda { +public: + class Argument { + public: + Argument(KeySet &keySet); + ~Argument(); + + // Create lambda Argument that use the given KeySet to perform encryption + // and decryption operations. + static llvm::Expected> create(KeySet &keySet); + + // Set a scalar argument at the given pos as a uint64_t. + llvm::Error setArg(size_t pos, uint64_t arg); + + // Set a argument at the given pos as a tensor of int64. + llvm::Error setArg(size_t pos, uint64_t *data, size_t size) { + return setArg(pos, 64, (void *)data, size); + } + + // Set a argument at the given pos as a tensor of int32. + llvm::Error setArg(size_t pos, uint32_t *data, size_t size) { + return setArg(pos, 32, (void *)data, size); + } + + // Set a argument at the given pos as a tensor of int32. + llvm::Error setArg(size_t pos, uint16_t *data, size_t size) { + return setArg(pos, 16, (void *)data, size); + } + + // Set a tensor argument at the given pos as a uint64_t. + llvm::Error setArg(size_t pos, uint8_t *data, size_t size) { + return setArg(pos, 8, (void *)data, size); + } + + // Get the result at the given pos as an uint64_t. + llvm::Error getResult(size_t pos, uint64_t &res); + + // Fill the result. + llvm::Error getResult(size_t pos, uint64_t *res, size_t size); + + private: + llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); + + friend JITLambda; + // Store the pointer on inputs values and outputs values + std::vector rawArg; + // Store the values of inputs + std::vector inputs; + // Store the values of outputs + std::vector outputs; + // Store the input gates description and the offset of the argument. + std::vector> inputGates; + // Store the outputs gates description and the offset of the argument. + std::vector> outputGates; + // Store allocated lwe ciphertexts (for free) + std::vector allocatedCiphertexts; + // Store buffers of ciphertexts + std::vector ciphertextBuffers; + + KeySet &keySet; + }; + JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name) + : type(type), name(name){}; + + /// create a JITLambda that point to the function name of the given module. + static llvm::Expected> + create(llvm::StringRef name, mlir::ModuleOp &module, + llvm::function_ref optPipeline); + + /// invokeRaw execute the jit lambda with a list of Argument, the last one is + /// used to store the result of the computation. + /// Example: + /// uin64_t arg0 = 1; + /// uin64_t res; + /// llvm::SmallVector args{&arg1, &res}; + /// lambda.invokeRaw(args); + llvm::Error invokeRaw(llvm::MutableArrayRef args); + + /// invoke the jit lambda with the Argument. + llvm::Error invoke(Argument &args); + +private: + mlir::LLVM::LLVMFunctionType type; + llvm::StringRef name; + std::unique_ptr engine; +}; + } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/Pipeline.h b/compiler/include/zamalang/Support/Pipeline.h new file mode 100644 index 000000000..d8d82b82f --- /dev/null +++ b/compiler/include/zamalang/Support/Pipeline.h @@ -0,0 +1,42 @@ +#ifndef ZAMALANG_SUPPORT_PIPELINE_H_ +#define ZAMALANG_SUPPORT_PIPELINE_H_ + +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { +namespace pipeline { + +mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, bool verbose); + +mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, + bool parametrize); + +mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module); + +mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, + mlir::ModuleOp &module, bool verbose); + +mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, + llvm::Module &module); + +mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, bool verbose); + +std::unique_ptr +lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context, + llvm::LLVMContext &llvmContext, + mlir::ModuleOp &module); +} // namespace pipeline +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 61280fc1f..d5c319893 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(ZamalangSupport - CompilerTools.cpp + Pipeline.cpp + Jit.cpp CompilerEngine.cpp V0Parameters.cpp V0Curves.cpp diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 8f37ff94c..5b99e8dd4 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -1,8 +1,16 @@ -#include "zamalang/Support/CompilerEngine.h" -#include "zamalang/Conversion/Passes.h" +#include +#include +#include +#include #include #include +#include +#include +#include +#include +#include + namespace mlir { namespace zamalang { @@ -29,10 +37,20 @@ llvm::Error CompilerEngine::compile(std::string mlirStr) { return llvm::make_error("mlir parsing failed", llvm::inconvertibleErrorCode()); } - mlir::zamalang::V0FHEContext fheContext; + + mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, + .p = 7}; + const mlir::zamalang::V0Parameter *parameter = + getV0Parameter(defaultGlobalFHECircuitConstraint); + + mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, + *parameter}; + + mlir::ModuleOp module = module_ref.get(); + // Lower to MLIR Std - if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - *context, module_ref.get(), fheContext) + if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext, + false) .failed()) { return llvm::make_error("failed to lower to MLIR Std", llvm::inconvertibleErrorCode()); @@ -53,8 +71,7 @@ llvm::Error CompilerEngine::compile(std::string mlirStr) { keySet = std::move(maybeKeySet.get()); // Lower to MLIR LLVM Dialect - if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - *context, module_ref.get()) + if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(*context, module, false) .failed()) { return llvm::make_error( "failed to lower to LLVM dialect", llvm::inconvertibleErrorCode()); @@ -114,4 +131,4 @@ llvm::Expected CompilerEngine::run(std::vector args) { return res; } } // namespace zamalang -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp deleted file mode 100644 index 030812cca..000000000 --- a/compiler/lib/Support/CompilerTools.cpp +++ /dev/null @@ -1,467 +0,0 @@ -#include "mlir/Dialect/Tensor/Transforms/Passes.h" -#include -#include -#include -#include -#include -#include - -#include "zamalang/Conversion/Passes.h" -#include "zamalang/Support/CompilerTools.h" - -namespace mlir { -namespace zamalang { - -// This is temporary while we doesn't yet have the high-level verification pass -V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7}; - -void initLLVMNativeTarget() { - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); -} - -void addFilteredPassToPassManager( - mlir::PassManager &pm, std::unique_ptr pass, - llvm::function_ref enablePass) { - if (!enablePass(pass->getArgument().str())) { - return; - } - if (*pass->getOpName() == "module") { - pm.addPass(std::move(pass)); - } else { - pm.nest(*pass->getOpName()).addPass(std::move(pass)); - } -}; - -mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( - mlir::MLIRContext &context, mlir::Operation *module, - V0FHEContext &fheContext, LowerOptions options) { - mlir::PassManager pm(&context); - if (options.verbose) { - llvm::errs() << "##################################################\n"; - llvm::errs() << "### HLFHEToMlirStdsDialect pipeline\n"; - context.disableMultithreading(); - pm.enableIRPrinting(); - pm.enableStatistics(); - pm.enableTiming(); - pm.enableVerifier(); - } - - fheContext.constraint = defaultGlobalFHECircuitConstraint; - fheContext.parameter = *getV0Parameter(fheContext.constraint); - // Add all passes to lower from HLFHE to LLVM Dialect - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), - options.enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), - options.enablePass); - addFilteredPassToPassManager( - pm, - mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(fheContext), - options.enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), - options.enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), - options.enablePass); - - // Run the passes - if (pm.run(module).failed()) { - return mlir::failure(); - } - - return mlir::success(); -} - -mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - mlir::MLIRContext &context, mlir::Operation *module, LowerOptions options) { - mlir::PassManager pm(&context); - if (options.verbose) { - llvm::errs() << "##################################################\n"; - llvm::errs() << "### MlirStdsDialectToMlirLLVMDialect pipeline\n"; - context.disableMultithreading(); - pm.enableIRPrinting(); - pm.enableStatistics(); - pm.enableTiming(); - pm.enableVerifier(); - } - - // Unparametrize LowLFHE - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(), - options.enablePass); - - // Bufferize - addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createTensorBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createLinalgBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createConvertLinalgToLoopsPass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(), - options.enablePass); - - // Convert to MLIR LLVM Dialect - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), - options.enablePass); - - if (pm.run(module).failed()) { - return mlir::failure(); - } - return mlir::success(); -} - -llvm::Expected> CompilerTools::toLLVMModule( - llvm::LLVMContext &llvmContext, mlir::ModuleOp &module, - llvm::function_ref optPipeline) { - - initLLVMNativeTarget(); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); - if (!llvmModule) { - return llvm::make_error( - "failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode()); - } - - if (auto err = optPipeline(llvmModule.get())) { - return llvm::make_error("failed to optimize LLVM IR", - llvm::inconvertibleErrorCode()); - } - - return std::move(llvmModule); -} - -llvm::Expected> -JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, - llvm::function_ref optPipeline) { - - // Looking for the function - auto rangeOps = module.getOps(); - auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) { - return op.getName() == name; - }); - if (funcOp == rangeOps.end()) { - return llvm::make_error( - "cannot find the function to JIT", llvm::inconvertibleErrorCode()); - } - initLLVMNativeTarget(); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // Create an MLIR execution engine. The execution engine eagerly - // JIT-compiles the module. - auto maybeEngine = mlir::ExecutionEngine::create( - module, /*llvmModuleBuilder=*/nullptr, optPipeline); - if (!maybeEngine) { - return llvm::make_error( - "failed to construct the MLIR ExecutionEngine", - llvm::inconvertibleErrorCode()); - } - auto &engine = maybeEngine.get(); - auto lambda = std::make_unique((*funcOp).getType(), name); - lambda->engine = std::move(engine); - - return std::move(lambda); -} - -llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { - size_t nbReturn = 0; - // TODO - This check break with memref as we have 5 returns args. - // if (!this->type.getReturnType().isa()) { - // nbReturn = 1; - // } - // if (this->type.getNumParams() != args.size() - nbReturn) { - // return llvm::make_error( - // "invokeRaw: wrong number of argument", - // llvm::inconvertibleErrorCode()); - // } - if (llvm::find(args, nullptr) != args.end()) { - return llvm::make_error( - "invoke: some arguments are null", llvm::inconvertibleErrorCode()); - } - return this->engine->invokePacked(this->name, args); -} - -llvm::Error JITLambda::invoke(Argument &args) { - return std::move(invokeRaw(args.rawArg)); -} - -JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { - // Setting the inputs - { - auto numInputs = 0; - for (size_t i = 0; i < keySet.numInputs(); i++) { - auto offset = numInputs; - auto gate = keySet.inputGate(i); - inputGates.push_back({gate, offset}); - if (keySet.inputGate(i).shape.size == 0) { - // scalar gate - numInputs = numInputs + 1; - continue; - } - // memref gate, as we follow the standard calling convention - numInputs = numInputs + 5; - } - inputs = std::vector(numInputs); - } - - // Setting the outputs - { - auto numOutputs = 0; - for (size_t i = 0; i < keySet.numOutputs(); i++) { - auto offset = numOutputs; - auto gate = keySet.outputGate(i); - outputGates.push_back({gate, offset}); - if (gate.shape.size == 0) { - // scalar gate - numOutputs = numOutputs + 1; - continue; - } - // memref gate, as we follow the standard calling convention - numOutputs = numOutputs + 5; - } - outputs = std::vector(numOutputs); - } - - // The raw argument contains pointers to inputs and pointers to store the - // results - rawArg = std::vector(inputs.size() + outputs.size(), nullptr); - // Set the pointer on outputs on rawArg - for (auto i = inputs.size(); i < rawArg.size(); i++) { - rawArg[i] = &outputs[i - inputs.size()]; - } - - // Setup runtime context with appropriate keys - keySet.initGlobalRuntimeContext(); -} - -JITLambda::Argument::~Argument() { - int err; - for (auto ct : allocatedCiphertexts) { - free_lwe_ciphertext_u64(&err, ct); - } - for (auto buffer : ciphertextBuffers) { - free(buffer); - } -} - -llvm::Expected> -JITLambda::Argument::create(KeySet &keySet) { - auto args = std::make_unique(keySet); - return std::move(args); -} - -llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { - if (pos >= inputGates.size()) { - return llvm::make_error( - llvm::Twine("argument index out of bound: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - auto gate = inputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - - // Check is the argument is a scalar - if (info.shape.size != 0) { - return llvm::make_error( - llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - - // If argument is not encrypted, just save. - if (!info.encryption.hasValue()) { - inputs[offset] = (void *)arg; - rawArg[offset] = &inputs[offset]; - return llvm::Error::success(); - } - // Else if is encryted, allocate ciphertext and encrypt. - LweCiphertext_u64 *ctArg; - if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) { - return std::move(err); - } - allocatedCiphertexts.push_back(ctArg); - if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) { - return std::move(err); - } - inputs[offset] = ctArg; - rawArg[offset] = &inputs[offset]; - return llvm::Error::success(); -} - -llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, - size_t size) { - auto gate = inputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - // Check if the width is compatible - // TODO - I found this rules empirically, they are a spec somewhere? - if (info.shape.width <= 8 && width != 8) { - return llvm::make_error( - llvm::Twine("argument width should be 8: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) { - return llvm::make_error( - llvm::Twine("argument width should be 16: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) { - return llvm::make_error( - llvm::Twine("argument width should be 32: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) { - return llvm::make_error( - llvm::Twine("argument width should be 64: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 64) { - return llvm::make_error( - llvm::Twine("argument width not supported: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - // Check the size - if (info.shape.size == 0) { - return llvm::make_error( - llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.size != size) { - return llvm::make_error( - llvm::Twine("vector argument has not the expected size") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - // If argument is not encrypted, just save with the right calling convention. - if (info.encryption.hasValue()) { - // Else if is encrypted - // For moment we support only 8 bits inputs - uint8_t *data8 = (uint8_t *)data; - if (width != 8) { - return llvm::make_error( - llvm::Twine( - "argument width > 8 for encrypted gates are not supported: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - - // Allocate a buffer for ciphertexts. - auto ctBuffer = - (LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *)); - ciphertextBuffers.push_back(ctBuffer); - // Allocate ciphertexts and encrypt - for (auto i = 0; i < size; i++) { - if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) { - return std::move(err); - } - allocatedCiphertexts.push_back(ctBuffer[i]); - if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) { - return std::move(err); - } - } - // Replace the data by the buffer to ciphertext - data = (void *)ctBuffer; - } - // Set the buffer as the memref calling convention expect. - // allocated - inputs[offset] = (void *)0; // TODO - Better understand how it is used. - rawArg[offset] = &inputs[offset]; - // aligned - inputs[offset + 1] = data; - rawArg[offset + 1] = &inputs[offset + 1]; - // offset - inputs[offset + 2] = (void *)0; - rawArg[offset + 2] = &inputs[offset + 2]; - // size - inputs[offset + 3] = (void *)size; - rawArg[offset + 3] = &inputs[offset + 3]; - // stride - inputs[offset + 4] = (void *)0; - rawArg[offset + 4] = &inputs[offset + 4]; - return llvm::Error::success(); -} - -llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { - auto gate = outputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - - // Check is the argument is a scalar - if (info.shape.size != 0) { - return llvm::make_error( - llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - // If result is not encrypted, just set the result - if (!info.encryption.hasValue()) { - res = (uint64_t)(outputs[offset]); - return llvm::Error::success(); - } - // Else if is encryted, decrypt - LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]); - if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) { - return std::move(err); - } - return llvm::Error::success(); -} - -llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, - size_t size) { - auto gate = outputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - - // Check is the argument is a scalar - if (info.shape.size == 0) { - return llvm::make_error( - llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (!info.encryption.hasValue()) { - return llvm::make_error( - "unencrypted result as tensor output NYI", - llvm::inconvertibleErrorCode()); - } - // Get the values as the memref calling convention expect. - void *allocated = outputs[offset]; // TODO - Better understand how it is used. - // aligned - void *aligned = outputs[offset + 1]; - // offset - size_t offset_r = (size_t)outputs[offset + 2]; - // size - size_t size_r = (size_t)outputs[offset + 3]; - // stride - size_t stride = (size_t)outputs[offset + 4]; - // Check the sizes - if (info.shape.size != size || size_r != size) { - return llvm::make_error("output bad result buffer size", - llvm::inconvertibleErrorCode()); - } - // decrypt and fill the result buffer - for (auto i = 0; i < size_r; i++) { - LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i]; - if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) { - return std::move(err); - } - } - return llvm::Error::success(); -} - -} // namespace zamalang -} // namespace mlir diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 8c95d6a1d..779e5bfdb 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -1,6 +1,10 @@ #include #include #include +#include + +#include +#include #include #include @@ -54,5 +58,329 @@ runJit(mlir::ModuleOp module, llvm::StringRef func, return mlir::success(); } +llvm::Expected> +JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, + llvm::function_ref optPipeline) { + + // Looking for the function + auto rangeOps = module.getOps(); + auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) { + return op.getName() == name; + }); + if (funcOp == rangeOps.end()) { + return llvm::make_error( + "cannot find the function to JIT", llvm::inconvertibleErrorCode()); + } + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // Create an MLIR execution engine. The execution engine eagerly + // JIT-compiles the module. + auto maybeEngine = mlir::ExecutionEngine::create( + module, /*llvmModuleBuilder=*/nullptr, optPipeline); + if (!maybeEngine) { + return llvm::make_error( + "failed to construct the MLIR ExecutionEngine", + llvm::inconvertibleErrorCode()); + } + auto &engine = maybeEngine.get(); + auto lambda = std::make_unique((*funcOp).getType(), name); + lambda->engine = std::move(engine); + + return std::move(lambda); +} + +llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { + size_t nbReturn = 0; + // TODO - This check break with memref as we have 5 returns args. + // if (!this->type.getReturnType().isa()) { + // nbReturn = 1; + // } + // if (this->type.getNumParams() != args.size() - nbReturn) { + // return llvm::make_error( + // "invokeRaw: wrong number of argument", + // llvm::inconvertibleErrorCode()); + // } + if (llvm::find(args, nullptr) != args.end()) { + return llvm::make_error( + "invoke: some arguments are null", llvm::inconvertibleErrorCode()); + } + return this->engine->invokePacked(this->name, args); +} + +llvm::Error JITLambda::invoke(Argument &args) { + return std::move(invokeRaw(args.rawArg)); +} + +JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { + // Setting the inputs + { + auto numInputs = 0; + for (size_t i = 0; i < keySet.numInputs(); i++) { + auto offset = numInputs; + auto gate = keySet.inputGate(i); + inputGates.push_back({gate, offset}); + if (keySet.inputGate(i).shape.size == 0) { + // scalar gate + numInputs = numInputs + 1; + continue; + } + // memref gate, as we follow the standard calling convention + numInputs = numInputs + 5; + } + inputs = std::vector(numInputs); + } + + // Setting the outputs + { + auto numOutputs = 0; + for (size_t i = 0; i < keySet.numOutputs(); i++) { + auto offset = numOutputs; + auto gate = keySet.outputGate(i); + outputGates.push_back({gate, offset}); + if (gate.shape.size == 0) { + // scalar gate + numOutputs = numOutputs + 1; + continue; + } + // memref gate, as we follow the standard calling convention + numOutputs = numOutputs + 5; + } + outputs = std::vector(numOutputs); + } + + // The raw argument contains pointers to inputs and pointers to store the + // results + rawArg = std::vector(inputs.size() + outputs.size(), nullptr); + // Set the pointer on outputs on rawArg + for (auto i = inputs.size(); i < rawArg.size(); i++) { + rawArg[i] = &outputs[i - inputs.size()]; + } + + // Setup runtime context with appropriate keys + keySet.initGlobalRuntimeContext(); +} + +JITLambda::Argument::~Argument() { + int err; + for (auto ct : allocatedCiphertexts) { + free_lwe_ciphertext_u64(&err, ct); + } + for (auto buffer : ciphertextBuffers) { + free(buffer); + } +} + +llvm::Expected> +JITLambda::Argument::create(KeySet &keySet) { + auto args = std::make_unique(keySet); + return std::move(args); +} + +llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { + if (pos >= inputGates.size()) { + return llvm::make_error( + llvm::Twine("argument index out of bound: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + auto gate = inputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size != 0) { + return llvm::make_error( + llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + + // If argument is not encrypted, just save. + if (!info.encryption.hasValue()) { + inputs[offset] = (void *)arg; + rawArg[offset] = &inputs[offset]; + return llvm::Error::success(); + } + // Else if is encryted, allocate ciphertext and encrypt. + LweCiphertext_u64 *ctArg; + if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) { + return std::move(err); + } + allocatedCiphertexts.push_back(ctArg); + if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) { + return std::move(err); + } + inputs[offset] = ctArg; + rawArg[offset] = &inputs[offset]; + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, + size_t size) { + auto gate = inputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + // Check if the width is compatible + // TODO - I found this rules empirically, they are a spec somewhere? + if (info.shape.width <= 8 && width != 8) { + return llvm::make_error( + llvm::Twine("argument width should be 8: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) { + return llvm::make_error( + llvm::Twine("argument width should be 16: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) { + return llvm::make_error( + llvm::Twine("argument width should be 32: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) { + return llvm::make_error( + llvm::Twine("argument width should be 64: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 64) { + return llvm::make_error( + llvm::Twine("argument width not supported: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // Check the size + if (info.shape.size == 0) { + return llvm::make_error( + llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.size != size) { + return llvm::make_error( + llvm::Twine("vector argument has not the expected size") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // If argument is not encrypted, just save with the right calling convention. + if (info.encryption.hasValue()) { + // Else if is encrypted + // For moment we support only 8 bits inputs + uint8_t *data8 = (uint8_t *)data; + if (width != 8) { + return llvm::make_error( + llvm::Twine( + "argument width > 8 for encrypted gates are not supported: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + + // Allocate a buffer for ciphertexts. + auto ctBuffer = + (LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *)); + ciphertextBuffers.push_back(ctBuffer); + // Allocate ciphertexts and encrypt + for (auto i = 0; i < size; i++) { + if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) { + return std::move(err); + } + allocatedCiphertexts.push_back(ctBuffer[i]); + if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) { + return std::move(err); + } + } + // Replace the data by the buffer to ciphertext + data = (void *)ctBuffer; + } + // Set the buffer as the memref calling convention expect. + // allocated + inputs[offset] = (void *)0; // TODO - Better understand how it is used. + rawArg[offset] = &inputs[offset]; + // aligned + inputs[offset + 1] = data; + rawArg[offset + 1] = &inputs[offset + 1]; + // offset + inputs[offset + 2] = (void *)0; + rawArg[offset + 2] = &inputs[offset + 2]; + // size + inputs[offset + 3] = (void *)size; + rawArg[offset + 3] = &inputs[offset + 3]; + // stride + inputs[offset + 4] = (void *)0; + rawArg[offset + 4] = &inputs[offset + 4]; + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size != 0) { + return llvm::make_error( + llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // If result is not encrypted, just set the result + if (!info.encryption.hasValue()) { + res = (uint64_t)(outputs[offset]); + return llvm::Error::success(); + } + // Else if is encryted, decrypt + LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]); + if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) { + return std::move(err); + } + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, + size_t size) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size == 0) { + return llvm::make_error( + llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (!info.encryption.hasValue()) { + return llvm::make_error( + "unencrypted result as tensor output NYI", + llvm::inconvertibleErrorCode()); + } + // Get the values as the memref calling convention expect. + void *allocated = outputs[offset]; // TODO - Better understand how it is used. + // aligned + void *aligned = outputs[offset + 1]; + // offset + size_t offset_r = (size_t)outputs[offset + 2]; + // size + size_t size_r = (size_t)outputs[offset + 3]; + // stride + size_t stride = (size_t)outputs[offset + 4]; + // Check the sizes + if (info.shape.size != size || size_r != size) { + return llvm::make_error("output bad result buffer size", + llvm::inconvertibleErrorCode()); + } + // decrypt and fill the result buffer + for (auto i = 0; i < size_r; i++) { + LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i]; + if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) { + return std::move(err); + } + } + return llvm::Error::success(); +} + } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp new file mode 100644 index 000000000..76097f85d --- /dev/null +++ b/compiler/lib/Support/Pipeline.cpp @@ -0,0 +1,148 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace mlir { +namespace zamalang { +namespace pipeline { +static void addPotentiallyNestedPass(mlir::PassManager &pm, + std::unique_ptr pass) { + if (!pass->getOpName() || *pass->getOpName() == "module") { + pm.addPass(std::move(pass)); + } else { + pm.nest(*pass->getOpName()).addPass(std::move(pass)); + } +} + +mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, bool verbose) { + mlir::PassManager pm(&context); + + if (verbose) { + mlir::zamalang::log_verbose() + << "##################################################\n" + << "### HLFHE to MidLFHE pipeline\n"; + + pm.enableIRPrinting(); + pm.enableStatistics(); + pm.enableTiming(); + pm.enableVerifier(); + } + + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg()); + addPotentiallyNestedPass(pm, + mlir::zamalang::createConvertHLFHEToMidLFHEPass()); + + return pm.run(module.getOperation()); +} + +mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, + bool parametrize) { + mlir::PassManager pm(&context); + + if (parametrize) { + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass( + fheContext)); + } + + addPotentiallyNestedPass(pm, + mlir::zamalang::createConvertMidLFHEToLowLFHEPass()); + + return pm.run(module.getOperation()); +} + +mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module) { + mlir::PassManager pm(&context); + pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()); + return pm.run(module.getOperation()); +} + +mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, + mlir::ModuleOp &module, + bool verbose) { + mlir::PassManager pm(&context); + + if (verbose) { + mlir::zamalang::log_verbose() + << "##################################################\n" + << "### MlirStdsDialectToMlirLLVMDialect pipeline\n"; + context.disableMultithreading(); + pm.enableIRPrinting(); + pm.enableStatistics(); + pm.enableTiming(); + pm.enableVerifier(); + } + + // Unparametrize LowLFHE + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass()); + + // Bufferize + addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createStdBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass()); + addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass()); + + // Convert to MLIR LLVM Dialect + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()); + + return pm.run(module); +} + +std::unique_ptr +lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context, + llvm::LLVMContext &llvmContext, + mlir::ModuleOp &module) { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + return mlir::translateModuleToLLVMIR(module, llvmContext); +} + +mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, + llvm::Module &module) { + std::function optPipeline = + mlir::makeOptimizingTransformer(3, 0, nullptr); + + if (optPipeline(&module)) + return mlir::failure(); + else + return mlir::success(); +} + +mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, bool verbose) { + if (lowerHLFHEToMidLFHE(context, module, verbose).failed() || + lowerMidLFHEToLowLFHE(context, module, fheContext, true).failed() || + lowerLowLFHEToStd(context, module).failed()) { + return mlir::failure(); + } else { + return mlir::success(); + } +} + +} // namespace pipeline +} // namespace zamalang +} // namespace mlir diff --git a/compiler/python/CompilerAPIModule.cpp b/compiler/python/CompilerAPIModule.cpp index 887e79fe0..62e6203b7 100644 --- a/compiler/python/CompilerAPIModule.cpp +++ b/compiler/python/CompilerAPIModule.cpp @@ -7,7 +7,6 @@ #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" #include "zamalang/Support/CompilerEngine.h" -#include "zamalang/Support/CompilerTools.h" #include #include #include diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index ec15e7b80..517f128ff 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -12,16 +12,32 @@ #include #include +#include "mlir/IR/BuiltinOps.h" #include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/GlobalFHEContext.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" -#include "zamalang/Support/CompilerTools.h" -#include "zamalang/Support/logging.h" #include "zamalang/Support/Jit.h" +#include "zamalang/Support/KeySet.h" +#include "zamalang/Support/Pipeline.h" +#include "zamalang/Support/logging.h" + +enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM }; + +enum Action { + ROUND_TRIP, + DUMP_MIDLFHE, + DUMP_LOWLFHE, + DUMP_STD, + DUMP_LLVM_DIALECT, + DUMP_LLVM_IR, + DUMP_OPTIMIZED_LLVM_IR, + JIT_INVOKE +}; namespace cmdline { @@ -37,14 +53,53 @@ llvm::cl::opt output("o", llvm::cl::opt verbose("verbose", llvm::cl::desc("verbose logs"), llvm::cl::init(false)); -llvm::cl::list passes( - "passes", - llvm::cl::desc("Specify the passes to run (use only for compiler tests)"), - llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore); +llvm::cl::opt parametrizeMidLFHE( + "parametrize-midlfhe", + llvm::cl::desc("Perform MidLFHE global parametrization pass"), + llvm::cl::init(true)); -llvm::cl::opt roundTrip("round-trip", - llvm::cl::desc("Just parse and dump"), - llvm::cl::init(false)); +static llvm::cl::opt entryDialect( + "e", "entry-dialect", llvm::cl::desc("Entry dialect"), + llvm::cl::init(EntryDialect::HLFHE), + llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, + llvm::cl::values( + clEnumValN(EntryDialect::HLFHE, "hlfhe", + "Input module is composed of HLFHE operations")), + llvm::cl::values( + clEnumValN(EntryDialect::MIDLFHE, "midlfhe", + "Input module is composed of MidLFHE operations")), + llvm::cl::values( + clEnumValN(EntryDialect::LOWLFHE, "lowlfhe", + "Input module is composed of LowLFHE operations")), + llvm::cl::values( + clEnumValN(EntryDialect::STD, "std", + "Input module is composed of operations from std")), + llvm::cl::values( + clEnumValN(EntryDialect::LLVM, "llvm", + "Input module is composed of operations from llvm"))); + +static llvm::cl::opt action( + "a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired, + llvm::cl::NumOccurrencesFlag::Required, + llvm::cl::values( + clEnumValN(Action::ROUND_TRIP, "roundtrip", + "Parse input module and regenerate textual representation")), + llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe", + "Lower to MidLFHE and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe", + "Lower to LowLFHE and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std", + "Lower to std and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect", + "Lower to LLVM dialect and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_LLVM_IR, "dump-llvm-ir", + "Lower to LLVM-IR and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_OPTIMIZED_LLVM_IR, + "dump-optimized-llvm-ir", + "Lower to LLVM-IR, optimize and dump result")), + llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke", + "Lower and JIT-compile input module and invoke " + "function specified with --jit-funcname"))); llvm::cl::opt verifyDiagnostics( "verify-diagnostics", @@ -58,15 +113,7 @@ llvm::cl::opt splitInputFile( "chunk independently"), llvm::cl::init(false)); -llvm::cl::opt generateKeySet( - "generate-keyset", - llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"), - llvm::cl::init(false)); - -llvm::cl::opt runJit("run-jit", llvm::cl::desc("JIT the code and run it"), - llvm::cl::init(false)); - -llvm::cl::opt jitFuncname( +llvm::cl::opt jitFuncName( "jit-funcname", llvm::cl::desc("Name of the function to execute, default 'main'"), llvm::cl::init("main")); @@ -75,73 +122,16 @@ llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); - -llvm::cl::opt toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "), - llvm::cl::init(false)); }; // namespace cmdline std::function defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); -mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) { - llvm::LLVMContext context; - auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule( - context, module, defaultOptPipeline); - if (!llvmModule) { - return mlir::failure(); - } - os << **llvmModule; - return mlir::success(); -} +std::unique_ptr +generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, + const std::string &jitFuncName) { + std::unique_ptr keySet; -// Process a single source buffer -// -// If `verifyDiagnostics` is `true`, the procedure only checks if the -// diagnostic messages provided in the source buffer using -// `expected-error` are produced. -// -// If `verifyDiagnostics` is `false`, the procedure checks if the -// parsed module is valid and if all requested transformations -// succeeded. -mlir::LogicalResult -processInputBuffer(mlir::MLIRContext &context, - std::unique_ptr buffer, - llvm::raw_ostream &os, bool verifyDiagnostics) { - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); - - mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, - &context); - auto module = mlir::parseSourceFile(sourceMgr, &context); - - if (verifyDiagnostics) - return sourceMgrHandler.verify(); - - if (!module) - return mlir::failure(); - - if (cmdline::roundTrip) { - module->print(os); - return mlir::success(); - } - - auto enablePass = [](std::string passName) { - return cmdline::passes.size() == 0 || - std::any_of(cmdline::passes.begin(), cmdline::passes.end(), - [&](const std::string &p) { return passName == p; }); - }; - - // Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit. - mlir::zamalang::CompilerTools::LowerOptions lowerOptions; - lowerOptions.enablePass = enablePass; - lowerOptions.verbose = cmdline::verbose; - - mlir::zamalang::V0FHEContext fheContext; - if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - context, *module, fheContext, lowerOptions) - .failed()) { - return mlir::failure(); - } mlir::zamalang::log_verbose() << "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2 << ", p:" << fheContext.constraint.p << "}\n"; @@ -155,45 +145,196 @@ processInputBuffer(mlir::MLIRContext &context, << ", ksLevel: " << fheContext.parameter.ksLevel << ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n"; - // Generate the keySet - std::unique_ptr keySet; - if (cmdline::generateKeySet || cmdline::runJit) { - // Create the client parameters - auto clientParameter = mlir::zamalang::createClientParametersForV0( - fheContext, cmdline::jitFuncname, *module); - if (auto err = clientParameter.takeError()) { - mlir::zamalang::log_error() - << "cannot generate client parameters: " << err << "\n"; - return mlir::failure(); - } - mlir::zamalang::log_verbose() << "### Generate the key set\n"; - auto maybeKeySet = - mlir::zamalang::KeySet::generate(clientParameter.get(), 0, - 0); // TODO: seed - if (auto err = maybeKeySet.takeError()) { - llvm::errs() << err; - return mlir::failure(); - } - keySet = std::move(maybeKeySet.get()); + // Create the client parameters + auto clientParameter = mlir::zamalang::createClientParametersForV0( + fheContext, jitFuncName, module); + + if (auto err = clientParameter.takeError()) { + mlir::zamalang::log_error() + << "cannot generate client parameters: " << err << "\n"; + return nullptr; } - // Lower to MLIR LLVM Dialect - if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - context, *module, lowerOptions) - .failed()) { + mlir::zamalang::log_verbose() << "### Generate the key set\n"; + + auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, + 0); // TODO: seed + if (auto err = maybeKeySet.takeError()) { + llvm::errs() << err; + return nullptr; + } + + return std::move(maybeKeySet.get()); +} + +// Process a single source buffer +// +// The parameter `entryDialect` must specify the FHE dialect to which +// belong all FHE operations used in the source buffer. The input +// program must only contain FHE operations from that single FHE +// dialect, otherwise processing might fail. +// +// The parameter `action` specifies how the buffer should be processed +// and thus defines the output. +// +// If the specified action involves JIT compilation, `jitFuncName` +// designates the function to JIT compile. This function is invoked +// using the parameters given in `jitArgs`. +// +// The parameter `parametrizeMidLFHE` defines, whether the +// parametrization pass for MidLFHE is executed. If the pair of +// `entryDialect` and `action` does not involve any MidlFHE +// manipulation, this parameter does not have any effect. +// +// If `verifyDiagnostics` is `true`, the procedure only checks if the +// diagnostic messages provided in the source buffer using +// `expected-error` are produced. If `verifyDiagnostics` is `false`, +// the procedure checks if the parsed module is valid and if all +// requested transformations succeeded. +// +// If `verbose` is true, debug messages are displayed throughout the +// compilation process. +// +// Compilation output is written to the stream specified by `os`. +mlir::LogicalResult processInputBuffer( + mlir::MLIRContext &context, std::unique_ptr buffer, + enum EntryDialect entryDialect, enum Action action, + const std::string &jitFuncName, llvm::ArrayRef jitArgs, + bool parametrizeMidlHFE, bool verifyDiagnostics, bool verbose, + llvm::raw_ostream &os) { + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); + + mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, + &context); + mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context); + + // This is temporary until we have the high-level verification pass + // determining these parameters automatically + mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, + .p = 7}; + + std::unique_ptr keySet = nullptr; + + const mlir::zamalang::V0Parameter *parameter = + getV0Parameter(defaultGlobalFHECircuitConstraint); + + if (!parameter) { + mlir::zamalang::log_error() + << "Could not determine V0 parameters for 2-norm of " + << defaultGlobalFHECircuitConstraint.norm2 << " and p of " + << defaultGlobalFHECircuitConstraint.p << "\n"; + return mlir::failure(); } - if (cmdline::runJit) { - mlir::zamalang::log_verbose() << "### JIT compile & running\n"; - return mlir::zamalang::runJit(module.get(), cmdline::jitFuncname, - cmdline::jitArgs, *keySet, - defaultOptPipeline, os); + mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, + *parameter}; + + if (verbose) + context.disableMultithreading(); + + if (verifyDiagnostics) + return sourceMgrHandler.verify(); + + if (!moduleRef) + return mlir::failure(); + + mlir::ModuleOp module = moduleRef.get(); + + if (action == Action::ROUND_TRIP) { + module->print(os); + return mlir::success(); } - if (cmdline::toLLVM) { - return dumpLLVMIR(module.get(), os); + + // Lowering pipeline. Each stage is represented as a label in the + // switch statement, from the most abstract dialect to the lowest + // level. Every labels acts as an entry point into the pipeline with + // a fallthrough mechanism to the next stage. Actions act as exit + // points from the pipeline. + switch (entryDialect) { + case EntryDialect::HLFHE: + if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) + .failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::MIDLFHE: + if (action == Action::DUMP_MIDLFHE) { + module.print(os); + return mlir::success(); + } + + if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( + context, module, fheContext, parametrizeMidlHFE) + .failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::LOWLFHE: + if (action == Action::DUMP_LOWLFHE) { + module.print(os); + return mlir::success(); + } + + if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::STD: + if (action == Action::DUMP_STD) { + module.print(os); + return mlir::success(); + } else if (action == Action::JIT_INVOKE) { + keySet = generateKeySet(module, fheContext, jitFuncName); + } + + if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module, + verbose) + .failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::LLVM: { + if (action == Action::DUMP_LLVM_DIALECT) { + module.print(os); + return mlir::success(); + } else if (action == Action::JIT_INVOKE) { + return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet, + defaultOptPipeline, os); + } + + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext, + module); + + if (!llvmModule) { + mlir::zamalang::log_error() + << "Failed to translate LLVM dialect to LLVM IR\n"; + return mlir::failure(); + } + + if (action == Action::DUMP_LLVM_IR) { + llvmModule->dump(); + return mlir::success(); + } + + if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule) + .failed()) { + mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n"; + return mlir::failure(); + } + + if (action == Action::DUMP_OPTIMIZED_LLVM_IR) { + llvmModule->dump(); + return mlir::success(); + } + + break; } - module->print(os); + } + return mlir::success(); } @@ -209,6 +350,16 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { // String for error messages from library functions std::string errorMessage; + if (cmdline::action == Action::JIT_INVOKE && + cmdline::entryDialect != EntryDialect::HLFHE && + cmdline::entryDialect != EntryDialect::MIDLFHE && + cmdline::entryDialect != EntryDialect::LOWLFHE && + cmdline::entryDialect != EntryDialect::STD) { + mlir::zamalang::log_error() + << "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs"; + return mlir::failure(); + } + // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); context.getOrLoadDialect(); @@ -229,7 +380,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { return mlir::failure(); } - // Iterate over all inpiut files specified on the command line + // Iterate over all input files specified on the command line for (const auto &fileName : cmdline::inputs) { auto file = mlir::openInputFile(fileName, &errorMessage); @@ -247,14 +398,19 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { std::move(file), [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { - return processInputBuffer(context, std::move(inputBuffer), os, - cmdline::verifyDiagnostics); + return processInputBuffer( + context, std::move(inputBuffer), cmdline::entryDialect, + cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, + cmdline::parametrizeMidLFHE, cmdline::verifyDiagnostics, + cmdline::verbose, os); }, output->os()))) return mlir::failure(); } else { - return processInputBuffer(context, std::move(file), output->os(), - cmdline::verifyDiagnostics); + return processInputBuffer( + context, std::move(file), cmdline::entryDialect, cmdline::action, + cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, + cmdline::verifyDiagnostics, cmdline::verbose, output->os()); } } diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir index c94a65a40..bf08e2e90 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { @@ -7,4 +7,4 @@ func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir index 811246be0..49b8063a3 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir index 74c4d531c..62a17b7d9 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}> func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2> { @@ -7,4 +7,4 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.e %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi2>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir index bcc99851f..23a7b3c28 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { @@ -8,4 +8,4 @@ func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir index 9fba81d0f..c98974a2a 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir @@ -1,19 +1,18 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s - -//CHECK: #map0 = affine_map<(d0) -> (d0)> -//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> -//CHECK-NEXT: module { -//CHECK-NEXT: func @linalg_generic(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { -//CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { -//CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors -//CHECK-NEXT: %1 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: %2 = "MidLFHE.add_glwe"(%1, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: linalg.yield %2 : !MidLFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> -//CHECK-NEXT: return -//CHECK-NEXT: } -//CHECK-NEXT: } +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// CHECK: #map0 = affine_map<(d0) -> (d0)> +// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> +// CHECK-NEXT: module { +// CHECK-NEXT: func @linalg_generic(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { +// CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { +// CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +// CHECK-NEXT: %1 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: %2 = "MidLFHE.add_glwe"(%1, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: linalg.yield %2 : !MidLFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> (0)> diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir index 0297179ca..c83bac0ba 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { @@ -9,4 +9,4 @@ func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %0 = constant 1 : i8 %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir index 0962bf42a..a34343f21 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 87fef60eb..bbd678db6 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) @@ -27,4 +27,4 @@ func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe // CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index ce6e2aff4..a40f5df5a 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list @@ -31,4 +31,4 @@ func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext { // CHECK-NEXT: return %[[V1]] : !LowLFHE.glwe_ciphertext %1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !LowLFHE.glwe_ciphertext return %1: !LowLFHE.glwe_ciphertext -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index 09f609291..66ecf0177 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) @@ -26,4 +26,4 @@ func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciph // CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir b/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir deleted file mode 100644 index 852924948..000000000 --- a/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: zamacompiler --passes lowlfhe-unparametrize %s 2>&1| FileCheck %s - -// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_> -func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> { - // CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_> - return %arg0: !LowLFHE.lwe_ciphertext<1024,4> -} \ No newline at end of file diff --git a/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir b/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir deleted file mode 100644 index b48ba3c24..000000000 --- a/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: zamacompiler --passes lowlfhe-unparametrize %s 2>&1| FileCheck %s - -// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_> -func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<_,_> { - // CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_> - %0 = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_> - return %0: !LowLFHE.lwe_ciphertext<_,_> -} \ No newline at end of file diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir index f84dcbd69..7814e2064 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir index db9e32472..ee1d3e53a 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { @@ -19,4 +19,4 @@ func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE. // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.add_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index f3cda0d8f..926c788ce 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { @@ -8,4 +8,4 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16x // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index e202ae025..adc811bef 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { @@ -10,4 +10,4 @@ func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.g %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi4> %1 = "MidLFHE.apply_lookup_table"(%arg0, %tlu){k=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{2048,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{2048,1,64}{4}>) return %1: !MidLFHE.glwe<{2048,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir index 3950f64e1..c13c73353 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { @@ -19,4 +19,4 @@ func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE. // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.mul_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir index 5af7c049c..a40db4b79 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { @@ -20,4 +20,4 @@ func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE. // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.sub_int_glwe"(%arg1, %arg0): (i5, !MidLFHE.glwe<{1024,1,64}{4}>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir index 2e2f31a65..0e53d8d5a 100644 --- a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir +++ b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s // Incompatible shapes func @dot_incompatible_shapes( @@ -66,4 +66,4 @@ func @dot_incompatible_int( (tensor<4x!HLFHE.eint<2>>, tensor<4xi4>) -> !HLFHE.eint<2> return %ret : !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir index de76e776f..6a6d4f962 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<8>) { diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir index 300d49685..bb43f441c 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<0>) { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir index 01bd02a02..1bb62e224 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir @@ -1,7 +1,7 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<3>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir index 3c9e7b0cb..d43bc7194 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir @@ -1,7 +1,7 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir index 680c79b57..205e7afe1 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %0 = constant 1 : i4 %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir index 299d2771a..0a8ae9a8c 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %0 = constant 1 : i2 %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir index ea0c179a8..0a8d9cd48 100644 --- a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir +++ b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir @@ -1,7 +1,7 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument. func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> { %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<8xi3>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir index 6ea308618..6a9e6e059 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %0 = constant 1 : i4 %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir index bc7ef7637..ee84b2a49 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %0 = constant 1 : i2 %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir index 73c036e5e..deded0859 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1 func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %0 = constant 1 : i4 %1 = "HLFHE.sub_int_eint"(%0, %arg0): (i4, !HLFHE.eint<2>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir index 6d2892e8a..207414189 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %0 = constant 1 : i2 %1 = "HLFHE.sub_int_eint"(%0, %arg0): (i2, !HLFHE.eint<2>) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index de2c43060..4ee6f53bc 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @zero() -> !HLFHE.eint<2> func @zero() -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir index f313d7501..1130181ea 100644 --- a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -1,22 +1,22 @@ -// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s //CHECK: #map0 = affine_map<(d0) -> (d0)> //CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> //CHECK-NEXT: module { -//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { -//CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2> -//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>> -//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) outs(%1 : tensor<1x!HLFHE.eint<2>>) { -//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors -//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> -//CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> -//CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2> -//CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>> +//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>) -> !MidLFHE.glwe<{_,_,_}{2}> { +//CHECK-NEXT: %0 = "MidLFHE.zero"() : () -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%1 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg2: !MidLFHE.glwe<{_,_,_}{2}>, %arg3: i3, %arg4: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %4 = "MidLFHE.mul_glwe_int"(%arg2, %arg3) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %5 = "MidLFHE.add_glwe"(%4, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %5 : !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> //CHECK-NEXT: %c0 = constant 0 : index -//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<2>> -//CHECK-NEXT: return %3 : !HLFHE.eint<2> +//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: return %3 : !MidLFHE.glwe<{_,_,_}{2}> //CHECK-NEXT: } -//CHECK-NEXT: } +//CHECK-NEXT: } func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/types.mlir b/compiler/tests/Dialect/HLFHE/types.mlir index 97d77ffa8..8e6b6bc85 100644 --- a/compiler/tests/Dialect/HLFHE/types.mlir +++ b/compiler/tests/Dialect/HLFHE/types.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>> func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) { diff --git a/compiler/tests/Dialect/LowLFHE/ops.mlir b/compiler/tests/Dialect/LowLFHE/ops.mlir index be4315c9d..b909b0473 100644 --- a/compiler/tests/Dialect/LowLFHE/ops.mlir +++ b/compiler/tests/Dialect/LowLFHE/ops.mlir @@ -1,5 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s - +// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> { diff --git a/compiler/tests/Dialect/LowLFHE/types.mlir b/compiler/tests/Dialect/LowLFHE/types.mlir index 0b6d37df0..27552cb2a 100644 --- a/compiler/tests/Dialect/LowLFHE/types.mlir +++ b/compiler/tests/Dialect/LowLFHE/types.mlir @@ -1,5 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s - +// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir index 18cb9f48f..55df41028 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter result func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir index eebd51590..3d1f81407 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir index 4ce221611..97ead991e 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir index dae0011a8..ba6a37313 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -9,4 +9,4 @@ func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 %0 = constant 1 : i8 %1 = "MidLFHE.add_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir index 04f52dafc..ee8a78f6d 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // Bad dimension of the lookup table func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir index 07de47e6e..928095674 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> { @@ -7,4 +7,4 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<12 %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32, outputSizeKS = 600 : i32} : (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<128xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>) return %1: !MidLFHE.glwe<{512,10,64}{2}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir index f3fe65d0e..f21873208 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir index 4a15c7a59..ae9daa983 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -9,4 +9,4 @@ func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 %0 = constant 1 : i8 %1 = "MidLFHE.mul_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir index 29099bf7e..0903aeb00 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir index 32032566f..47fab99a6 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -9,4 +9,4 @@ func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 %0 = constant 1 : i8 %1 = "MidLFHE.sub_int_glwe"(%0, %arg0): (i8, !MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir index eaec6e7dd..b66236c76 100644 --- a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --round-trip 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s // CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -10,4 +10,4 @@ func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64 func @glwe_1(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> { // CHECK-LABEL: return %arg0 : !MidLFHE.glwe<{_,_,_}{7}> return %arg0: !MidLFHE.glwe<{_,_,_}{7}> -} \ No newline at end of file +}