From 2900c9a2a1076c2e6a454febca24745db34b8770 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 26 Oct 2021 21:49:12 +0200 Subject: [PATCH] feat(compiler): Add support of HLFHELinalg binary operators in MANP pass (close #172) --- compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 173 ++++++++++++++++++ .../Dialect/HLFHE/Analysis/MANP_linalg.mlir | 94 ++++++++++ .../unittest/end_to_end_jit_hlfhelinalg.cc | 106 +++++------ 3 files changed, 310 insertions(+), 63 deletions(-) create mode 100644 compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 151adc859..2c071919c 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -422,6 +423,161 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUMul(sqNorm, eNorm); } +// Calculates the squared Minimal Arithmetic Noise Padding of an +// `HLFHELinalg.add_eint_int` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHELinalg::AddEintIntOp op, + llvm::ArrayRef *> operandMANPs) { + + mlir::RankedTensorType op1Ty = + op->getOpOperand(1).get().getType().cast(); + + mlir::Type iTy = op1Ty.getElementType(); + + assert(iTy.isSignlessInteger() && + "Only additions with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + mlir::ConstantOp cstOp = llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + mlir::DenseIntElementsAttr denseVals = + cstOp ? cstOp->getAttrOfType("value") + : nullptr; + + if (denseVals) { + // For a constant operand use actual constant to calculate 2-norm + llvm::APInt maxCst = denseVals.getFlatValue(0); + for (int64_t i = 0; i < denseVals.getNumElements(); i++) { + llvm::APInt iCst = denseVals.getFlatValue(i); + if (maxCst.ult(iCst)) { + maxCst = iCst; + } + } + sqNorm = APIntWidthExtendUSq(maxCst); + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + sqNorm = conservativeIntNorm2Sq(iTy); + } + + return APIntWidthExtendUAdd(sqNorm, eNorm); +} + +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHELinalg::AddEintOp op, + llvm::ArrayRef *> operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue(); + + return APIntWidthExtendUAdd(a, b); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `HLFHELinalg.sub_int_eint` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHELinalg::SubIntEintOp op, + llvm::ArrayRef *> operandMANPs) { + + mlir::RankedTensorType op0Ty = + op->getOpOperand(0).get().getType().cast(); + + mlir::Type iTy = op0Ty.getElementType(); + + assert(iTy.isSignlessInteger() && + "Only subtractions with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + mlir::ConstantOp cstOp = llvm::dyn_cast_or_null( + op->getOpOperand(0).get().getDefiningOp()); + mlir::DenseIntElementsAttr denseVals = + cstOp ? cstOp->getAttrOfType("value") + : nullptr; + + if (denseVals) { + // For a constant operand use actual constant to calculate 2-norm + llvm::APInt maxCst = denseVals.getFlatValue(0); + for (int64_t i = 0; i < denseVals.getNumElements(); i++) { + llvm::APInt iCst = denseVals.getFlatValue(i); + if (maxCst.ult(iCst)) { + maxCst = iCst; + } + } + sqNorm = APIntWidthExtendUSq(maxCst); + } else { + // For dynamic plaintext operands conservatively assume that the integer has + // its maximum possible value + sqNorm = conservativeIntNorm2Sq(iTy); + } + return APIntWidthExtendUAdd(sqNorm, eNorm); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `HLFHE.mul_eint_int` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHELinalg::MulEintIntOp op, + llvm::ArrayRef *> operandMANPs) { + + mlir::RankedTensorType op0Ty = + op->getOpOperand(1).get().getType().cast(); + + mlir::Type iTy = op0Ty.getElementType(); + + assert(iTy.isSignlessInteger() && + "Only multiplications with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + mlir::ConstantOp cstOp = llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + mlir::DenseIntElementsAttr denseVals = + cstOp ? cstOp->getAttrOfType("value") + : nullptr; + + if (denseVals) { + // For a constant operand use actual constant to calculate 2-norm + llvm::APInt maxCst = denseVals.getFlatValue(0); + for (int64_t i = 0; i < denseVals.getNumElements(); i++) { + llvm::APInt iCst = denseVals.getFlatValue(i); + if (maxCst.ult(iCst)) { + maxCst = iCst; + } + } + sqNorm = APIntWidthExtendUSq(maxCst); + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + sqNorm = conservativeIntNorm2Sq(iTy); + } + + return APIntWidthExtendUMul(sqNorm, eNorm); +} + static llvm::APInt getSqMANP( mlir::tensor::ExtractOp op, llvm::ArrayRef *> operandMANPs) { @@ -508,6 +664,23 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; } + // HLFHELinalg Operators + else if (auto addEintIntOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(addEintIntOp, operands); + } else if (auto addEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(addEintOp, operands); + } else if (auto subIntEintOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(subIntEintOp, operands); + } else if (auto mulEintIntOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } // Tensor Operators // ExtractOp else if (auto extractOp = llvm::dyn_cast(op)) { diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir new file mode 100644 index 000000000..926041558 --- /dev/null +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir @@ -0,0 +1,94 @@ +// RUN: zamacompiler --passes MANP --action=dump-hlfhe --split-input-file %s 2>&1 | FileCheck %s + +func @single_cst_add_eint_int(%t: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> +{ + %cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> + + // CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} + +// ----- + +func @single_dyn_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> +{ + // CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.add_eint_int"(%e, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} + +// ----- + +func @single_add_eint(%e0: tensor<8x!HLFHE.eint<2>>, %e1: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> +{ + // CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.add_eint"(%e0, %e1) : (tensor<8x!HLFHE.eint<2>>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} + +// ----- + +func @single_cst_sub_int_eint(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> +{ + %cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> + + // CHECK: %[[ret:.*]] = "HLFHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} + +// ----- + +func @single_dyn_sub_int_eint(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> +{ + // CHECK: %[[ret:.*]] = "HLFHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} + +// ----- + +func @single_cst_mul_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> +{ + %cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> + + // %0 = "HLFHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} + +// ----- + +func @single_dyn_mul_eint_int(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> +{ + // CHECK: %[[ret:.*]] = "HLFHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.mul_eint_int"(%e, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + + return %0 : tensor<8x!HLFHE.eint<2>> +} + +// ----- + +func @chain_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>> +{ + %cst0 = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> + %cst1 = std.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi3> + %cst2 = std.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi3> + %cst3 = std.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi3> + // CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + // CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %1 = "HLFHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + // CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %2 = "HLFHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + // CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %3 = "HLFHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + return %3 : tensor<8x!HLFHE.eint<2>> +} diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index bf648425e..ab0eaed87 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -11,12 +11,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( // Returns the term to term addition of `%a0` with `%a1` - func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> { - %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> - return %res : tensor<4x!HLFHE.eint<4>> + func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> { + %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> + return %res : tensor<4x!HLFHE.eint<6>> } -)XXX", - "main", true); +)XXX"); std::vector a0{31, 6, 12, 9}; std::vector a1{32, 9, 2, 3}; @@ -43,12 +42,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_broadcast) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( // Returns the term to term addition of `%a0` with `%a1` - func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> { - %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> - return %res : tensor<4x4x4x!HLFHE.eint<4>> + func @main(%a0: tensor<4x1x4x!HLFHE.eint<5>>, %a1: tensor<1x4x4xi6>) -> tensor<4x4x4x!HLFHE.eint<5>> { + %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<5>>, tensor<1x4x4xi6>) -> tensor<4x4x4x!HLFHE.eint<5>> + return %res : tensor<4x4x4x!HLFHE.eint<5>> } -)XXX", - "main", true); +)XXX"); uint8_t a0[4][1][4]{ {{1, 2, 3, 4}}, {{5, 6, 7, 8}}, @@ -102,8 +100,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_column) { %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -149,8 +146,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_line) { %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -188,8 +184,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_line_missing_dim) { %res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } - )XXX", - "main", true); + )XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -228,12 +223,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( // Returns the term to term addition of `%a0` with `%a1` - func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> { - %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> - return %res : tensor<4x!HLFHE.eint<4>> + func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4x!HLFHE.eint<6>>) -> tensor<4x!HLFHE.eint<6>> { + %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4x!HLFHE.eint<6>>) -> tensor<4x!HLFHE.eint<6>> + return %res : tensor<4x!HLFHE.eint<6>> } -)XXX", - "main", true); +)XXX"); std::vector a0{31, 6, 12, 9}; std::vector a1{32, 9, 2, 3}; @@ -263,14 +257,13 @@ TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term_broadcast) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( // Returns the term to term addition of `%a0` with `%a1` - func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: - tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> { + func @main(%a0: tensor<4x1x4x!HLFHE.eint<5>>, %a1: + tensor<1x4x4x!HLFHE.eint<5>>) -> tensor<4x4x4x!HLFHE.eint<5>> { %res = "HLFHELinalg.add_eint"(%a0, %a1) : - (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) -> - tensor<4x4x4x!HLFHE.eint<4>> return %res : tensor<4x4x4x!HLFHE.eint<4>> + (tensor<4x1x4x!HLFHE.eint<5>>, tensor<1x4x4x!HLFHE.eint<5>>) -> + tensor<4x4x4x!HLFHE.eint<5>> return %res : tensor<4x4x4x!HLFHE.eint<5>> } -)XXX", - "main", true); +)XXX"); uint8_t a0[4][1][4]{ {{1, 2, 3, 4}}, {{5, 6, 7, 8}}, @@ -325,8 +318,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_column) { %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -375,8 +367,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line) { tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -415,8 +406,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) { %res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -459,8 +449,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_term_to_term) { %res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> return %res : tensor<4x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); std::vector a0{32, 9, 12, 9}; std::vector a1{31, 6, 2, 3}; @@ -487,12 +476,11 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_term_to_term_broadcast) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( // Returns the term to term substraction of `%a0` with `%a1`, where dimensions equals to one are stretched. - func @main(%a0: tensor<4x1x4xi5>, %a1: tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> { - %res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> - return %res : tensor<4x4x4x!HLFHE.eint<4>> + func @main(%a0: tensor<4x1x4xi8>, %a1: tensor<1x4x4x!HLFHE.eint<7>>) -> tensor<4x4x4x!HLFHE.eint<7>> { + %res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi8>, tensor<1x4x4x!HLFHE.eint<7>>) -> tensor<4x4x4x!HLFHE.eint<7>> + return %res : tensor<4x4x4x!HLFHE.eint<7>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[4][1][4]{ {{1, 2, 3, 4}}, {{5, 6, 7, 8}}, @@ -547,8 +535,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_column) { tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -597,8 +584,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_line) { tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -637,8 +623,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_line_missing_dim) { %res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -677,12 +662,11 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_term_to_term) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( // Returns the term to term multiplication of `%a0` with `%a1` - func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> { - %res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> - return %res : tensor<4x!HLFHE.eint<4>> + func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> { + %res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> + return %res : tensor<4x!HLFHE.eint<6>> } -)XXX", - "main", true); +)XXX"); std::vector a0{31, 6, 12, 9}; std::vector a1{2, 3, 2, 3}; @@ -709,12 +693,11 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_term_to_term_broadcast) { mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( // Returns the term to term multiplication of `%a0` with `%a1`, where dimensions equals to one are stretched. - func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> { - %res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> - return %res : tensor<4x4x4x!HLFHE.eint<4>> + func @main(%a0: tensor<4x1x4x!HLFHE.eint<6>>, %a1: tensor<1x4x4xi7>) -> tensor<4x4x4x!HLFHE.eint<6>> { + %res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<6>>, tensor<1x4x4xi7>) -> tensor<4x4x4x!HLFHE.eint<6>> + return %res : tensor<4x4x4x!HLFHE.eint<6>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[4][1][4]{ {{1, 2, 3, 4}}, {{5, 6, 7, 8}}, @@ -768,8 +751,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_column) { %res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -816,8 +798,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line) { %res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -856,8 +837,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) { %res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>> return %res : tensor<3x3x!HLFHE.eint<4>> } -)XXX", - "main", true); +)XXX"); const uint8_t a0[3][3]{ {1, 2, 3}, {4, 5, 6}, @@ -886,4 +866,4 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) { EXPECT_EQ((*res)[i * 3 + j], a0[i][j] * a1[0][j]); } } -} \ No newline at end of file +}