feat(compiler): Determine FHE circuit constraints instead of using default values

This replaces the default FHE circuit constrains (maximum encrypted
integer width of 7 bits and a Minimal Arithmetic Noise Padding of 10
with the results of the `MaxMANP` pass, which determines these values
automatically from the input program.

Since the maximum encrypted integer width and the maximum value for
the Minimal Arithmetic Noise Padding can only be derived from HLFHE
operations, the circuit constraints are determined automatically by
`zamacompiler` only if the option `--entry-dialect=hlfhe` was
specified.

For lower-level dialects, `zamacompiler` has been provided with the
options `--assume-max-eint-precision=...` and `--assume-max-manp=...`
that allow a user to specify the values for the maximum required
precision and maximum values for the Minimal Arithmetic Noise Padding.
This commit is contained in:
Andi Drebes
2021-09-24 23:43:41 +02:00
committed by Quentin Bourgerie
parent 2ed1720234
commit 2acfa63eb7
12 changed files with 250 additions and 59 deletions

View File

@@ -20,7 +20,9 @@ public:
}
// Compile an mlir programs from it's textual representation.
llvm::Error compile(std::string mlirStr);
llvm::Error compile(
std::string mlirStr,
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints = {});
// Build the jit lambda argument.
llvm::Expected<std::unique_ptr<JITLambda::Argument>> buildArgument();

View File

@@ -9,9 +9,13 @@
namespace mlir {
namespace zamalang {
namespace pipeline {
mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool debug);
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module);
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose);

View File

