diff --git a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h index 603fdd2b8..f9c2d4c99 100644 --- a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h +++ b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h @@ -19,7 +19,7 @@ struct V0Parameter { size_t ksLevel; size_t ksLogBase; - V0Parameter() {} + V0Parameter() = delete; V0Parameter(size_t k, size_t polynomialSize, size_t nSmall, size_t brLevel, size_t brLogBase, size_t ksLevel, size_t ksLogBase) @@ -31,11 +31,14 @@ struct V0Parameter { }; struct V0FHEContext { + V0FHEContext() = delete; + V0FHEContext(const V0FHEConstraint &constraint, const V0Parameter ¶meter) + : constraint(constraint), parameter(parameter) {} + V0FHEConstraint constraint; V0Parameter parameter; }; - } // namespace zamalang } // namespace mlir -#endif \ No newline at end of file +#endif diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index b380a850b..047a33dc8 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -1,17 +1,7 @@ #ifndef ZAMALANG_SUPPORT_COMPILER_ENGINE_H #define ZAMALANG_SUPPORT_COMPILER_ENGINE_H -#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" -#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" -#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" -#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" -#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" -#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" -#include "zamalang/Support/CompilerTools.h" -#include -#include -#include -#include +#include "Jit.h" namespace mlir { namespace zamalang { @@ -55,4 +45,4 @@ private: } // namespace zamalang } // namespace mlir -#endif \ No newline at end of file +#endif diff --git a/compiler/include/zamalang/Support/CompilerTools.h b/compiler/include/zamalang/Support/CompilerTools.h deleted file mode 100644 index 1017e0906..000000000 --- a/compiler/include/zamalang/Support/CompilerTools.h +++ /dev/null @@ -1,138 +0,0 @@ -#ifndef ZAMALANG_SUPPORT_COMPILERTOOLS_H_ -#define ZAMALANG_SUPPORT_COMPILERTOOLS_H_ - -#include -#include -#include - -#include "zamalang/Support/ClientParameters.h" -#include "zamalang/Support/KeySet.h" -#include "zamalang/Support/V0Parameters.h" - -namespace mlir { -namespace zamalang { - -class CompilerTools { -public: - struct LowerOptions { - llvm::function_ref enablePass; - bool verbose; - - LowerOptions() - : verbose(false), enablePass([](std::string pass) { return true; }){}; - }; - - /// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir - /// lowerable to llvm dialect. - /// The given module MLIR operation would be modified and the constraints set. - static mlir::LogicalResult - lowerHLFHEToMlirStdsDialect(mlir::MLIRContext &context, - mlir::Operation *module, V0FHEContext &fheContext, - LowerOptions options = LowerOptions()); - - /// lowerMlirStdsDialectToMlirLLVMDialect run all passes to lower MLIR - /// dialects to MLIR LLVM dialect. The given module MLIR operation would be - /// modified. - static mlir::LogicalResult - lowerMlirStdsDialectToMlirLLVMDialect(mlir::MLIRContext &context, - mlir::Operation *module, - LowerOptions options = LowerOptions()); - - static llvm::Expected> - toLLVMModule(llvm::LLVMContext &llvmContext, mlir::ModuleOp &module, - llvm::function_ref optPipeline); -}; - -/// JITLambda is a tool to JIT compile an mlir module and to invoke a function -/// of the module. -class JITLambda { -public: - class Argument { - public: - Argument(KeySet &keySet); - ~Argument(); - - // Create lambda Argument that use the given KeySet to perform encryption - // and decryption operations. - static llvm::Expected> create(KeySet &keySet); - - // Set a scalar argument at the given pos as a uint64_t. - llvm::Error setArg(size_t pos, uint64_t arg); - - // Set a argument at the given pos as a tensor of int64. - llvm::Error setArg(size_t pos, uint64_t *data, size_t size) { - return setArg(pos, 64, (void *)data, size); - } - - // Set a argument at the given pos as a tensor of int32. - llvm::Error setArg(size_t pos, uint32_t *data, size_t size) { - return setArg(pos, 32, (void *)data, size); - } - - // Set a argument at the given pos as a tensor of int32. - llvm::Error setArg(size_t pos, uint16_t *data, size_t size) { - return setArg(pos, 16, (void *)data, size); - } - - // Set a tensor argument at the given pos as a uint64_t. - llvm::Error setArg(size_t pos, uint8_t *data, size_t size) { - return setArg(pos, 8, (void *)data, size); - } - - // Get the result at the given pos as an uint64_t. - llvm::Error getResult(size_t pos, uint64_t &res); - - // Fill the result. - llvm::Error getResult(size_t pos, uint64_t *res, size_t size); - - private: - llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); - - friend JITLambda; - // Store the pointer on inputs values and outputs values - std::vector rawArg; - // Store the values of inputs - std::vector inputs; - // Store the values of outputs - std::vector outputs; - // Store the input gates description and the offset of the argument. - std::vector> inputGates; - // Store the outputs gates description and the offset of the argument. - std::vector> outputGates; - // Store allocated lwe ciphertexts (for free) - std::vector allocatedCiphertexts; - // Store buffers of ciphertexts - std::vector ciphertextBuffers; - - KeySet &keySet; - }; - JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name) - : type(type), name(name){}; - - /// create a JITLambda that point to the function name of the given module. - static llvm::Expected> - create(llvm::StringRef name, mlir::ModuleOp &module, - llvm::function_ref optPipeline); - - /// invokeRaw execute the jit lambda with a list of Argument, the last one is - /// used to store the result of the computation. - /// Example: - /// uin64_t arg0 = 1; - /// uin64_t res; - /// llvm::SmallVector args{&arg1, &res}; - /// lambda.invokeRaw(args); - llvm::Error invokeRaw(llvm::MutableArrayRef args); - - /// invoke the jit lambda with the Argument. - llvm::Error invoke(Argument &args); - -private: - mlir::LLVM::LLVMFunctionType type; - llvm::StringRef name; - std::unique_ptr engine; -}; - -} // namespace zamalang -} // namespace mlir - -#endif \ No newline at end of file diff --git a/compiler/include/zamalang/Support/Jit.h b/compiler/include/zamalang/Support/Jit.h index c72d67f82..c7358ae77 100644 --- a/compiler/include/zamalang/Support/Jit.h +++ b/compiler/include/zamalang/Support/Jit.h @@ -1,11 +1,12 @@ #ifndef COMPILER_JIT_H #define COMPILER_JIT_H -#include "zamalang/Support/CompilerTools.h" - +#include #include #include +#include + namespace mlir { namespace zamalang { mlir::LogicalResult @@ -13,6 +14,96 @@ runJit(mlir::ModuleOp module, llvm::StringRef func, llvm::ArrayRef funcArgs, mlir::zamalang::KeySet &keySet, std::function optPipeline, llvm::raw_ostream &os); + +/// JITLambda is a tool to JIT compile an mlir module and to invoke a function +/// of the module. +class JITLambda { +public: + class Argument { + public: + Argument(KeySet &keySet); + ~Argument(); + + // Create lambda Argument that use the given KeySet to perform encryption + // and decryption operations. + static llvm::Expected> create(KeySet &keySet); + + // Set a scalar argument at the given pos as a uint64_t. + llvm::Error setArg(size_t pos, uint64_t arg); + + // Set a argument at the given pos as a tensor of int64. + llvm::Error setArg(size_t pos, uint64_t *data, size_t size) { + return setArg(pos, 64, (void *)data, size); + } + + // Set a argument at the given pos as a tensor of int32. + llvm::Error setArg(size_t pos, uint32_t *data, size_t size) { + return setArg(pos, 32, (void *)data, size); + } + + // Set a argument at the given pos as a tensor of int32. + llvm::Error setArg(size_t pos, uint16_t *data, size_t size) { + return setArg(pos, 16, (void *)data, size); + } + + // Set a tensor argument at the given pos as a uint64_t. + llvm::Error setArg(size_t pos, uint8_t *data, size_t size) { + return setArg(pos, 8, (void *)data, size); + } + + // Get the result at the given pos as an uint64_t. + llvm::Error getResult(size_t pos, uint64_t &res); + + // Fill the result. + llvm::Error getResult(size_t pos, uint64_t *res, size_t size); + + private: + llvm::Error setArg(size_t pos, size_t width, void *data, size_t size); + + friend JITLambda; + // Store the pointer on inputs values and outputs values + std::vector rawArg; + // Store the values of inputs + std::vector inputs; + // Store the values of outputs + std::vector outputs; + // Store the input gates description and the offset of the argument. + std::vector> inputGates; + // Store the outputs gates description and the offset of the argument. + std::vector> outputGates; + // Store allocated lwe ciphertexts (for free) + std::vector allocatedCiphertexts; + // Store buffers of ciphertexts + std::vector ciphertextBuffers; + + KeySet &keySet; + }; + JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name) + : type(type), name(name){}; + + /// create a JITLambda that point to the function name of the given module. + static llvm::Expected> + create(llvm::StringRef name, mlir::ModuleOp &module, + llvm::function_ref optPipeline); + + /// invokeRaw execute the jit lambda with a list of Argument, the last one is + /// used to store the result of the computation. + /// Example: + /// uin64_t arg0 = 1; + /// uin64_t res; + /// llvm::SmallVector args{&arg1, &res}; + /// lambda.invokeRaw(args); + llvm::Error invokeRaw(llvm::MutableArrayRef args); + + /// invoke the jit lambda with the Argument. + llvm::Error invoke(Argument &args); + +private: + mlir::LLVM::LLVMFunctionType type; + llvm::StringRef name; + std::unique_ptr engine; +}; + } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Support/Pipeline.h b/compiler/include/zamalang/Support/Pipeline.h new file mode 100644 index 000000000..d8d82b82f --- /dev/null +++ b/compiler/include/zamalang/Support/Pipeline.h @@ -0,0 +1,42 @@ +#ifndef ZAMALANG_SUPPORT_PIPELINE_H_ +#define ZAMALANG_SUPPORT_PIPELINE_H_ + +#include +#include +#include +#include + +namespace mlir { +namespace zamalang { +namespace pipeline { + +mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, bool verbose); + +mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, + bool parametrize); + +mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module); + +mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, + mlir::ModuleOp &module, bool verbose); + +mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, + llvm::Module &module); + +mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, bool verbose); + +std::unique_ptr +lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context, + llvm::LLVMContext &llvmContext, + mlir::ModuleOp &module); +} // namespace pipeline +} // namespace zamalang +} // namespace mlir + +#endif diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 61280fc1f..d5c319893 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(ZamalangSupport - CompilerTools.cpp + Pipeline.cpp + Jit.cpp CompilerEngine.cpp V0Parameters.cpp V0Curves.cpp diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 8f37ff94c..5b99e8dd4 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -1,8 +1,16 @@ -#include "zamalang/Support/CompilerEngine.h" -#include "zamalang/Conversion/Passes.h" +#include +#include +#include +#include #include #include +#include +#include +#include +#include +#include + namespace mlir { namespace zamalang { @@ -29,10 +37,20 @@ llvm::Error CompilerEngine::compile(std::string mlirStr) { return llvm::make_error("mlir parsing failed", llvm::inconvertibleErrorCode()); } - mlir::zamalang::V0FHEContext fheContext; + + mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, + .p = 7}; + const mlir::zamalang::V0Parameter *parameter = + getV0Parameter(defaultGlobalFHECircuitConstraint); + + mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, + *parameter}; + + mlir::ModuleOp module = module_ref.get(); + // Lower to MLIR Std - if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - *context, module_ref.get(), fheContext) + if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext, + false) .failed()) { return llvm::make_error("failed to lower to MLIR Std", llvm::inconvertibleErrorCode()); @@ -53,8 +71,7 @@ llvm::Error CompilerEngine::compile(std::string mlirStr) { keySet = std::move(maybeKeySet.get()); // Lower to MLIR LLVM Dialect - if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - *context, module_ref.get()) + if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(*context, module, false) .failed()) { return llvm::make_error( "failed to lower to LLVM dialect", llvm::inconvertibleErrorCode()); @@ -114,4 +131,4 @@ llvm::Expected CompilerEngine::run(std::vector args) { return res; } } // namespace zamalang -} // namespace mlir \ No newline at end of file +} // namespace mlir diff --git a/compiler/lib/Support/CompilerTools.cpp b/compiler/lib/Support/CompilerTools.cpp deleted file mode 100644 index 030812cca..000000000 --- a/compiler/lib/Support/CompilerTools.cpp +++ /dev/null @@ -1,467 +0,0 @@ -#include "mlir/Dialect/Tensor/Transforms/Passes.h" -#include -#include -#include -#include -#include -#include - -#include "zamalang/Conversion/Passes.h" -#include "zamalang/Support/CompilerTools.h" - -namespace mlir { -namespace zamalang { - -// This is temporary while we doesn't yet have the high-level verification pass -V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, .p = 7}; - -void initLLVMNativeTarget() { - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); -} - -void addFilteredPassToPassManager( - mlir::PassManager &pm, std::unique_ptr pass, - llvm::function_ref enablePass) { - if (!enablePass(pass->getArgument().str())) { - return; - } - if (*pass->getOpName() == "module") { - pm.addPass(std::move(pass)); - } else { - pm.nest(*pass->getOpName()).addPass(std::move(pass)); - } -}; - -mlir::LogicalResult CompilerTools::lowerHLFHEToMlirStdsDialect( - mlir::MLIRContext &context, mlir::Operation *module, - V0FHEContext &fheContext, LowerOptions options) { - mlir::PassManager pm(&context); - if (options.verbose) { - llvm::errs() << "##################################################\n"; - llvm::errs() << "### HLFHEToMlirStdsDialect pipeline\n"; - context.disableMultithreading(); - pm.enableIRPrinting(); - pm.enableStatistics(); - pm.enableTiming(); - pm.enableVerifier(); - } - - fheContext.constraint = defaultGlobalFHECircuitConstraint; - fheContext.parameter = *getV0Parameter(fheContext.constraint); - // Add all passes to lower from HLFHE to LLVM Dialect - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), - options.enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), - options.enablePass); - addFilteredPassToPassManager( - pm, - mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass(fheContext), - options.enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertMidLFHEToLowLFHEPass(), - options.enablePass); - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass(), - options.enablePass); - - // Run the passes - if (pm.run(module).failed()) { - return mlir::failure(); - } - - return mlir::success(); -} - -mlir::LogicalResult CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - mlir::MLIRContext &context, mlir::Operation *module, LowerOptions options) { - mlir::PassManager pm(&context); - if (options.verbose) { - llvm::errs() << "##################################################\n"; - llvm::errs() << "### MlirStdsDialectToMlirLLVMDialect pipeline\n"; - context.disableMultithreading(); - pm.enableIRPrinting(); - pm.enableStatistics(); - pm.enableTiming(); - pm.enableVerifier(); - } - - // Unparametrize LowLFHE - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass(), - options.enablePass); - - // Bufferize - addFilteredPassToPassManager(pm, mlir::createTensorConstantBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createStdBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createTensorBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createLinalgBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createConvertLinalgToLoopsPass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createFuncBufferizePass(), - options.enablePass); - addFilteredPassToPassManager(pm, mlir::createFinalizingBufferizePass(), - options.enablePass); - - // Convert to MLIR LLVM Dialect - addFilteredPassToPassManager( - pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(), - options.enablePass); - - if (pm.run(module).failed()) { - return mlir::failure(); - } - return mlir::success(); -} - -llvm::Expected> CompilerTools::toLLVMModule( - llvm::LLVMContext &llvmContext, mlir::ModuleOp &module, - llvm::function_ref optPipeline) { - - initLLVMNativeTarget(); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); - if (!llvmModule) { - return llvm::make_error( - "failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode()); - } - - if (auto err = optPipeline(llvmModule.get())) { - return llvm::make_error("failed to optimize LLVM IR", - llvm::inconvertibleErrorCode()); - } - - return std::move(llvmModule); -} - -llvm::Expected> -JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, - llvm::function_ref optPipeline) { - - // Looking for the function - auto rangeOps = module.getOps(); - auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) { - return op.getName() == name; - }); - if (funcOp == rangeOps.end()) { - return llvm::make_error( - "cannot find the function to JIT", llvm::inconvertibleErrorCode()); - } - initLLVMNativeTarget(); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // Create an MLIR execution engine. The execution engine eagerly - // JIT-compiles the module. - auto maybeEngine = mlir::ExecutionEngine::create( - module, /*llvmModuleBuilder=*/nullptr, optPipeline); - if (!maybeEngine) { - return llvm::make_error( - "failed to construct the MLIR ExecutionEngine", - llvm::inconvertibleErrorCode()); - } - auto &engine = maybeEngine.get(); - auto lambda = std::make_unique((*funcOp).getType(), name); - lambda->engine = std::move(engine); - - return std::move(lambda); -} - -llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { - size_t nbReturn = 0; - // TODO - This check break with memref as we have 5 returns args. - // if (!this->type.getReturnType().isa()) { - // nbReturn = 1; - // } - // if (this->type.getNumParams() != args.size() - nbReturn) { - // return llvm::make_error( - // "invokeRaw: wrong number of argument", - // llvm::inconvertibleErrorCode()); - // } - if (llvm::find(args, nullptr) != args.end()) { - return llvm::make_error( - "invoke: some arguments are null", llvm::inconvertibleErrorCode()); - } - return this->engine->invokePacked(this->name, args); -} - -llvm::Error JITLambda::invoke(Argument &args) { - return std::move(invokeRaw(args.rawArg)); -} - -JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { - // Setting the inputs - { - auto numInputs = 0; - for (size_t i = 0; i < keySet.numInputs(); i++) { - auto offset = numInputs; - auto gate = keySet.inputGate(i); - inputGates.push_back({gate, offset}); - if (keySet.inputGate(i).shape.size == 0) { - // scalar gate - numInputs = numInputs + 1; - continue; - } - // memref gate, as we follow the standard calling convention - numInputs = numInputs + 5; - } - inputs = std::vector(numInputs); - } - - // Setting the outputs - { - auto numOutputs = 0; - for (size_t i = 0; i < keySet.numOutputs(); i++) { - auto offset = numOutputs; - auto gate = keySet.outputGate(i); - outputGates.push_back({gate, offset}); - if (gate.shape.size == 0) { - // scalar gate - numOutputs = numOutputs + 1; - continue; - } - // memref gate, as we follow the standard calling convention - numOutputs = numOutputs + 5; - } - outputs = std::vector(numOutputs); - } - - // The raw argument contains pointers to inputs and pointers to store the - // results - rawArg = std::vector(inputs.size() + outputs.size(), nullptr); - // Set the pointer on outputs on rawArg - for (auto i = inputs.size(); i < rawArg.size(); i++) { - rawArg[i] = &outputs[i - inputs.size()]; - } - - // Setup runtime context with appropriate keys - keySet.initGlobalRuntimeContext(); -} - -JITLambda::Argument::~Argument() { - int err; - for (auto ct : allocatedCiphertexts) { - free_lwe_ciphertext_u64(&err, ct); - } - for (auto buffer : ciphertextBuffers) { - free(buffer); - } -} - -llvm::Expected> -JITLambda::Argument::create(KeySet &keySet) { - auto args = std::make_unique(keySet); - return std::move(args); -} - -llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { - if (pos >= inputGates.size()) { - return llvm::make_error( - llvm::Twine("argument index out of bound: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - auto gate = inputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - - // Check is the argument is a scalar - if (info.shape.size != 0) { - return llvm::make_error( - llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - - // If argument is not encrypted, just save. - if (!info.encryption.hasValue()) { - inputs[offset] = (void *)arg; - rawArg[offset] = &inputs[offset]; - return llvm::Error::success(); - } - // Else if is encryted, allocate ciphertext and encrypt. - LweCiphertext_u64 *ctArg; - if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) { - return std::move(err); - } - allocatedCiphertexts.push_back(ctArg); - if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) { - return std::move(err); - } - inputs[offset] = ctArg; - rawArg[offset] = &inputs[offset]; - return llvm::Error::success(); -} - -llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, - size_t size) { - auto gate = inputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - // Check if the width is compatible - // TODO - I found this rules empirically, they are a spec somewhere? - if (info.shape.width <= 8 && width != 8) { - return llvm::make_error( - llvm::Twine("argument width should be 8: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) { - return llvm::make_error( - llvm::Twine("argument width should be 16: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) { - return llvm::make_error( - llvm::Twine("argument width should be 32: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) { - return llvm::make_error( - llvm::Twine("argument width should be 64: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.width > 64) { - return llvm::make_error( - llvm::Twine("argument width not supported: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - // Check the size - if (info.shape.size == 0) { - return llvm::make_error( - llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (info.shape.size != size) { - return llvm::make_error( - llvm::Twine("vector argument has not the expected size") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - // If argument is not encrypted, just save with the right calling convention. - if (info.encryption.hasValue()) { - // Else if is encrypted - // For moment we support only 8 bits inputs - uint8_t *data8 = (uint8_t *)data; - if (width != 8) { - return llvm::make_error( - llvm::Twine( - "argument width > 8 for encrypted gates are not supported: pos=") - .concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - - // Allocate a buffer for ciphertexts. - auto ctBuffer = - (LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *)); - ciphertextBuffers.push_back(ctBuffer); - // Allocate ciphertexts and encrypt - for (auto i = 0; i < size; i++) { - if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) { - return std::move(err); - } - allocatedCiphertexts.push_back(ctBuffer[i]); - if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) { - return std::move(err); - } - } - // Replace the data by the buffer to ciphertext - data = (void *)ctBuffer; - } - // Set the buffer as the memref calling convention expect. - // allocated - inputs[offset] = (void *)0; // TODO - Better understand how it is used. - rawArg[offset] = &inputs[offset]; - // aligned - inputs[offset + 1] = data; - rawArg[offset + 1] = &inputs[offset + 1]; - // offset - inputs[offset + 2] = (void *)0; - rawArg[offset + 2] = &inputs[offset + 2]; - // size - inputs[offset + 3] = (void *)size; - rawArg[offset + 3] = &inputs[offset + 3]; - // stride - inputs[offset + 4] = (void *)0; - rawArg[offset + 4] = &inputs[offset + 4]; - return llvm::Error::success(); -} - -llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { - auto gate = outputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - - // Check is the argument is a scalar - if (info.shape.size != 0) { - return llvm::make_error( - llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - // If result is not encrypted, just set the result - if (!info.encryption.hasValue()) { - res = (uint64_t)(outputs[offset]); - return llvm::Error::success(); - } - // Else if is encryted, decrypt - LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]); - if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) { - return std::move(err); - } - return llvm::Error::success(); -} - -llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, - size_t size) { - auto gate = outputGates[pos]; - auto info = std::get<0>(gate); - auto offset = std::get<1>(gate); - - // Check is the argument is a scalar - if (info.shape.size == 0) { - return llvm::make_error( - llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)), - llvm::inconvertibleErrorCode()); - } - if (!info.encryption.hasValue()) { - return llvm::make_error( - "unencrypted result as tensor output NYI", - llvm::inconvertibleErrorCode()); - } - // Get the values as the memref calling convention expect. - void *allocated = outputs[offset]; // TODO - Better understand how it is used. - // aligned - void *aligned = outputs[offset + 1]; - // offset - size_t offset_r = (size_t)outputs[offset + 2]; - // size - size_t size_r = (size_t)outputs[offset + 3]; - // stride - size_t stride = (size_t)outputs[offset + 4]; - // Check the sizes - if (info.shape.size != size || size_r != size) { - return llvm::make_error("output bad result buffer size", - llvm::inconvertibleErrorCode()); - } - // decrypt and fill the result buffer - for (auto i = 0; i < size_r; i++) { - LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i]; - if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) { - return std::move(err); - } - } - return llvm::Error::success(); -} - -} // namespace zamalang -} // namespace mlir diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 8c95d6a1d..779e5bfdb 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -1,6 +1,10 @@ #include #include #include +#include + +#include +#include #include #include @@ -54,5 +58,329 @@ runJit(mlir::ModuleOp module, llvm::StringRef func, return mlir::success(); } +llvm::Expected> +JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module, + llvm::function_ref optPipeline) { + + // Looking for the function + auto rangeOps = module.getOps(); + auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) { + return op.getName() == name; + }); + if (funcOp == rangeOps.end()) { + return llvm::make_error( + "cannot find the function to JIT", llvm::inconvertibleErrorCode()); + } + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerLLVMDialectTranslation(*module->getContext()); + + // Create an MLIR execution engine. The execution engine eagerly + // JIT-compiles the module. + auto maybeEngine = mlir::ExecutionEngine::create( + module, /*llvmModuleBuilder=*/nullptr, optPipeline); + if (!maybeEngine) { + return llvm::make_error( + "failed to construct the MLIR ExecutionEngine", + llvm::inconvertibleErrorCode()); + } + auto &engine = maybeEngine.get(); + auto lambda = std::make_unique((*funcOp).getType(), name); + lambda->engine = std::move(engine); + + return std::move(lambda); +} + +llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef args) { + size_t nbReturn = 0; + // TODO - This check break with memref as we have 5 returns args. + // if (!this->type.getReturnType().isa()) { + // nbReturn = 1; + // } + // if (this->type.getNumParams() != args.size() - nbReturn) { + // return llvm::make_error( + // "invokeRaw: wrong number of argument", + // llvm::inconvertibleErrorCode()); + // } + if (llvm::find(args, nullptr) != args.end()) { + return llvm::make_error( + "invoke: some arguments are null", llvm::inconvertibleErrorCode()); + } + return this->engine->invokePacked(this->name, args); +} + +llvm::Error JITLambda::invoke(Argument &args) { + return std::move(invokeRaw(args.rawArg)); +} + +JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) { + // Setting the inputs + { + auto numInputs = 0; + for (size_t i = 0; i < keySet.numInputs(); i++) { + auto offset = numInputs; + auto gate = keySet.inputGate(i); + inputGates.push_back({gate, offset}); + if (keySet.inputGate(i).shape.size == 0) { + // scalar gate + numInputs = numInputs + 1; + continue; + } + // memref gate, as we follow the standard calling convention + numInputs = numInputs + 5; + } + inputs = std::vector(numInputs); + } + + // Setting the outputs + { + auto numOutputs = 0; + for (size_t i = 0; i < keySet.numOutputs(); i++) { + auto offset = numOutputs; + auto gate = keySet.outputGate(i); + outputGates.push_back({gate, offset}); + if (gate.shape.size == 0) { + // scalar gate + numOutputs = numOutputs + 1; + continue; + } + // memref gate, as we follow the standard calling convention + numOutputs = numOutputs + 5; + } + outputs = std::vector(numOutputs); + } + + // The raw argument contains pointers to inputs and pointers to store the + // results + rawArg = std::vector(inputs.size() + outputs.size(), nullptr); + // Set the pointer on outputs on rawArg + for (auto i = inputs.size(); i < rawArg.size(); i++) { + rawArg[i] = &outputs[i - inputs.size()]; + } + + // Setup runtime context with appropriate keys + keySet.initGlobalRuntimeContext(); +} + +JITLambda::Argument::~Argument() { + int err; + for (auto ct : allocatedCiphertexts) { + free_lwe_ciphertext_u64(&err, ct); + } + for (auto buffer : ciphertextBuffers) { + free(buffer); + } +} + +llvm::Expected> +JITLambda::Argument::create(KeySet &keySet) { + auto args = std::make_unique(keySet); + return std::move(args); +} + +llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) { + if (pos >= inputGates.size()) { + return llvm::make_error( + llvm::Twine("argument index out of bound: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + auto gate = inputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size != 0) { + return llvm::make_error( + llvm::Twine("argument is not a scalar: pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + + // If argument is not encrypted, just save. + if (!info.encryption.hasValue()) { + inputs[offset] = (void *)arg; + rawArg[offset] = &inputs[offset]; + return llvm::Error::success(); + } + // Else if is encryted, allocate ciphertext and encrypt. + LweCiphertext_u64 *ctArg; + if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) { + return std::move(err); + } + allocatedCiphertexts.push_back(ctArg); + if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) { + return std::move(err); + } + inputs[offset] = ctArg; + rawArg[offset] = &inputs[offset]; + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width, void *data, + size_t size) { + auto gate = inputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + // Check if the width is compatible + // TODO - I found this rules empirically, they are a spec somewhere? + if (info.shape.width <= 8 && width != 8) { + return llvm::make_error( + llvm::Twine("argument width should be 8: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 8 && info.shape.width <= 16 && width != 16) { + return llvm::make_error( + llvm::Twine("argument width should be 16: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 16 && info.shape.width <= 32 && width != 32) { + return llvm::make_error( + llvm::Twine("argument width should be 32: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 32 && info.shape.width <= 64 && width != 64) { + return llvm::make_error( + llvm::Twine("argument width should be 64: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.width > 64) { + return llvm::make_error( + llvm::Twine("argument width not supported: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // Check the size + if (info.shape.size == 0) { + return llvm::make_error( + llvm::Twine("argument is not a vector: pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (info.shape.size != size) { + return llvm::make_error( + llvm::Twine("vector argument has not the expected size") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // If argument is not encrypted, just save with the right calling convention. + if (info.encryption.hasValue()) { + // Else if is encrypted + // For moment we support only 8 bits inputs + uint8_t *data8 = (uint8_t *)data; + if (width != 8) { + return llvm::make_error( + llvm::Twine( + "argument width > 8 for encrypted gates are not supported: pos=") + .concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + + // Allocate a buffer for ciphertexts. + auto ctBuffer = + (LweCiphertext_u64 **)malloc(size * sizeof(LweCiphertext_u64 *)); + ciphertextBuffers.push_back(ctBuffer); + // Allocate ciphertexts and encrypt + for (auto i = 0; i < size; i++) { + if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) { + return std::move(err); + } + allocatedCiphertexts.push_back(ctBuffer[i]); + if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) { + return std::move(err); + } + } + // Replace the data by the buffer to ciphertext + data = (void *)ctBuffer; + } + // Set the buffer as the memref calling convention expect. + // allocated + inputs[offset] = (void *)0; // TODO - Better understand how it is used. + rawArg[offset] = &inputs[offset]; + // aligned + inputs[offset + 1] = data; + rawArg[offset + 1] = &inputs[offset + 1]; + // offset + inputs[offset + 2] = (void *)0; + rawArg[offset + 2] = &inputs[offset + 2]; + // size + inputs[offset + 3] = (void *)size; + rawArg[offset + 3] = &inputs[offset + 3]; + // stride + inputs[offset + 4] = (void *)0; + rawArg[offset + 4] = &inputs[offset + 4]; + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size != 0) { + return llvm::make_error( + llvm::Twine("output is not a scalar, pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + // If result is not encrypted, just set the result + if (!info.encryption.hasValue()) { + res = (uint64_t)(outputs[offset]); + return llvm::Error::success(); + } + // Else if is encryted, decrypt + LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]); + if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) { + return std::move(err); + } + return llvm::Error::success(); +} + +llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res, + size_t size) { + auto gate = outputGates[pos]; + auto info = std::get<0>(gate); + auto offset = std::get<1>(gate); + + // Check is the argument is a scalar + if (info.shape.size == 0) { + return llvm::make_error( + llvm::Twine("output is not a tensor, pos=").concat(llvm::Twine(pos)), + llvm::inconvertibleErrorCode()); + } + if (!info.encryption.hasValue()) { + return llvm::make_error( + "unencrypted result as tensor output NYI", + llvm::inconvertibleErrorCode()); + } + // Get the values as the memref calling convention expect. + void *allocated = outputs[offset]; // TODO - Better understand how it is used. + // aligned + void *aligned = outputs[offset + 1]; + // offset + size_t offset_r = (size_t)outputs[offset + 2]; + // size + size_t size_r = (size_t)outputs[offset + 3]; + // stride + size_t stride = (size_t)outputs[offset + 4]; + // Check the sizes + if (info.shape.size != size || size_r != size) { + return llvm::make_error("output bad result buffer size", + llvm::inconvertibleErrorCode()); + } + // decrypt and fill the result buffer + for (auto i = 0; i < size_r; i++) { + LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)(aligned))[i]; + if (auto err = this->keySet.decrypt_lwe(pos, ct, res[i])) { + return std::move(err); + } + } + return llvm::Error::success(); +} + } // namespace zamalang } // namespace mlir diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp new file mode 100644 index 000000000..76097f85d --- /dev/null +++ b/compiler/lib/Support/Pipeline.cpp @@ -0,0 +1,148 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace mlir { +namespace zamalang { +namespace pipeline { +static void addPotentiallyNestedPass(mlir::PassManager &pm, + std::unique_ptr pass) { + if (!pass->getOpName() || *pass->getOpName() == "module") { + pm.addPass(std::move(pass)); + } else { + pm.nest(*pass->getOpName()).addPass(std::move(pass)); + } +} + +mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, bool verbose) { + mlir::PassManager pm(&context); + + if (verbose) { + mlir::zamalang::log_verbose() + << "##################################################\n" + << "### HLFHE to MidLFHE pipeline\n"; + + pm.enableIRPrinting(); + pm.enableStatistics(); + pm.enableTiming(); + pm.enableVerifier(); + } + + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg()); + addPotentiallyNestedPass(pm, + mlir::zamalang::createConvertHLFHEToMidLFHEPass()); + + return pm.run(module.getOperation()); +} + +mlir::LogicalResult lowerMidLFHEToLowLFHE(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, + bool parametrize) { + mlir::PassManager pm(&context); + + if (parametrize) { + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertMidLFHEGlobalParametrizationPass( + fheContext)); + } + + addPotentiallyNestedPass(pm, + mlir::zamalang::createConvertMidLFHEToLowLFHEPass()); + + return pm.run(module.getOperation()); +} + +mlir::LogicalResult lowerLowLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module) { + mlir::PassManager pm(&context); + pm.addPass(mlir::zamalang::createConvertLowLFHEToConcreteCAPIPass()); + return pm.run(module.getOperation()); +} + +mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, + mlir::ModuleOp &module, + bool verbose) { + mlir::PassManager pm(&context); + + if (verbose) { + mlir::zamalang::log_verbose() + << "##################################################\n" + << "### MlirStdsDialectToMlirLLVMDialect pipeline\n"; + context.disableMultithreading(); + pm.enableIRPrinting(); + pm.enableStatistics(); + pm.enableTiming(); + pm.enableVerifier(); + } + + // Unparametrize LowLFHE + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertLowLFHEUnparametrizePass()); + + // Bufferize + addPotentiallyNestedPass(pm, mlir::createTensorConstantBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createStdBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createTensorBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createLinalgBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createConvertLinalgToLoopsPass()); + addPotentiallyNestedPass(pm, mlir::createFuncBufferizePass()); + addPotentiallyNestedPass(pm, mlir::createFinalizingBufferizePass()); + + // Convert to MLIR LLVM Dialect + addPotentiallyNestedPass( + pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()); + + return pm.run(module); +} + +std::unique_ptr +lowerLLVMDialectToLLVMIR(mlir::MLIRContext &context, + llvm::LLVMContext &llvmContext, + mlir::ModuleOp &module) { + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + mlir::registerLLVMDialectTranslation(*module->getContext()); + + return mlir::translateModuleToLLVMIR(module, llvmContext); +} + +mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext, + llvm::Module &module) { + std::function optPipeline = + mlir::makeOptimizingTransformer(3, 0, nullptr); + + if (optPipeline(&module)) + return mlir::failure(); + else + return mlir::success(); +} + +mlir::LogicalResult lowerHLFHEToStd(mlir::MLIRContext &context, + mlir::ModuleOp &module, + V0FHEContext &fheContext, bool verbose) { + if (lowerHLFHEToMidLFHE(context, module, verbose).failed() || + lowerMidLFHEToLowLFHE(context, module, fheContext, true).failed() || + lowerLowLFHEToStd(context, module).failed()) { + return mlir::failure(); + } else { + return mlir::success(); + } +} + +} // namespace pipeline +} // namespace zamalang +} // namespace mlir diff --git a/compiler/python/CompilerAPIModule.cpp b/compiler/python/CompilerAPIModule.cpp index 887e79fe0..62e6203b7 100644 --- a/compiler/python/CompilerAPIModule.cpp +++ b/compiler/python/CompilerAPIModule.cpp @@ -7,7 +7,6 @@ #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" #include "zamalang/Support/CompilerEngine.h" -#include "zamalang/Support/CompilerTools.h" #include #include #include diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index ec15e7b80..517f128ff 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -12,16 +12,32 @@ #include #include +#include "mlir/IR/BuiltinOps.h" #include "zamalang/Conversion/Passes.h" +#include "zamalang/Conversion/Utils/GlobalFHEContext.h" #include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h" #include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h" #include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" -#include "zamalang/Support/CompilerTools.h" -#include "zamalang/Support/logging.h" #include "zamalang/Support/Jit.h" +#include "zamalang/Support/KeySet.h" +#include "zamalang/Support/Pipeline.h" +#include "zamalang/Support/logging.h" + +enum EntryDialect { HLFHE, MIDLFHE, LOWLFHE, STD, LLVM }; + +enum Action { + ROUND_TRIP, + DUMP_MIDLFHE, + DUMP_LOWLFHE, + DUMP_STD, + DUMP_LLVM_DIALECT, + DUMP_LLVM_IR, + DUMP_OPTIMIZED_LLVM_IR, + JIT_INVOKE +}; namespace cmdline { @@ -37,14 +53,53 @@ llvm::cl::opt output("o", llvm::cl::opt verbose("verbose", llvm::cl::desc("verbose logs"), llvm::cl::init(false)); -llvm::cl::list passes( - "passes", - llvm::cl::desc("Specify the passes to run (use only for compiler tests)"), - llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore); +llvm::cl::opt parametrizeMidLFHE( + "parametrize-midlfhe", + llvm::cl::desc("Perform MidLFHE global parametrization pass"), + llvm::cl::init(true)); -llvm::cl::opt roundTrip("round-trip", - llvm::cl::desc("Just parse and dump"), - llvm::cl::init(false)); +static llvm::cl::opt entryDialect( + "e", "entry-dialect", llvm::cl::desc("Entry dialect"), + llvm::cl::init(EntryDialect::HLFHE), + llvm::cl::ValueRequired, llvm::cl::NumOccurrencesFlag::Required, + llvm::cl::values( + clEnumValN(EntryDialect::HLFHE, "hlfhe", + "Input module is composed of HLFHE operations")), + llvm::cl::values( + clEnumValN(EntryDialect::MIDLFHE, "midlfhe", + "Input module is composed of MidLFHE operations")), + llvm::cl::values( + clEnumValN(EntryDialect::LOWLFHE, "lowlfhe", + "Input module is composed of LowLFHE operations")), + llvm::cl::values( + clEnumValN(EntryDialect::STD, "std", + "Input module is composed of operations from std")), + llvm::cl::values( + clEnumValN(EntryDialect::LLVM, "llvm", + "Input module is composed of operations from llvm"))); + +static llvm::cl::opt action( + "a", "action", llvm::cl::desc("output mode"), llvm::cl::ValueRequired, + llvm::cl::NumOccurrencesFlag::Required, + llvm::cl::values( + clEnumValN(Action::ROUND_TRIP, "roundtrip", + "Parse input module and regenerate textual representation")), + llvm::cl::values(clEnumValN(Action::DUMP_MIDLFHE, "dump-midlfhe", + "Lower to MidLFHE and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_LOWLFHE, "dump-lowlfhe", + "Lower to LowLFHE and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_STD, "dump-std", + "Lower to std and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_LLVM_DIALECT, "dump-llvm-dialect", + "Lower to LLVM dialect and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_LLVM_IR, "dump-llvm-ir", + "Lower to LLVM-IR and dump result")), + llvm::cl::values(clEnumValN(Action::DUMP_OPTIMIZED_LLVM_IR, + "dump-optimized-llvm-ir", + "Lower to LLVM-IR, optimize and dump result")), + llvm::cl::values(clEnumValN(Action::JIT_INVOKE, "jit-invoke", + "Lower and JIT-compile input module and invoke " + "function specified with --jit-funcname"))); llvm::cl::opt verifyDiagnostics( "verify-diagnostics", @@ -58,15 +113,7 @@ llvm::cl::opt splitInputFile( "chunk independently"), llvm::cl::init(false)); -llvm::cl::opt generateKeySet( - "generate-keyset", - llvm::cl::desc("[tmp] Generate a key set for the compiled fhe circuit"), - llvm::cl::init(false)); - -llvm::cl::opt runJit("run-jit", llvm::cl::desc("JIT the code and run it"), - llvm::cl::init(false)); - -llvm::cl::opt jitFuncname( +llvm::cl::opt jitFuncName( "jit-funcname", llvm::cl::desc("Name of the function to execute, default 'main'"), llvm::cl::init("main")); @@ -75,73 +122,16 @@ llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); - -llvm::cl::opt toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "), - llvm::cl::init(false)); }; // namespace cmdline std::function defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); -mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) { - llvm::LLVMContext context; - auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule( - context, module, defaultOptPipeline); - if (!llvmModule) { - return mlir::failure(); - } - os << **llvmModule; - return mlir::success(); -} +std::unique_ptr +generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, + const std::string &jitFuncName) { + std::unique_ptr keySet; -// Process a single source buffer -// -// If `verifyDiagnostics` is `true`, the procedure only checks if the -// diagnostic messages provided in the source buffer using -// `expected-error` are produced. -// -// If `verifyDiagnostics` is `false`, the procedure checks if the -// parsed module is valid and if all requested transformations -// succeeded. -mlir::LogicalResult -processInputBuffer(mlir::MLIRContext &context, - std::unique_ptr buffer, - llvm::raw_ostream &os, bool verifyDiagnostics) { - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); - - mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, - &context); - auto module = mlir::parseSourceFile(sourceMgr, &context); - - if (verifyDiagnostics) - return sourceMgrHandler.verify(); - - if (!module) - return mlir::failure(); - - if (cmdline::roundTrip) { - module->print(os); - return mlir::success(); - } - - auto enablePass = [](std::string passName) { - return cmdline::passes.size() == 0 || - std::any_of(cmdline::passes.begin(), cmdline::passes.end(), - [&](const std::string &p) { return passName == p; }); - }; - - // Lower to MLIR Stds Dialects and compute the constraint on the FHE Circuit. - mlir::zamalang::CompilerTools::LowerOptions lowerOptions; - lowerOptions.enablePass = enablePass; - lowerOptions.verbose = cmdline::verbose; - - mlir::zamalang::V0FHEContext fheContext; - if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirStdsDialect( - context, *module, fheContext, lowerOptions) - .failed()) { - return mlir::failure(); - } mlir::zamalang::log_verbose() << "### Global FHE constraint: {norm2:" << fheContext.constraint.norm2 << ", p:" << fheContext.constraint.p << "}\n"; @@ -155,45 +145,196 @@ processInputBuffer(mlir::MLIRContext &context, << ", ksLevel: " << fheContext.parameter.ksLevel << ", ksLogBase: " << fheContext.parameter.ksLogBase << "}\n"; - // Generate the keySet - std::unique_ptr keySet; - if (cmdline::generateKeySet || cmdline::runJit) { - // Create the client parameters - auto clientParameter = mlir::zamalang::createClientParametersForV0( - fheContext, cmdline::jitFuncname, *module); - if (auto err = clientParameter.takeError()) { - mlir::zamalang::log_error() - << "cannot generate client parameters: " << err << "\n"; - return mlir::failure(); - } - mlir::zamalang::log_verbose() << "### Generate the key set\n"; - auto maybeKeySet = - mlir::zamalang::KeySet::generate(clientParameter.get(), 0, - 0); // TODO: seed - if (auto err = maybeKeySet.takeError()) { - llvm::errs() << err; - return mlir::failure(); - } - keySet = std::move(maybeKeySet.get()); + // Create the client parameters + auto clientParameter = mlir::zamalang::createClientParametersForV0( + fheContext, jitFuncName, module); + + if (auto err = clientParameter.takeError()) { + mlir::zamalang::log_error() + << "cannot generate client parameters: " << err << "\n"; + return nullptr; } - // Lower to MLIR LLVM Dialect - if (mlir::zamalang::CompilerTools::lowerMlirStdsDialectToMlirLLVMDialect( - context, *module, lowerOptions) - .failed()) { + mlir::zamalang::log_verbose() << "### Generate the key set\n"; + + auto maybeKeySet = mlir::zamalang::KeySet::generate(clientParameter.get(), 0, + 0); // TODO: seed + if (auto err = maybeKeySet.takeError()) { + llvm::errs() << err; + return nullptr; + } + + return std::move(maybeKeySet.get()); +} + +// Process a single source buffer +// +// The parameter `entryDialect` must specify the FHE dialect to which +// belong all FHE operations used in the source buffer. The input +// program must only contain FHE operations from that single FHE +// dialect, otherwise processing might fail. +// +// The parameter `action` specifies how the buffer should be processed +// and thus defines the output. +// +// If the specified action involves JIT compilation, `jitFuncName` +// designates the function to JIT compile. This function is invoked +// using the parameters given in `jitArgs`. +// +// The parameter `parametrizeMidLFHE` defines, whether the +// parametrization pass for MidLFHE is executed. If the pair of +// `entryDialect` and `action` does not involve any MidlFHE +// manipulation, this parameter does not have any effect. +// +// If `verifyDiagnostics` is `true`, the procedure only checks if the +// diagnostic messages provided in the source buffer using +// `expected-error` are produced. If `verifyDiagnostics` is `false`, +// the procedure checks if the parsed module is valid and if all +// requested transformations succeeded. +// +// If `verbose` is true, debug messages are displayed throughout the +// compilation process. +// +// Compilation output is written to the stream specified by `os`. +mlir::LogicalResult processInputBuffer( + mlir::MLIRContext &context, std::unique_ptr buffer, + enum EntryDialect entryDialect, enum Action action, + const std::string &jitFuncName, llvm::ArrayRef jitArgs, + bool parametrizeMidlHFE, bool verifyDiagnostics, bool verbose, + llvm::raw_ostream &os) { + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); + + mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, + &context); + mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context); + + // This is temporary until we have the high-level verification pass + // determining these parameters automatically + mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, + .p = 7}; + + std::unique_ptr keySet = nullptr; + + const mlir::zamalang::V0Parameter *parameter = + getV0Parameter(defaultGlobalFHECircuitConstraint); + + if (!parameter) { + mlir::zamalang::log_error() + << "Could not determine V0 parameters for 2-norm of " + << defaultGlobalFHECircuitConstraint.norm2 << " and p of " + << defaultGlobalFHECircuitConstraint.p << "\n"; + return mlir::failure(); } - if (cmdline::runJit) { - mlir::zamalang::log_verbose() << "### JIT compile & running\n"; - return mlir::zamalang::runJit(module.get(), cmdline::jitFuncname, - cmdline::jitArgs, *keySet, - defaultOptPipeline, os); + mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, + *parameter}; + + if (verbose) + context.disableMultithreading(); + + if (verifyDiagnostics) + return sourceMgrHandler.verify(); + + if (!moduleRef) + return mlir::failure(); + + mlir::ModuleOp module = moduleRef.get(); + + if (action == Action::ROUND_TRIP) { + module->print(os); + return mlir::success(); } - if (cmdline::toLLVM) { - return dumpLLVMIR(module.get(), os); + + // Lowering pipeline. Each stage is represented as a label in the + // switch statement, from the most abstract dialect to the lowest + // level. Every labels acts as an entry point into the pipeline with + // a fallthrough mechanism to the next stage. Actions act as exit + // points from the pipeline. + switch (entryDialect) { + case EntryDialect::HLFHE: + if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) + .failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::MIDLFHE: + if (action == Action::DUMP_MIDLFHE) { + module.print(os); + return mlir::success(); + } + + if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( + context, module, fheContext, parametrizeMidlHFE) + .failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::LOWLFHE: + if (action == Action::DUMP_LOWLFHE) { + module.print(os); + return mlir::success(); + } + + if (mlir::zamalang::pipeline::lowerLowLFHEToStd(context, module).failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::STD: + if (action == Action::DUMP_STD) { + module.print(os); + return mlir::success(); + } else if (action == Action::JIT_INVOKE) { + keySet = generateKeySet(module, fheContext, jitFuncName); + } + + if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module, + verbose) + .failed()) + return mlir::failure(); + + // fallthrough + case EntryDialect::LLVM: { + if (action == Action::DUMP_LLVM_DIALECT) { + module.print(os); + return mlir::success(); + } else if (action == Action::JIT_INVOKE) { + return mlir::zamalang::runJit(module, jitFuncName, jitArgs, *keySet, + defaultOptPipeline, os); + } + + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + mlir::zamalang::pipeline::lowerLLVMDialectToLLVMIR(context, llvmContext, + module); + + if (!llvmModule) { + mlir::zamalang::log_error() + << "Failed to translate LLVM dialect to LLVM IR\n"; + return mlir::failure(); + } + + if (action == Action::DUMP_LLVM_IR) { + llvmModule->dump(); + return mlir::success(); + } + + if (mlir::zamalang::pipeline::optimizeLLVMModule(llvmContext, *llvmModule) + .failed()) { + mlir::zamalang::log_error() << "Failed to optimize LLVM IR\n"; + return mlir::failure(); + } + + if (action == Action::DUMP_OPTIMIZED_LLVM_IR) { + llvmModule->dump(); + return mlir::success(); + } + + break; } - module->print(os); + } + return mlir::success(); } @@ -209,6 +350,16 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { // String for error messages from library functions std::string errorMessage; + if (cmdline::action == Action::JIT_INVOKE && + cmdline::entryDialect != EntryDialect::HLFHE && + cmdline::entryDialect != EntryDialect::MIDLFHE && + cmdline::entryDialect != EntryDialect::LOWLFHE && + cmdline::entryDialect != EntryDialect::STD) { + mlir::zamalang::log_error() + << "Can only JIT invoke HLFHE / MidLFHE / LowLFHE / STD programs"; + return mlir::failure(); + } + // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); context.getOrLoadDialect(); @@ -229,7 +380,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { return mlir::failure(); } - // Iterate over all inpiut files specified on the command line + // Iterate over all input files specified on the command line for (const auto &fileName : cmdline::inputs) { auto file = mlir::openInputFile(fileName, &errorMessage); @@ -247,14 +398,19 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { std::move(file), [&](std::unique_ptr inputBuffer, llvm::raw_ostream &os) { - return processInputBuffer(context, std::move(inputBuffer), os, - cmdline::verifyDiagnostics); + return processInputBuffer( + context, std::move(inputBuffer), cmdline::entryDialect, + cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, + cmdline::parametrizeMidLFHE, cmdline::verifyDiagnostics, + cmdline::verbose, os); }, output->os()))) return mlir::failure(); } else { - return processInputBuffer(context, std::move(file), output->os(), - cmdline::verifyDiagnostics); + return processInputBuffer( + context, std::move(file), cmdline::entryDialect, cmdline::action, + cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, + cmdline::verifyDiagnostics, cmdline::verbose, output->os()); } } diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir index c94a65a40..bf08e2e90 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>, %arg1: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { @@ -7,4 +7,4 @@ func @add_eint(%arg0: !HLFHE.eint<7>, %arg1: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<7>, !HLFHE.eint<7>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir index 811246be0..49b8063a3 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/add_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @add_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @add_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir index 74c4d531c..62a17b7d9 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{_,_,_}{2}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{_,_,_}{2}> func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.eint<2> { @@ -7,4 +7,4 @@ func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<4xi2>) -> !HLFHE.e %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<4xi2>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir index bcc99851f..23a7b3c28 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/apply_univariate_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { @@ -8,4 +8,4 @@ func @apply_lookup_table_cst(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<7>, tensor<128xi64>) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir index 9fba81d0f..c98974a2a 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/linalg_generic.mlir @@ -1,19 +1,18 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s - -//CHECK: #map0 = affine_map<(d0) -> (d0)> -//CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> -//CHECK-NEXT: module { -//CHECK-NEXT: func @linalg_generic(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { -//CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { -//CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors -//CHECK-NEXT: %1 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: %2 = "MidLFHE.add_glwe"(%1, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: linalg.yield %2 : !MidLFHE.glwe<{_,_,_}{2}> -//CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> -//CHECK-NEXT: return -//CHECK-NEXT: } -//CHECK-NEXT: } +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s +// CHECK: #map0 = affine_map<(d0) -> (d0)> +// CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> +// CHECK-NEXT: module { +// CHECK-NEXT: func @linalg_generic(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>, %arg2: tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { +// CHECK-NEXT: %0 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%arg2 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { +// CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +// CHECK-NEXT: %1 = "MidLFHE.mul_glwe_int"(%arg3, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: %2 = "MidLFHE.add_glwe"(%1, %arg5) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: linalg.yield %2 : !MidLFHE.glwe<{_,_,_}{2}> +// CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } #map0 = affine_map<(d0) -> (d0)> #map1 = affine_map<(d0) -> (0)> diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir index 0297179ca..c83bac0ba 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/mul_eint_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @mul_eint_int(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { @@ -9,4 +9,4 @@ func @mul_eint_int(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { %0 = constant 1 : i8 %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>) return %1: !HLFHE.eint<7> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir index 0962bf42a..a34343f21 100644 --- a/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir +++ b/compiler/tests/Conversion/HLFHEToMidLFHE/sub_int_eint.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --passes hlfhe-to-midlfhe 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_eint(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> func @sub_int_eint(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir index 87fef60eb..bbd678db6 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/bootstrap.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) @@ -27,4 +27,4 @@ func @bootstrap_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: !LowLFHE.glwe // CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, k = 1 : i32, level = 3 : i32, polynomialSize = 1024 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>, !LowLFHE.glwe_ciphertext) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir index ce6e2aff4..a40f5df5a 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/glwe_from_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @runtime_foreign_plaintext_list_u64(index, tensor<16xi64>, i64, i32) -> !LowLFHE.foreign_plaintext_list @@ -31,4 +31,4 @@ func @glwe_from_table(%arg0: tensor<16xi64>) -> !LowLFHE.glwe_ciphertext { // CHECK-NEXT: return %[[V1]] : !LowLFHE.glwe_ciphertext %1 = "LowLFHE.glwe_from_table"(%arg0) {k = 1 : i32, p = 4 : i32, polynomialSize = 1024 : i32} : (tensor<16xi64>) -> !LowLFHE.glwe_ciphertext return %1: !LowLFHE.glwe_ciphertext -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir index 09f609291..66ecf0177 100644 --- a/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir +++ b/compiler/tests/Conversion/LowLFHEToConcreteCAPI/keyswitch_lwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes lowlfhe-to-concrete-c-api %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=lowlfhe --action=dump-std %s 2>&1| FileCheck %s // CHECK-LABEL: module // CHECK-NEXT: func private @add_plaintext_list_glwe_ciphertext_u64(index, !LowLFHE.glwe_ciphertext, !LowLFHE.glwe_ciphertext, !LowLFHE.plaintext_list) @@ -26,4 +26,4 @@ func @keyswitch_lwe(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciph // CHECK-NEXT: return %[[RES]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "LowLFHE.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweSize = 1 : i32, level = 3 : i32, outputLweSize = 1 : i32} : (!LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> return %1: !LowLFHE.lwe_ciphertext<1024,4> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir b/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir deleted file mode 100644 index 852924948..000000000 --- a/compiler/tests/Conversion/LowLFHEUnparametrize/func.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: zamacompiler --passes lowlfhe-unparametrize %s 2>&1| FileCheck %s - -// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_> -func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<1024,4> { - // CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_> - return %arg0: !LowLFHE.lwe_ciphertext<1024,4> -} \ No newline at end of file diff --git a/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir b/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir deleted file mode 100644 index b48ba3c24..000000000 --- a/compiler/tests/Conversion/LowLFHEUnparametrize/unrealized_conversion_cast.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: zamacompiler --passes lowlfhe-unparametrize %s 2>&1| FileCheck %s - -// CHECK-LABEL: func @main(%arg0: !LowLFHE.lwe_ciphertext<_,_>) -> !LowLFHE.lwe_ciphertext<_,_> -func @main(%arg0: !LowLFHE.lwe_ciphertext<1024,4>) -> !LowLFHE.lwe_ciphertext<_,_> { - // CHECK-NEXT: return %arg0 : !LowLFHE.lwe_ciphertext<_,_> - %0 = unrealized_conversion_cast %arg0 : !LowLFHE.lwe_ciphertext<1024,4> to !LowLFHE.lwe_ciphertext<_,_> - return %0: !LowLFHE.lwe_ciphertext<_,_> -} \ No newline at end of file diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir index f84dcbd69..7814e2064 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir index db9e32472..ee1d3e53a 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { @@ -19,4 +19,4 @@ func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE. // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.add_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index f3cda0d8f..926c788ce 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { @@ -8,4 +8,4 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16x // CHECK-NEXT: return %[[V3]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1){k=1:i32, polynomialSize=1024:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{1024,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index e202ae025..adc811bef 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { @@ -10,4 +10,4 @@ func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.g %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi4> %1 = "MidLFHE.apply_lookup_table"(%arg0, %tlu){k=1:i32, polynomialSize=2048:i32, levelKS=3:i32, baseLogKS=2:i32, levelBS=5:i32, baseLogBS=4:i32, outputSizeKS=600:i32}: (!MidLFHE.glwe<{2048,1,64}{4}>, tensor<16xi4>) -> (!MidLFHE.glwe<{2048,1,64}{4}>) return %1: !MidLFHE.glwe<{2048,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir index 3950f64e1..c13c73353 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { @@ -19,4 +19,4 @@ func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE. // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.mul_glwe_int"(%arg0, %arg1): (!MidLFHE.glwe<{1024,1,64}{4}>, i5) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir index 5af7c049c..a40db4b79 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --passes midlfhe-to-lowlfhe %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { @@ -20,4 +20,4 @@ func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: i5) -> !MidLFHE. // CHECK-NEXT: return %[[V2]] : !LowLFHE.lwe_ciphertext<1024,4> %1 = "MidLFHE.sub_int_glwe"(%arg1, %arg0): (i5, !MidLFHE.glwe<{1024,1,64}{4}>) -> (!MidLFHE.glwe<{1024,1,64}{4}>) return %1: !MidLFHE.glwe<{1024,1,64}{4}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir index 2e2f31a65..0e53d8d5a 100644 --- a/compiler/tests/Dialect/HLFHE/dot.invalid.mlir +++ b/compiler/tests/Dialect/HLFHE/dot.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=hlfhe --action=roundtrip %s // Incompatible shapes func @dot_incompatible_shapes( @@ -66,4 +66,4 @@ func @dot_incompatible_int( (tensor<4x!HLFHE.eint<2>>, tensor<4xi4>) -> !HLFHE.eint<2> return %ret : !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir index de76e776f..6a6d4f962 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_big.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<8>) { diff --git a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir index 300d49685..bb43f441c 100644 --- a/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir +++ b/compiler/tests/Dialect/HLFHE/eint_error_p_too_small.mlir @@ -1,4 +1,4 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: eint support only precision in ]0;7] func @test(%arg0: !HLFHE.eint<0>) { diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir index 01bd02a02..1bb62e224 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_inputs.mlir @@ -1,7 +1,7 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<3>) -> !HLFHE.eint<2> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<3>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir index 3c9e7b0cb..d43bc7194 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_err_result.mlir @@ -1,7 +1,7 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint' op should have the width of encrypted inputs and result equals func @add_eint(%arg0: !HLFHE.eint<2>, %arg1: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %1 = "HLFHE.add_eint"(%arg0, %arg1): (!HLFHE.eint<2>, !HLFHE.eint<2>) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir index 680c79b57..205e7afe1 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_inputs.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %0 = constant 1 : i4 %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir index 299d2771a..0a8ae9a8c 100644 --- a/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_add_eint_int_err_result.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.add_eint_int' op should have the width of encrypted inputs and result equals func @add_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %0 = constant 1 : i2 %1 = "HLFHE.add_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir index ea0c179a8..0a8d9cd48 100644 --- a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir +++ b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir @@ -1,7 +1,7 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument. func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> { %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<8xi3>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir index 6ea308618..6a9e6e059 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_inputs.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1 func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %0 = constant 1 : i4 %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i4) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir index bc7ef7637..ee84b2a49 100644 --- a/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_mul_eint_int_err_result.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.mul_eint_int' op should have the width of encrypted inputs and result equals func @mul_eint_int(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %0 = constant 1 : i2 %1 = "HLFHE.mul_eint_int"(%arg0, %0): (!HLFHE.eint<2>, i2) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir index 73c036e5e..deded0859 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_inputs.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1 func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %0 = constant 1 : i4 %1 = "HLFHE.sub_int_eint"(%0, %arg0): (i4, !HLFHE.eint<2>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir index 6d2892e8a..207414189 100644 --- a/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir +++ b/compiler/tests/Dialect/HLFHE/op_sub_int_eint_err_result.mlir @@ -1,8 +1,8 @@ -// RUN: not zamacompiler %s 2>&1| FileCheck %s +// RUN: not zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: error: 'HLFHE.sub_int_eint' op should have the width of encrypted inputs and result equals func @sub_int_eint(%arg0: !HLFHE.eint<2>) -> !HLFHE.eint<3> { %0 = constant 1 : i2 %1 = "HLFHE.sub_int_eint"(%0, %arg0): (i2, !HLFHE.eint<2>) -> (!HLFHE.eint<3>) return %1: !HLFHE.eint<3> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/HLFHE/ops.mlir b/compiler/tests/Dialect/HLFHE/ops.mlir index de2c43060..4ee6f53bc 100644 --- a/compiler/tests/Dialect/HLFHE/ops.mlir +++ b/compiler/tests/Dialect/HLFHE/ops.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @zero() -> !HLFHE.eint<2> func @zero() -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir index f313d7501..1130181ea 100644 --- a/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/tensor-ops-to-linalg.mlir @@ -1,22 +1,22 @@ -// RUN: zamacompiler %s --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s +// RUN: zamacompiler %s --entry-dialect=hlfhe --action=dump-midlfhe 2>&1 | FileCheck %s //CHECK: #map0 = affine_map<(d0) -> (d0)> //CHECK-NEXT: #map1 = affine_map<(d0) -> (0)> //CHECK-NEXT: module { -//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { -//CHECK-NEXT: %0 = "HLFHE.zero"() : () -> !HLFHE.eint<2> -//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!HLFHE.eint<2>> -//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!HLFHE.eint<2>>, tensor<2xi3>) outs(%1 : tensor<1x!HLFHE.eint<2>>) { -//CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors -//CHECK-NEXT: %4 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> -//CHECK-NEXT: %5 = "HLFHE.add_eint"(%4, %arg4) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> -//CHECK-NEXT: linalg.yield %5 : !HLFHE.eint<2> -//CHECK-NEXT: } -> tensor<1x!HLFHE.eint<2>> +//CHECK-NEXT: func @dot_eint_int(%arg0: tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<2xi3>) -> !MidLFHE.glwe<{_,_,_}{2}> { +//CHECK-NEXT: %0 = "MidLFHE.zero"() : () -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %1 = tensor.from_elements %0 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map0, #map0, #map1], iterator_types = ["reduction"]} ins(%arg0, %arg1 : tensor<2x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2xi3>) outs(%1 : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg2: !MidLFHE.glwe<{_,_,_}{2}>, %arg3: i3, %arg4: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %4 = "MidLFHE.mul_glwe_int"(%arg2, %arg3) : (!MidLFHE.glwe<{_,_,_}{2}>, i3) -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %5 = "MidLFHE.add_glwe"(%4, %arg4) : (!MidLFHE.glwe<{_,_,_}{2}>, !MidLFHE.glwe<{_,_,_}{2}>) -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %5 : !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> //CHECK-NEXT: %c0 = constant 0 : index -//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!HLFHE.eint<2>> -//CHECK-NEXT: return %3 : !HLFHE.eint<2> +//CHECK-NEXT: %3 = tensor.extract %2[%c0] : tensor<1x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: return %3 : !MidLFHE.glwe<{_,_,_}{2}> //CHECK-NEXT: } -//CHECK-NEXT: } +//CHECK-NEXT: } func @dot_eint_int(%arg0: tensor<2x!HLFHE.eint<2>>, %arg1: tensor<2xi3>) -> !HLFHE.eint<2> { diff --git a/compiler/tests/Dialect/HLFHE/types.mlir b/compiler/tests/Dialect/HLFHE/types.mlir index 97d77ffa8..8e6b6bc85 100644 --- a/compiler/tests/Dialect/HLFHE/types.mlir +++ b/compiler/tests/Dialect/HLFHE/types.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=hlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>> func @memref_arg(%arg0: memref<2x!HLFHE.eint<7>>) { diff --git a/compiler/tests/Dialect/LowLFHE/ops.mlir b/compiler/tests/Dialect/LowLFHE/ops.mlir index be4315c9d..b909b0473 100644 --- a/compiler/tests/Dialect/LowLFHE/ops.mlir +++ b/compiler/tests/Dialect/LowLFHE/ops.mlir @@ -1,5 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s - +// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_lwe_ciphertexts(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> { diff --git a/compiler/tests/Dialect/LowLFHE/types.mlir b/compiler/tests/Dialect/LowLFHE/types.mlir index 0b6d37df0..27552cb2a 100644 --- a/compiler/tests/Dialect/LowLFHE/types.mlir +++ b/compiler/tests/Dialect/LowLFHE/types.mlir @@ -1,5 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s - +// RUN: zamacompiler --entry-dialect=lowlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen func @type_enc_rand_gen(%arg0: !LowLFHE.enc_rand_gen) -> !LowLFHE.enc_rand_gen { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir index 18cb9f48f..55df41028 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter result func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir index eebd51590..3d1f81407 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir index 4ce221611..97ead991e 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir index dae0011a8..ba6a37313 100644 --- a/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -9,4 +9,4 @@ func @add_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 %0 = constant 1 : i8 %1 = "MidLFHE.add_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir index 04f52dafc..ee8a78f6d 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // Bad dimension of the lookup table func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir index 07de47e6e..928095674 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<128xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> { @@ -7,4 +7,4 @@ func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<12 %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {k = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32, outputSizeKS = 600 : i32} : (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<128xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>) return %1: !MidLFHE.glwe<{512,10,64}{2}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir index f3fe65d0e..f21873208 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir index 4a15c7a59..ae9daa983 100644 --- a/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -9,4 +9,4 @@ func @mul_glwe_int(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 %0 = constant 1 : i8 %1 = "MidLFHE.mul_glwe_int"(%arg0, %0): (!MidLFHE.glwe<{1024,12,64}{7}>, i8) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir index 29099bf7e..0903aeb00 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.invalid.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --split-input-file --verify-diagnostics %s +// RUN: zamacompiler --split-input-file --verify-diagnostics --entry-dialect=midlfhe --action=roundtrip %s // GLWE p parameter func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{6}> { diff --git a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir index 32032566f..47fab99a6 100644 --- a/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --round-trip %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=roundtrip %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -9,4 +9,4 @@ func @sub_int_glwe(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024 %0 = constant 1 : i8 %1 = "MidLFHE.sub_int_glwe"(%0, %arg0): (i8, !MidLFHE.glwe<{1024,12,64}{7}>) -> (!MidLFHE.glwe<{1024,12,64}{7}>) return %1: !MidLFHE.glwe<{1024,12,64}{7}> -} \ No newline at end of file +} diff --git a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir index eaec6e7dd..b66236c76 100644 --- a/compiler/tests/Dialect/MidLFHE/types_glwe.mlir +++ b/compiler/tests/Dialect/MidLFHE/types_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler %s --round-trip 2>&1| FileCheck %s +// RUN: zamacompiler %s --entry-dialect=midlfhe --action=roundtrip 2>&1| FileCheck %s // CHECK-LABEL: func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64}{7}> { @@ -10,4 +10,4 @@ func @glwe_0(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>) -> !MidLFHE.glwe<{1024,12,64 func @glwe_1(%arg0: !MidLFHE.glwe<{_,_,_}{7}>) -> !MidLFHE.glwe<{_,_,_}{7}> { // CHECK-LABEL: return %arg0 : !MidLFHE.glwe<{_,_,_}{7}> return %arg0: !MidLFHE.glwe<{_,_,_}{7}> -} \ No newline at end of file +}