From 2acfa63eb7bb5a5dd9ecc6708557d01f3ed33d0b Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 24 Sep 2021 23:43:41 +0200 Subject: [PATCH] feat(compiler): Determine FHE circuit constraints instead of using default values This replaces the default FHE circuit constrains (maximum encrypted integer width of 7 bits and a Minimal Arithmetic Noise Padding of 10 with the results of the `MaxMANP` pass, which determines these values automatically from the input program. Since the maximum encrypted integer width and the maximum value for the Minimal Arithmetic Noise Padding can only be derived from HLFHE operations, the circuit constraints are determined automatically by `zamacompiler` only if the option `--entry-dialect=hlfhe` was specified. For lower-level dialects, `zamacompiler` has been provided with the options `--assume-max-eint-precision=...` and `--assume-max-manp=...` that allow a user to specify the values for the maximum required precision and maximum values for the Minimal Arithmetic Noise Padding. --- .../include/zamalang/Support/CompilerEngine.h | 4 +- compiler/include/zamalang/Support/Pipeline.h | 4 + compiler/lib/Support/CompilerEngine.cpp | 50 +++++- compiler/lib/Support/Pipeline.cpp | 46 +++++ compiler/src/main.cpp | 165 ++++++++++++++---- .../Conversion/MidLFHEToLowLFHE/add_glwe.mlir | 2 +- .../MidLFHEToLowLFHE/add_glwe_int.mlir | 2 +- .../MidLFHEToLowLFHE/apply_lookup_table.mlir | 2 +- .../apply_lookup_table_cst.mlir | 2 +- .../MidLFHEToLowLFHE/mul_glwe_int.mlir | 2 +- .../MidLFHEToLowLFHE/sub_int_glwe.mlir | 2 +- .../tests/unittest/end_to_end_jit_test.cc | 28 +-- 12 files changed, 250 insertions(+), 59 deletions(-) diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index 047a33dc8..f7dbc981f 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -20,7 +20,9 @@ public: } // Compile an mlir programs from it's textual representation. - llvm::Error compile(std::string mlirStr); + llvm::Error compile( + std::string mlirStr, + llvm::Optional overrideConstraints = {}); // Build the jit lambda argument. llvm::Expected> buildArgument(); diff --git a/compiler/include/zamalang/Support/Pipeline.h b/compiler/include/zamalang/Support/Pipeline.h index 7d8dcd803..bdd921a0b 100644 --- a/compiler/include/zamalang/Support/Pipeline.h +++ b/compiler/include/zamalang/Support/Pipeline.h @@ -9,9 +9,13 @@ namespace mlir { namespace zamalang { namespace pipeline { + mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context, mlir::ModuleOp &module, bool debug); +llvm::Expected> +getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module); + mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, bool verbose); diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 5b99e8dd4..a3c27430f 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -31,23 +31,55 @@ std::string CompilerEngine::getCompiledModule() { return os.str(); } -llvm::Error CompilerEngine::compile(std::string mlirStr) { +llvm::Error CompilerEngine::compile( + std::string mlirStr, + llvm::Optional overrideConstraints) { module_ref = mlir::parseSourceString(mlirStr, context); if (!module_ref) { return llvm::make_error("mlir parsing failed", llvm::inconvertibleErrorCode()); } - mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, - .p = 7}; - const mlir::zamalang::V0Parameter *parameter = - getV0Parameter(defaultGlobalFHECircuitConstraint); - - mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, - *parameter}; - mlir::ModuleOp module = module_ref.get(); + llvm::Optional fheConstraintsOpt = + overrideConstraints; + + if (!fheConstraintsOpt.hasValue()) { + llvm::Expected> + fheConstraintsOrErr = + mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(*context, + module); + + if (auto err = fheConstraintsOrErr.takeError()) + return std::move(err); + + if (!fheConstraintsOrErr.get().hasValue()) { + return llvm::make_error( + "Could not determine maximum required precision for encrypted " + "integers " + "and maximum value for the Minimal Arithmetic Noise Padding", + llvm::inconvertibleErrorCode()); + } + + fheConstraintsOpt = fheConstraintsOrErr.get(); + } + + mlir::zamalang::V0FHEConstraint fheConstraints = fheConstraintsOpt.getValue(); + const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); + + if (!parameter) { + std::string buffer; + llvm::raw_string_ostream strs(buffer); + strs << "Could not determine V0 parameters for 2-norm of " + << fheConstraints.norm2 << " and p of " << fheConstraints.p; + + return llvm::make_error(strs.str(), + llvm::inconvertibleErrorCode()); + } + + mlir::zamalang::V0FHEContext fheContext{fheConstraints, *parameter}; + // Lower to MLIR Std if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext, false) diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 07ed73c7d..f606bd409 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -13,6 +13,7 @@ #include #include #include +#include namespace mlir { namespace zamalang { @@ -35,6 +36,51 @@ mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context, return pm.run(module); } +llvm::Expected> +getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) { + llvm::Optional oMax2norm; + llvm::Optional oMaxWidth; + + mlir::PassManager pm(&context); + + addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass()); + addPotentiallyNestedPass( + pm, mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP, + unsigned currMaxWidth) { + assert((uint64_t)currMaxWidth < std::numeric_limits::max() && + "Maximum width does not fit into size_t"); + + assert(sizeof(uint64_t) >= sizeof(size_t) && + currMaxMANP.ult(std::numeric_limits::max()) && + "Maximum MANP does not fit into size_t"); + + size_t manp = (size_t)currMaxMANP.getZExtValue(); + size_t width = (size_t)currMaxWidth; + + if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp) + oMax2norm.emplace(manp); + + if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width) + oMaxWidth.emplace(width); + })); + + if (pm.run(module.getOperation()).failed()) { + return llvm::make_error( + "Failed to determine the maximum Arithmetic Noise Padding and maximum" + "required precision", + llvm::inconvertibleErrorCode()); + } + + llvm::Optional ret; + + if (oMax2norm.hasValue() && oMaxWidth.hasValue()) { + ret = llvm::Optional( + {.norm2 = ceilLog2(oMax2norm.getValue()), .p = oMaxWidth.getValue()}); + } + + return ret; +} + mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, bool verbose) { mlir::PassManager pm(&context); diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 3df1c9873..afc0abf6a 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "mlir/IR/BuiltinOps.h" #include "zamalang/Conversion/Passes.h" @@ -41,6 +42,26 @@ enum Action { }; namespace cmdline { +class OptionalSizeTParser : public llvm::cl::parser> { +public: + OptionalSizeTParser(llvm::cl::Option &option) + : llvm::cl::parser>(option) {} + + bool parse(llvm::cl::Option &option, llvm::StringRef argName, + llvm::StringRef arg, llvm::Optional &value) { + size_t parsedVal; + std::istringstream iss(arg.str()); + + iss >> parsedVal; + + if (iss.fail()) + return option.error("Invalid value " + arg); + + value.emplace(parsedVal); + + return false; + } +}; llvm::cl::list inputs(llvm::cl::Positional, llvm::cl::desc(""), @@ -126,6 +147,17 @@ llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); + +llvm::cl::opt, false, OptionalSizeTParser> + assumeMaxEintPrecision( + "assume-max-eint-precision", + llvm::cl::desc("Assume a maximum precision for encrypted integers")); + +llvm::cl::opt, false, OptionalSizeTParser> assumeMaxMANP( + "assume-max-manp", + llvm::cl::desc( + "Assume a maximum for the Minimum Arithmetic Noise Padding")); + }; // namespace cmdline std::function defaultOptPipeline = @@ -171,6 +203,64 @@ generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, return std::move(maybeKeySet.get()); } +llvm::Expected buildFHEContext( + llvm::Optional autoFHEConstraints, + llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP) { + if (!autoFHEConstraints.hasValue() && + (!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) { + return llvm::make_error( + "Maximum encrypted integer precision and maximum for the Minimal" + "Arithmetic Noise Passing are required, but were neither specified" + "explicitly nor determined automatically", + llvm::inconvertibleErrorCode()); + } + + mlir::zamalang::V0FHEConstraint fheConstraints{ + .norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue() + : autoFHEConstraints.getValue().norm2, + .p = overrideMaxEintPrecision.hasValue() + ? overrideMaxEintPrecision.getValue() + : autoFHEConstraints.getValue().p}; + + const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); + + if (!parameter) { + std::string buffer; + llvm::raw_string_ostream strs(buffer); + strs << "Could not determine V0 parameters for 2-norm of " + << fheConstraints.norm2 << " and p of " << fheConstraints.p; + + return llvm::make_error(strs.str(), + llvm::inconvertibleErrorCode()); + } + + return mlir::zamalang::V0FHEContext{fheConstraints, *parameter}; +} + +mlir::LogicalResult buildAssignFHEContext( + llvm::Optional &fheContext, + llvm::Optional autoFHEConstraints, + llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP) { + + if (fheContext.hasValue()) + return mlir::success(); + + llvm::Expected fheContextOrErr = + buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision, + overrideMaxMANP); + + if (auto err = fheContextOrErr.takeError()) { + mlir::zamalang::log_error() << err; + return mlir::failure(); + } + + fheContext.emplace(fheContextOrErr.get()); + + return mlir::success(); +} + // Process a single source buffer // // The parameter `entryDialect` must specify the FHE dialect to which @@ -190,6 +280,12 @@ generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, // `entryDialect` and `action` does not involve any MidlFHE // manipulation, this parameter does not have any effect. // +// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if +// set, override the values for the maximum required precision of +// encrypted integers and the maximum value for the Minimum Arithmetic +// Noise Padding otherwise determined automatically if the entry +// dialect is HLFHE.. +// // If `verifyDiagnostics` is `true`, the procedure only checks if the // diagnostic messages provided in the source buffer using // `expected-error` are produced. If `verifyDiagnostics` is `false`, @@ -204,8 +300,9 @@ mlir::LogicalResult processInputBuffer( mlir::MLIRContext &context, std::unique_ptr buffer, enum EntryDialect entryDialect, enum Action action, const std::string &jitFuncName, llvm::ArrayRef jitArgs, - bool parametrizeMidlHFE, bool verifyDiagnostics, bool verbose, - llvm::raw_ostream &os) { + bool parametrizeMidlHFE, llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP, bool verifyDiagnostics, + bool verbose, llvm::raw_ostream &os) { llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); @@ -213,28 +310,11 @@ mlir::LogicalResult processInputBuffer( &context); mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context); - // This is temporary until we have the high-level verification pass - // determining these parameters automatically - mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, - .p = 7}; + llvm::Optional fheConstraints; + llvm::Optional fheContext; std::unique_ptr keySet = nullptr; - const mlir::zamalang::V0Parameter *parameter = - getV0Parameter(defaultGlobalFHECircuitConstraint); - - if (!parameter) { - mlir::zamalang::log_error() - << "Could not determine V0 parameters for 2-norm of " - << defaultGlobalFHECircuitConstraint.norm2 << " and p of " - << defaultGlobalFHECircuitConstraint.p << "\n"; - - return mlir::failure(); - } - - mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, - *parameter}; - if (verbose) context.disableMultithreading(); @@ -258,14 +338,25 @@ mlir::LogicalResult processInputBuffer( // points from the pipeline. switch (entryDialect) { case EntryDialect::HLFHE: - if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false) - .failed()) { - return mlir::failure(); - } - if (action == Action::DUMP_HLFHE_MANP) { + if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false) + .failed()) { + return mlir::failure(); + } + module.print(os); return mlir::success(); + } else { + llvm::Expected> + fheConstraintsOrErr = + mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context, + module); + if (auto err = fheConstraintsOrErr.takeError()) { + mlir::zamalang::log_error() << err; + return mlir::failure(); + } else { + fheConstraints = fheConstraintsOrErr.get(); + } } if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) @@ -279,8 +370,14 @@ mlir::LogicalResult processInputBuffer( return mlir::success(); } + if (buildAssignFHEContext(fheContext, fheConstraints, + overrideMaxEintPrecision, overrideMaxMANP) + .failed()) { + return mlir::failure(); + } + if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( - context, module, fheContext, parametrizeMidlHFE) + context, module, fheContext.getValue(), parametrizeMidlHFE) .failed()) return mlir::failure(); @@ -300,7 +397,13 @@ mlir::LogicalResult processInputBuffer( module.print(os); return mlir::success(); } else if (action == Action::JIT_INVOKE) { - keySet = generateKeySet(module, fheContext, jitFuncName); + if (buildAssignFHEContext(fheContext, fheConstraints, + overrideMaxEintPrecision, overrideMaxMANP) + .failed()) { + return mlir::failure(); + } + + keySet = generateKeySet(module, fheContext.getValue(), jitFuncName); } if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module, @@ -422,8 +525,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { return processInputBuffer( context, std::move(inputBuffer), cmdline::entryDialect, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, - cmdline::parametrizeMidLFHE, cmdline::verifyDiagnostics, - cmdline::verbose, os); + cmdline::parametrizeMidLFHE, + cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, + cmdline::verifyDiagnostics, cmdline::verbose, os); }, output->os()))) return mlir::failure(); @@ -431,6 +535,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { return processInputBuffer( context, std::move(file), cmdline::entryDialect, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, + cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, cmdline::verbose, output->os()); } } diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir index 7814e2064..a3ec7b838 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir index ee1d3e53a..42b77bcf4 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index 926c788ce..468b02001 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index adc811bef..e0d691a8f 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir index c13c73353..24fc2ff96 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir index a40db4b79..1359aaa1e 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index f7791facc..e49516602 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -2,6 +2,8 @@ #include "zamalang/Support/CompilerEngine.h" +mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7}; + #define ASSERT_LLVM_ERROR(err) \ if (err) { \ llvm::errs() << "error: " << std::move(err) << "\n"; \ @@ -31,7 +33,7 @@ func @main(%t: tensor<10xi64>, %i: index) -> i64{ return %c : i64 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF, 0, @@ -68,7 +70,7 @@ func @main(%t: tensor<10xi32>, %i: index) -> i32{ return %c : i32 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90, 197864, 698735, 72132, 87474, 42}; @@ -97,7 +99,7 @@ func @main(%t: tensor<10xi16>, %i: index) -> i16{ return %c : i16 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227, 63269, 36435, 52380, 7401, 13313}; @@ -126,7 +128,7 @@ func @main(%t: tensor<10xi8>, %i: index) -> i8{ return %c : i8 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93}; for (size_t i = 0; i < size; i++) { @@ -154,7 +156,7 @@ func @main(%t: tensor<10xi5>, %i: index) -> i5{ return %c : i5 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; for (size_t i = 0; i < size; i++) { @@ -182,7 +184,7 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{ return %c : i1 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0}; for (size_t i = 0; i < size; i++) { @@ -210,7 +212,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{ return %c : !HLFHE.eint<5> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; for (size_t i = 0; i < size; i++) { @@ -240,7 +242,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5 return %c : !HLFHE.eint<5> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; for (size_t i = 0; i < size; i++) { @@ -273,7 +275,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{ return %c : index } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; auto maybeArgument = engine.buildArgument(); @@ -297,7 +299,7 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> { return %t: tensor<1x!HLFHE.eint<5>> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); auto maybeArgument = engine.buildArgument(); ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); @@ -327,7 +329,7 @@ func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> { return %out: tensor<3x!HLFHE.eint<5>> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); auto maybeArgument = engine.buildArgument(); ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); @@ -364,7 +366,7 @@ func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.ei return %ret : !HLFHE.eint<7> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); auto maybeArgument = engine.buildArgument(); ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); @@ -459,4 +461,4 @@ func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { ASSERT_TRUE((bool)maybeResult); result = maybeResult.get(); ASSERT_EQ(result, 6); -} \ No newline at end of file +}