@@ -31,23 +31,55 @@ std::string CompilerEngine::getCompiledModule() {
return os.str();
}
llvm::Error CompilerEngine::compile(std::string mlirStr) {
llvm::Error CompilerEngine::compile(
std::string mlirStr,
llvm::Optional<mlir::zamalang::V0FHEConstraint> overrideConstraints) {
module_ref = mlir::parseSourceString(mlirStr, context);
if (!module_ref) {
return llvm::make_error<llvm::StringError>("mlir parsing failed",
llvm::inconvertibleErrorCode());
}
mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10,
.p = 7};
const mlir::zamalang::V0Parameter *parameter =
getV0Parameter(defaultGlobalFHECircuitConstraint);
mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint,
*parameter};
mlir::ModuleOp module = module_ref.get();
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraintsOpt =
overrideConstraints;
if (!fheConstraintsOpt.hasValue()) {
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
fheConstraintsOrErr =
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(*context,
module);
if (auto err = fheConstraintsOrErr.takeError())
return std::move(err);
if (!fheConstraintsOrErr.get().hasValue()) {
return llvm::make_error<llvm::StringError>(
"Could not determine maximum required precision for encrypted "
"integers "
"and maximum value for the Minimal Arithmetic Noise Padding",
llvm::inconvertibleErrorCode());
}
fheConstraintsOpt = fheConstraintsOrErr.get();
}
mlir::zamalang::V0FHEConstraint fheConstraints = fheConstraintsOpt.getValue();
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
if (!parameter) {
std::string buffer;
llvm::raw_string_ostream strs(buffer);
strs << "Could not determine V0 parameters for 2-norm of "
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
return llvm::make_error<llvm::StringError>(strs.str(),
llvm::inconvertibleErrorCode());
}
mlir::zamalang::V0FHEContext fheContext{fheConstraints, *parameter};
// Lower to MLIR Std
if (mlir::zamalang::pipeline::lowerHLFHEToStd(*context, module, fheContext,
false)

View File

@@ -13,6 +13,7 @@
#include <zamalang/Dialect/HLFHE/Analysis/MANP.h>
#include <zamalang/Support/Pipeline.h>
#include <zamalang/Support/logging.h>
#include <zamalang/Support/math.h>
namespace mlir {
namespace zamalang {
@@ -35,6 +36,51 @@ mlir::LogicalResult invokeMANPPass(mlir::MLIRContext &context,
return pm.run(module);
}
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
getFHEConstraintsFromHLFHE(mlir::MLIRContext &context, mlir::ModuleOp &module) {
llvm::Optional<size_t> oMax2norm;
llvm::Optional<size_t> oMaxWidth;
mlir::PassManager pm(&context);
addPotentiallyNestedPass(pm, mlir::zamalang::createMANPPass());
addPotentiallyNestedPass(
pm, mlir::zamalang::createMaxMANPPass([&](const llvm::APInt &currMaxMANP,
unsigned currMaxWidth) {
assert((uint64_t)currMaxWidth < std::numeric_limits<size_t>::max() &&
"Maximum width does not fit into size_t");
assert(sizeof(uint64_t) >= sizeof(size_t) &&
currMaxMANP.ult(std::numeric_limits<size_t>::max()) &&
"Maximum MANP does not fit into size_t");
size_t manp = (size_t)currMaxMANP.getZExtValue();
size_t width = (size_t)currMaxWidth;
if (!oMax2norm.hasValue() || oMax2norm.getValue() < manp)
oMax2norm.emplace(manp);
if (!oMaxWidth.hasValue() || oMaxWidth.getValue() < width)
oMaxWidth.emplace(width);
}));
if (pm.run(module.getOperation()).failed()) {
return llvm::make_error<llvm::StringError>(
"Failed to determine the maximum Arithmetic Noise Padding and maximum"
"required precision",
llvm::inconvertibleErrorCode());
}
llvm::Optional<mlir::zamalang::V0FHEConstraint> ret;
if (oMax2norm.hasValue() && oMaxWidth.hasValue()) {
ret = llvm::Optional<mlir::zamalang::V0FHEConstraint>(
{.norm2 = ceilLog2(oMax2norm.getValue()), .p = oMaxWidth.getValue()});
}
return ret;
}
mlir::LogicalResult lowerHLFHEToMidLFHE(mlir::MLIRContext &context,
mlir::ModuleOp &module, bool verbose) {
mlir::PassManager pm(&context);

View File

@@ -11,6 +11,7 @@
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include <sstream>
#include "mlir/IR/BuiltinOps.h"
#include "zamalang/Conversion/Passes.h"
@@ -41,6 +42,26 @@ enum Action {
};
namespace cmdline {
class OptionalSizeTParser : public llvm::cl::parser<llvm::Optional<size_t>> {
public:
OptionalSizeTParser(llvm::cl::Option &option)
: llvm::cl::parser<llvm::Optional<size_t>>(option) {}
bool parse(llvm::cl::Option &option, llvm::StringRef argName,
llvm::StringRef arg, llvm::Optional<size_t> &value) {
size_t parsedVal;
std::istringstream iss(arg.str());
iss >> parsedVal;
if (iss.fail())
return option.error("Invalid value " + arg);
value.emplace(parsedVal);
return false;
}
};
llvm::cl::list<std::string> inputs(llvm::cl::Positional,
llvm::cl::desc("<Input files>"),
@@ -126,6 +147,17 @@ llvm::cl::list<uint64_t>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("argument(uint64)"), llvm::cl::ZeroOrMore);
llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser>
assumeMaxEintPrecision(
"assume-max-eint-precision",
llvm::cl::desc("Assume a maximum precision for encrypted integers"));
llvm::cl::opt<llvm::Optional<size_t>, false, OptionalSizeTParser> assumeMaxMANP(
"assume-max-manp",
llvm::cl::desc(
"Assume a maximum for the Minimum Arithmetic Noise Padding"));
}; // namespace cmdline
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
@@ -171,6 +203,64 @@ generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
return std::move(maybeKeySet.get());
}
llvm::Expected<mlir::zamalang::V0FHEContext> buildFHEContext(
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP) {
if (!autoFHEConstraints.hasValue() &&
(!overrideMaxMANP.hasValue() || !overrideMaxEintPrecision.hasValue())) {
return llvm::make_error<llvm::StringError>(
"Maximum encrypted integer precision and maximum for the Minimal"
"Arithmetic Noise Passing are required, but were neither specified"
"explicitly nor determined automatically",
llvm::inconvertibleErrorCode());
}
mlir::zamalang::V0FHEConstraint fheConstraints{
.norm2 = overrideMaxMANP.hasValue() ? overrideMaxMANP.getValue()
: autoFHEConstraints.getValue().norm2,
.p = overrideMaxEintPrecision.hasValue()
? overrideMaxEintPrecision.getValue()
: autoFHEConstraints.getValue().p};
const mlir::zamalang::V0Parameter *parameter = getV0Parameter(fheConstraints);
if (!parameter) {
std::string buffer;
llvm::raw_string_ostream strs(buffer);
strs << "Could not determine V0 parameters for 2-norm of "
<< fheConstraints.norm2 << " and p of " << fheConstraints.p;
return llvm::make_error<llvm::StringError>(strs.str(),
llvm::inconvertibleErrorCode());
}
return mlir::zamalang::V0FHEContext{fheConstraints, *parameter};
}
mlir::LogicalResult buildAssignFHEContext(
llvm::Optional<mlir::zamalang::V0FHEContext> &fheContext,
llvm::Optional<mlir::zamalang::V0FHEConstraint> autoFHEConstraints,
llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP) {
if (fheContext.hasValue())
return mlir::success();
llvm::Expected<mlir::zamalang::V0FHEContext> fheContextOrErr =
buildFHEContext(autoFHEConstraints, overrideMaxEintPrecision,
overrideMaxMANP);
if (auto err = fheContextOrErr.takeError()) {
mlir::zamalang::log_error() << err;
return mlir::failure();
}
fheContext.emplace(fheContextOrErr.get());
return mlir::success();
}
// Process a single source buffer
//
// The parameter `entryDialect` must specify the FHE dialect to which
@@ -190,6 +280,12 @@ generateKeySet(mlir::ModuleOp &module, mlir::zamalang::V0FHEContext &fheContext,
// `entryDialect` and `action` does not involve any MidlFHE
// manipulation, this parameter does not have any effect.
//
// The parameters `overrideMaxEintPrecision` and `overrideMaxMANP`, if
// set, override the values for the maximum required precision of
// encrypted integers and the maximum value for the Minimum Arithmetic
// Noise Padding otherwise determined automatically if the entry
// dialect is HLFHE..
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
// diagnostic messages provided in the source buffer using
// `expected-error` are produced. If `verifyDiagnostics` is `false`,
@@ -204,8 +300,9 @@ mlir::LogicalResult processInputBuffer(
mlir::MLIRContext &context, std::unique_ptr<llvm::MemoryBuffer> buffer,
enum EntryDialect entryDialect, enum Action action,
const std::string &jitFuncName, llvm::ArrayRef<uint64_t> jitArgs,
bool parametrizeMidlHFE, bool verifyDiagnostics, bool verbose,
llvm::raw_ostream &os) {
bool parametrizeMidlHFE, llvm::Optional<size_t> overrideMaxEintPrecision,
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
bool verbose, llvm::raw_ostream &os) {
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
@@ -213,28 +310,11 @@ mlir::LogicalResult processInputBuffer(
&context);
mlir::OwningModuleRef moduleRef = mlir::parseSourceFile(sourceMgr, &context);
// This is temporary until we have the high-level verification pass
// determining these parameters automatically
mlir::zamalang::V0FHEConstraint defaultGlobalFHECircuitConstraint{.norm2 = 10,
.p = 7};
llvm::Optional<mlir::zamalang::V0FHEConstraint> fheConstraints;
llvm::Optional<mlir::zamalang::V0FHEContext> fheContext;
std::unique_ptr<mlir::zamalang::KeySet> keySet = nullptr;
const mlir::zamalang::V0Parameter *parameter =
getV0Parameter(defaultGlobalFHECircuitConstraint);
if (!parameter) {
mlir::zamalang::log_error()
<< "Could not determine V0 parameters for 2-norm of "
<< defaultGlobalFHECircuitConstraint.norm2 << " and p of "
<< defaultGlobalFHECircuitConstraint.p << "\n";
return mlir::failure();
}
mlir::zamalang::V0FHEContext fheContext{defaultGlobalFHECircuitConstraint,
*parameter};
if (verbose)
context.disableMultithreading();
@@ -258,14 +338,25 @@ mlir::LogicalResult processInputBuffer(
// points from the pipeline.
switch (entryDialect) {
case EntryDialect::HLFHE:
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
.failed()) {
return mlir::failure();
}
if (action == Action::DUMP_HLFHE_MANP) {
if (mlir::zamalang::pipeline::invokeMANPPass(context, module, false)
.failed()) {
return mlir::failure();
}
module.print(os);
return mlir::success();
} else {
llvm::Expected<llvm::Optional<mlir::zamalang::V0FHEConstraint>>
fheConstraintsOrErr =
mlir::zamalang::pipeline::getFHEConstraintsFromHLFHE(context,
module);
if (auto err = fheConstraintsOrErr.takeError()) {
mlir::zamalang::log_error() << err;
return mlir::failure();
} else {
fheConstraints = fheConstraintsOrErr.get();
}
}
if (mlir::zamalang::pipeline::lowerHLFHEToMidLFHE(context, module, verbose)
@@ -279,8 +370,14 @@ mlir::LogicalResult processInputBuffer(
return mlir::success();
}
if (buildAssignFHEContext(fheContext, fheConstraints,
overrideMaxEintPrecision, overrideMaxMANP)
.failed()) {
return mlir::failure();
}
if (mlir::zamalang::pipeline::lowerMidLFHEToLowLFHE(
context, module, fheContext, parametrizeMidlHFE)
context, module, fheContext.getValue(), parametrizeMidlHFE)
.failed())
return mlir::failure();
@@ -300,7 +397,13 @@ mlir::LogicalResult processInputBuffer(
module.print(os);
return mlir::success();
} else if (action == Action::JIT_INVOKE) {
keySet = generateKeySet(module, fheContext, jitFuncName);
if (buildAssignFHEContext(fheContext, fheConstraints,
overrideMaxEintPrecision, overrideMaxMANP)
.failed()) {
return mlir::failure();
}
keySet = generateKeySet(module, fheContext.getValue(), jitFuncName);
}
if (mlir::zamalang::pipeline::lowerStdToLLVMDialect(context, module,
@@ -422,8 +525,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
return processInputBuffer(
context, std::move(inputBuffer), cmdline::entryDialect,
cmdline::action, cmdline::jitFuncName, cmdline::jitArgs,
cmdline::parametrizeMidLFHE, cmdline::verifyDiagnostics,
cmdline::verbose, os);
cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, cmdline::verbose, os);
},
output->os())))
return mlir::failure();
@@ -431,6 +535,7 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
return processInputBuffer(
context, std::move(file), cmdline::entryDialect, cmdline::action,
cmdline::jitFuncName, cmdline::jitArgs, cmdline::parametrizeMidLFHE,
cmdline::assumeMaxEintPrecision, cmdline::assumeMaxMANP,
cmdline::verifyDiagnostics, cmdline::verbose, output->os());
}
}

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe(%arg0: !LowLFHE.lwe_ciphertext<2048,7>, %arg1: !LowLFHE.lwe_ciphertext<2048,7>) -> !LowLFHE.lwe_ciphertext<2048,7>
func @add_glwe(%arg0: !MidLFHE.glwe<{2048,1,64}{7}>, %arg1: !MidLFHE.glwe<{2048,1,64}{7}>) -> !MidLFHE.glwe<{2048,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @add_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @add_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table(%arg0: !LowLFHE.lwe_ciphertext<1024,4>, %arg1: tensor<16xi4>) -> !LowLFHE.lwe_ciphertext<1024,4>
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,1,64}{4}>, %arg1: tensor<16xi4>) -> !MidLFHE.glwe<{1024,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @apply_lookup_table_cst(%arg0: !LowLFHE.lwe_ciphertext<2048,4>) -> !LowLFHE.lwe_ciphertext<2048,4>
func @apply_lookup_table_cst(%arg0: !MidLFHE.glwe<{2048,1,64}{4}>) -> !MidLFHE.glwe<{2048,1,64}{4}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @mul_glwe_const_int(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @mul_glwe_const_int(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -1,4 +1,4 @@
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false %s 2>&1| FileCheck %s
// RUN: zamacompiler --entry-dialect=midlfhe --action=dump-lowlfhe --parametrize-midlfhe=false --assume-max-eint-precision=7 --assume-max-manp=10 %s 2>&1| FileCheck %s
// CHECK-LABEL: func @sub_const_int_glwe(%arg0: !LowLFHE.lwe_ciphertext<1024,7>) -> !LowLFHE.lwe_ciphertext<1024,7>
func @sub_const_int_glwe(%arg0: !MidLFHE.glwe<{1024,1,64}{7}>) -> !MidLFHE.glwe<{1024,1,64}{7}> {

View File

@@ -2,6 +2,8 @@
#include "zamalang/Support/CompilerEngine.h"
mlir::zamalang::V0FHEConstraint defaultV0Constraints = {.norm2 = 10, .p = 7};
#define ASSERT_LLVM_ERROR(err) \
if (err) { \
llvm::errs() << "error: " << std::move(err) << "\n"; \
@@ -31,7 +33,7 @@ func @main(%t: tensor<10xi64>, %i: index) -> i64{
return %c : i64
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint64_t t_arg[size]{0xFFFFFFFFFFFFFFFF,
0,
@@ -68,7 +70,7 @@ func @main(%t: tensor<10xi32>, %i: index) -> i32{
return %c : i32
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint32_t t_arg[size]{0xFFFFFFFF, 0, 8978, 2587490, 90,
197864, 698735, 72132, 87474, 42};
@@ -97,7 +99,7 @@ func @main(%t: tensor<10xi16>, %i: index) -> i16{
return %c : i16
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint16_t t_arg[size]{0xFFFF, 0, 59589, 47826, 16227,
63269, 36435, 52380, 7401, 13313};
@@ -126,7 +128,7 @@ func @main(%t: tensor<10xi8>, %i: index) -> i8{
return %c : i8
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{0xFF, 0, 120, 225, 14, 177, 131, 84, 174, 93};
for (size_t i = 0; i < size; i++) {
@@ -154,7 +156,7 @@ func @main(%t: tensor<10xi5>, %i: index) -> i5{
return %c : i5
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
@@ -182,7 +184,7 @@ func @main(%t: tensor<10xi1>, %i: index) -> i1{
return %c : i1
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{0, 0, 1, 0, 1, 1, 0, 1, 1, 0};
for (size_t i = 0; i < size; i++) {
@@ -210,7 +212,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index) -> !HLFHE.eint<5>{
return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
@@ -240,7 +242,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>, %i: index, %j: index) -> !HLFHE.eint<5
return %c : !HLFHE.eint<5>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
for (size_t i = 0; i < size; i++) {
@@ -273,7 +275,7 @@ func @main(%t: tensor<10x!HLFHE.eint<5>>) -> index{
return %c : index
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
const size_t size = 10;
uint8_t t_arg[size]{32, 0, 10, 25, 14, 25, 18, 28, 14, 7};
auto maybeArgument = engine.buildArgument();
@@ -297,7 +299,7 @@ func @main(%0: !HLFHE.eint<5>) -> tensor<1x!HLFHE.eint<5>> {
return %t: tensor<1x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
@@ -327,7 +329,7 @@ func @main(%in: tensor<2x!HLFHE.eint<5>>) -> tensor<3x!HLFHE.eint<5>> {
return %out: tensor<3x!HLFHE.eint<5>>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
@@ -364,7 +366,7 @@ func @main(%arg0: tensor<2x!HLFHE.eint<7>>, %arg1: tensor<2xi8>, %acc: !HLFHE.ei
return %ret : !HLFHE.eint<7>
}
)XXX";
ASSERT_LLVM_ERROR(engine.compile(mlirStr));
ASSERT_LLVM_ERROR(engine.compile(mlirStr, defaultV0Constraints));
auto maybeArgument = engine.buildArgument();
ASSERT_LLVM_ERROR(maybeArgument.takeError());
auto argument = std::move(maybeArgument.get());
@@ -459,4 +461,4 @@ func @main(%arg0: !HLFHE.eint<7>) -> !HLFHE.eint<7> {
ASSERT_TRUE((bool)maybeResult);
result = maybeResult.get();
ASSERT_EQ(result, 6);
}
}