diff --git a/compiler/include/zamalang/Support/CompilerEngine.h b/compiler/include/zamalang/Support/CompilerEngine.h index 047a33dc8..f7dbc981f 100644 --- a/compiler/include/zamalang/Support/CompilerEngine.h +++ b/compiler/include/zamalang/Support/CompilerEngine.h @@ -20,7 +20,9 @@ public: } // Compile an mlir programs from it's textual representation. - llvm::Error compile(std::string mlirStr); + llvm::Error compile( + std::string mlirStr, + llvm::Optional overrideConstraints = {}); // Build the jit lambda argument. llvm::Expected> buildArgument(); diff --git a/compiler/include/zamalang/Support/Pipeline.h b/compiler/include/zamalang/Support/Pipeline.h index 7d8dcd803..bdd921a0b 100644 --- a/compiler/include/zamalang/Support/Pipeline.h +++ b/compiler/include/zamalang/Support/Pipeline.h @@ -9,9 +9,13 @@ namespace mlir { namespace zamalang { namespace pipeline { + mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context, mlir::ModuleOp &module, bool debug); +llvm::Expected> +getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module); + mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, bool verbose); diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 5b99e8dd4..a3c27430f 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -31,23 +31,55 @@ std::string CompilerEngine::getCompiledModule() { return os.str(); } -llvm::Error CompilerEngine::compile(std::string mlirStr) { +llvm::Error CompilerEngine::compile( + std::string mlirStr, + llvm::Optional overrideConstraints) { module_ref = mlir::parseSourceString(mlirStr, context); if (!module_ref) { return llvm::make_error("mlir parsing failed", llvm::inconvertibleErrorCode()); } - mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, - .p = 7}; - const mlir::zamalang::V0Parameter *parameter = - getV0Parameter(defaultGlobalFHECircuitConstraint); - - mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, - *parameter}; - mlir::ModuleOp module = module_ref.get(); + llvm::Optional fheConstraintsOpt = + overrideConstraints; + + if (!fheConstraintsOpt.hasValue()) { + llvm::Expected> + fheConstraintsOrErr = + mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(*context, + module); + + if (auto err = fheConstraintsOrErr.takeError()) + return std::move(err); + + if (!fheConstraintsOrErr.get().hasValue()) { + return llvm::make_error( + "Could not determine maximum required precision for encrypted " + "integers " + "and maximum value for the Minimal Arithmetic Noise Padding", + llvm::inconvertibleErrorCode()); + } + + fheConstraintsOpt = fheConstraintsOrErr.get(); + } + + mlir::zamalang::V0FHEConstraint fheConstraints = fheConstraintsOpt.getValue(); + const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); + + if (!parameter) { + std::string buffer; + llvm::raw_string_ostream strs(buffer); + strs << "Could not determine V0 parameters for 2-norm of " + << fheConstraints.norm2 << " and p of " << fheConstraints.p; + + return llvm::make_error(strs.str(), + llvm::inconvertibleErrorCode()); + } + + mlir::zamalang::V0FHEContext fheContext{fheConstraints, *parameter}; + // Lower to MLIR Std if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext, false) diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 07ed73c7d..f606bd409 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -13,6 +13,7 @@ #include #include #include +#include namespace mlir { namespace zamalang { @@ -35,6 +36,51 @@ mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context, return pm.run(module); } +llvm::Expected> +getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) { + llvm::Optional oMax2norm; + llvm::Optional oMaxWidth; + + mlir::PassManager pm(&context); + + addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass()); + addPotentiallyNestedPass( + pm, mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP, + unsigned currMaxWidth) { + assert((uint64_t)currMaxWidth < std::numeric_limits::max() && + "Maximum width does not fit into size_t"); + + assert(sizeof(uint64_t) >= sizeof(size_t) && + currMaxMANP.ult(std::numeric_limits::max()) && + "Maximum MANP does not fit into size_t"); + + size_t manp = (size_t)currMaxMANP.getZExtValue(); + size_t width = (size_t)currMaxWidth; + + if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp) + oMax2norm.emplace(manp); + + if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width) + oMaxWidth.emplace(width); + })); + + if (pm.run(module.getOperation()).failed()) { + return llvm::make_error( + "Failed to determine the maximum Arithmetic Noise Padding and maximum" + "required precision", + llvm::inconvertibleErrorCode()); + } + + llvm::Optional ret; + + if (oMax2norm.hasValue() && oMaxWidth.hasValue()) { + ret = llvm::Optional( + {.norm2 = ceilLog2(oMax2norm.getValue()), .p = oMaxWidth.getValue()}); + } + + return ret; +} + mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, bool verbose) { mlir::PassManager pm(&context); diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 3df1c9873..afc0abf6a 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "mlir/IR/BuiltinOps.h" #include "zamalang/Conversion/Passes.h" @@ -41,6 +42,26 @@ enum Action { }; namespace cmdline { +class OptionalSizeTParser : public llvm::cl::parser> { +public: + OptionalSizeTParser(llvm::cl::Option &option) + : llvm::cl::parser>(option) {} + + bool parse(llvm::cl::Option &option, llvm::StringRef argName, + llvm::StringRef arg, llvm::Optional &value) { + size_t parsedVal; + std::istringstream iss(arg.str()); + + iss >> parsedVal; + + if (iss.fail()) + return option.error("Invalid value " + arg); + + value.emplace(parsedVal); + + return false; + } +}; llvm::cl::list inputs(llvm::cl::Positional, llvm::cl::desc(""), @@ -126,6 +147,17 @@ llvm::cl::list jitArgs("jit-args", llvm::cl::desc("Value of arguments to pass to the main func"), llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore); + +llvm::cl::opt, false, OptionalSizeTParser> + assumeMaxEintPrecision( + "assume-max-eint-precision", + llvm::cl::desc("Assume a maximum precision for encrypted integers")); + +llvm::cl::opt, false, OptionalSizeTParser> assumeMaxMANP( + "assume-max-manp", + llvm::cl::desc( + "Assume a maximum for the Minimum Arithmetic Noise Padding")); + }; // namespace cmdline std::function defaultOptPipeline = @@ -171,6 +203,64 @@ generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, return std::move(maybeKeySet.get()); } +llvm::Expected buildFHEContext( + llvm::Optional autoFHEConstraints, + llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP) { + if (!autoFHEConstraints.hasValue() && + (!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) { + return llvm::make_error( + "Maximum encrypted integer precision and maximum for the Minimal" + "Arithmetic Noise Passing are required, but were neither specified" + "explicitly nor determined automatically", + llvm::inconvertibleErrorCode()); + } + + mlir::zamalang::V0FHEConstraint fheConstraints{ + .norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue() + : autoFHEConstraints.getValue().norm2, + .p = overrideMaxEintPrecision.hasValue() + ? overrideMaxEintPrecision.getValue() + : autoFHEConstraints.getValue().p}; + + const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints); + + if (!parameter) { + std::string buffer; + llvm::raw_string_ostream strs(buffer); + strs << "Could not determine V0 parameters for 2-norm of " + << fheConstraints.norm2 << " and p of " << fheConstraints.p; + + return llvm::make_error(strs.str(), + llvm::inconvertibleErrorCode()); + } + + return mlir::zamalang::V0FHEContext{fheConstraints, *parameter}; +} + +mlir::LogicalResult buildAssignFHEContext( + llvm::Optional &fheContext, + llvm::Optional autoFHEConstraints, + llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP) { + + if (fheContext.hasValue()) + return mlir::success(); + + llvm::Expected fheContextOrErr = + buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision, + overrideMaxMANP); + + if (auto err = fheContextOrErr.takeError()) { + mlir::zamalang::log_error() << err; + return mlir::failure(); + } + + fheContext.emplace(fheContextOrErr.get()); + + return mlir::success(); +} + // Process a single source buffer // // The parameter `entryDialect` must specify the FHE dialect to which @@ -190,6 +280,12 @@ generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext, // `entryDialect` and `action` does not involve any MidlFHE // manipulation, this parameter does not have any effect. // +// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if +// set, override the values for the maximum required precision of +// encrypted integers and the maximum value for the Minimum Arithmetic +// Noise Padding otherwise determined automatically if the entry +// dialect is HLFHE.. +// // If `verifyDiagnostics` is `true`, the procedure only checks if the // diagnostic messages provided in the source buffer using // `expected-error` are produced. If `verifyDiagnostics` is `false`, @@ -204,8 +300,9 @@ mlir::LogicalResult processInputBuffer( mlir::MLIRContext &context, std::unique_ptr buffer, enum EntryDialect entryDialect, enum Action action, const std::string &jitFuncName, llvm::ArrayRef jitArgs, - bool parametrizeMidlHFE, bool verifyDiagnostics, bool verbose, - llvm::raw_ostream &os) { + bool parametrizeMidlHFE, llvm::Optional overrideMaxEintPrecision, + llvm::Optional overrideMaxMANP, bool verifyDiagnostics, + bool verbose, llvm::raw_ostream &os) { llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc()); @@ -213,28 +310,11 @@ mlir::LogicalResult processInputBuffer( &context); mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context); - // This is temporary until we have the high-level verification pass - // determining these parameters automatically - mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10, - .p = 7}; + llvm::Optional fheConstraints; + llvm::Optional fheContext; std::unique_ptr keySet = nullptr; - const mlir::zamalang::V0Parameter *parameter = - getV0Parameter(defaultGlobalFHECircuitConstraint); - - if (!parameter) { - mlir::zamalang::log_error() - << "Could not determine V0 parameters for 2-norm of " - << defaultGlobalFHECircuitConstraint.norm2 << " and p of " - << defaultGlobalFHECircuitConstraint.p << "\n"; - - return mlir::failure(); - } - - mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint, - *parameter}; - if (verbose) context.disableMultithreading(); @@ -258,14 +338,25 @@ mlir::LogicalResult processInputBuffer( // points from the pipeline. switch (entryDialect) { case EntryDialect::HLFHE: - if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false) - .failed()) { - return mlir::failure(); - } - if (action == Action::DUMP_HLFHE_MANP) { + if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false) + .failed()) { + return mlir::failure(); + } + module.print(os); return mlir::success(); + } else { + llvm::Expected> + fheConstraintsOrErr = + mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context, + module); + if (auto err = fheConstraintsOrErr.takeError()) { + mlir::zamalang::log_error() << err; + return mlir::failure(); + } else { + fheConstraints = fheConstraintsOrErr.get(); + } } if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose) @@ -279,8 +370,14 @@ mlir::LogicalResult processInputBuffer( return mlir::success(); } + if (buildAssignFHEContext(fheContext, fheConstraints, + overrideMaxEintPrecision, overrideMaxMANP) + .failed()) { + return mlir::failure(); + } + if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE( - context, module, fheContext, parametrizeMidlHFE) + context, module, fheContext.getValue(), parametrizeMidlHFE) .failed()) return mlir::failure(); @@ -300,7 +397,13 @@ mlir::LogicalResult processInputBuffer( module.print(os); return mlir::success(); } else if (action == Action::JIT_INVOKE) { - keySet = generateKeySet(module, fheContext, jitFuncName); + if (buildAssignFHEContext(fheContext, fheConstraints, + overrideMaxEintPrecision, overrideMaxMANP) + .failed()) { + return mlir::failure(); + } + + keySet = generateKeySet(module, fheContext.getValue(), jitFuncName); } if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module, @@ -422,8 +525,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { return processInputBuffer( context, std::move(inputBuffer), cmdline::entryDialect, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, - cmdline::parametrizeMidLFHE, cmdline::verifyDiagnostics, - cmdline::verbose, os); + cmdline::parametrizeMidLFHE, + cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, + cmdline::verifyDiagnostics, cmdline::verbose, os); }, output->os()))) return mlir::failure(); @@ -431,6 +535,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { return processInputBuffer( context, std::move(file), cmdline::entryDialect, cmdline::action, cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE, + cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP, cmdline::verifyDiagnostics, cmdline::verbose, output->os()); } } diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir index 7814e2064..a3ec7b838 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7> func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir index ee1d3e53a..42b77bcf4 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/add_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir index 926c788ce..468b02001 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4> func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir index adc811bef..e0d691a8f 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/apply_lookup_table_cst.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4> func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir index c13c73353..24fc2ff96 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/mul_glwe_int.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir index a40db4b79..1359aaa1e 100644 --- a/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir +++ b/compiler/tests/Conversion/MidLFHEToLowLFHE/sub_int_glwe.mlir @@ -1,4 +1,4 @@ -// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s +// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s // CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7> func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> { diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index f7791facc..e49516602 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -2,6 +2,8 @@ #include "zamalang/Support/CompilerEngine.h" +mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7}; + #define ASSERT_LLVM_ERROR(err) \ if (err) { \ llvm::errs() << "error: " << std::move(err) << "\n"; \ @@ -31,7 +33,7 @@ func @main(%t: tensor<10xi64>, %i: index) -> i64{ return %c : i64 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF, 0, @@ -68,7 +70,7 @@ func @main(%t: tensor<10xi32>, %i: index) -> i32{ return %c : i32 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90, 197864, 698735, 72132, 87474, 42}; @@ -97,7 +99,7 @@ func @main(%t: tensor<10xi16>, %i: index) -> i16{ return %c : i16 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227, 63269, 36435, 52380, 7401, 13313}; @@ -126,7 +128,7 @@ func @main(%t: tensor<10xi8>, %i: index) -> i8{ return %c : i8 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93}; for (size_t i = 0; i < size; i++) { @@ -154,7 +156,7 @@ func @main(%t: tensor<10xi5>, %i: index) -> i5{ return %c : i5 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; for (size_t i = 0; i < size; i++) { @@ -182,7 +184,7 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{ return %c : i1 } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0}; for (size_t i = 0; i < size; i++) { @@ -210,7 +212,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{ return %c : !HLFHE.eint<5> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; for (size_t i = 0; i < size; i++) { @@ -240,7 +242,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5 return %c : !HLFHE.eint<5> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; for (size_t i = 0; i < size; i++) { @@ -273,7 +275,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{ return %c : index } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); const size_t size = 10; uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7}; auto maybeArgument = engine.buildArgument(); @@ -297,7 +299,7 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> { return %t: tensor<1x!HLFHE.eint<5>> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); auto maybeArgument = engine.buildArgument(); ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); @@ -327,7 +329,7 @@ func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> { return %out: tensor<3x!HLFHE.eint<5>> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); auto maybeArgument = engine.buildArgument(); ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); @@ -364,7 +366,7 @@ func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.ei return %ret : !HLFHE.eint<7> } )XXX"; - ASSERT_LLVM_ERROR(engine.compile(mlirStr)); + ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints)); auto maybeArgument = engine.buildArgument(); ASSERT_LLVM_ERROR(maybeArgument.takeError()); auto argument = std::move(maybeArgument.get()); @@ -459,4 +461,4 @@ func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> { ASSERT_TRUE((bool)maybeResult); result = maybeResult.get(); ASSERT_EQ(result, 6); -} \ No newline at end of file +}