feat(compiler): Add support of HLFHELinalg binary operators in MANP pass (close #172)

This commit is contained in:
Quentin Bourgerie
2021-10-26 21:49:12 +02:00
committed by Andi Drebes
parent be92b4580d
commit 2900c9a2a1
3 changed files with 310 additions and 63 deletions

View File

@@ -3,6 +3,7 @@
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h>
#include <zamalang/Support/math.h>
#include <limits>
@@ -422,6 +423,161 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUMul(sqNorm, eNorm);
}
// Calculates the squared Minimal Arithmetic Noise Padding of an
// `HLFHELinalg.add_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHELinalg::AddEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::RankedTensorType op1Ty =
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
mlir::Type iTy = op1Ty.getElementType();
assert(iTy.isSignlessInteger() &&
"Only additions with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
mlir::DenseIntElementsAttr denseVals =
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
if (denseVals) {
// For a constant operand use actual constant to calculate 2-norm
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
if (maxCst.ult(iCst)) {
maxCst = iCst;
}
}
sqNorm = APIntWidthExtendUSq(maxCst);
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHELinalg::AddEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
return APIntWidthExtendUAdd(a, b);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHELinalg.sub_int_eint` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHELinalg::SubIntEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::RankedTensorType op0Ty =
op->getOpOperand(0).get().getType().cast<mlir::RankedTensorType>();
mlir::Type iTy = op0Ty.getElementType();
assert(iTy.isSignlessInteger() &&
"Only subtractions with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
op->getOpOperand(0).get().getDefiningOp());
mlir::DenseIntElementsAttr denseVals =
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
if (denseVals) {
// For a constant operand use actual constant to calculate 2-norm
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
if (maxCst.ult(iCst)) {
maxCst = iCst;
}
}
sqNorm = APIntWidthExtendUSq(maxCst);
} else {
// For dynamic plaintext operands conservatively assume that the integer has
// its maximum possible value
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHELinalg::MulEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::RankedTensorType op0Ty =
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
mlir::Type iTy = op0Ty.getElementType();
assert(iTy.isSignlessInteger() &&
"Only multiplications with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
mlir::DenseIntElementsAttr denseVals =
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
if (denseVals) {
// For a constant operand use actual constant to calculate 2-norm
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
if (maxCst.ult(iCst)) {
maxCst = iCst;
}
}
sqNorm = APIntWidthExtendUSq(maxCst);
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUMul(sqNorm, eNorm);
}
static llvm::APInt getSqMANP(
mlir::tensor::ExtractOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
@@ -508,6 +664,23 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
norm2SqEquiv = llvm::APInt{1, 1, false};
}
// HLFHELinalg Operators
else if (auto addEintIntOp =
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::AddEintIntOp>(
op)) {
norm2SqEquiv = getSqMANP(addEintIntOp, operands);
} else if (auto addEintOp =
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::AddEintOp>(op)) {
norm2SqEquiv = getSqMANP(addEintOp, operands);
} else if (auto subIntEintOp =
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::SubIntEintOp>(
op)) {
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
} else if (auto mulEintIntOp =
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::MulEintIntOp>(
op)) {
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
}
// Tensor Operators
// ExtractOp
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {

View File

@@ -0,0 +1,94 @@
// RUN: zamacompiler --passes MANP --action=dump-hlfhe --split-input-file %s 2>&1 | FileCheck %s
func @single_cst_add_eint_int(%t: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
{
%cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}
// -----
func @single_dyn_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.add_eint_int"(%e, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}
// -----
func @single_add_eint(%e0: tensor<8x!HLFHE.eint<2>>, %e1: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.add_eint"(%e0, %e1) : (tensor<8x!HLFHE.eint<2>>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}
// -----
func @single_cst_sub_int_eint(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
{
%cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
// CHECK: %[[ret:.*]] = "HLFHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}
// -----
func @single_dyn_sub_int_eint(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "HLFHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}
// -----
func @single_cst_mul_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
{
%cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
// %0 = "HLFHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}
// -----
func @single_dyn_mul_eint_int(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "HLFHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.mul_eint_int"(%e, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
return %0 : tensor<8x!HLFHE.eint<2>>
}
// -----
func @chain_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
{
%cst0 = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
%cst1 = std.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi3>
%cst2 = std.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi3>
%cst3 = std.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi3>
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%0 = "HLFHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
// CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%1 = "HLFHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
// CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%2 = "HLFHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
// CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
%3 = "HLFHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
return %3 : tensor<8x!HLFHE.eint<2>>
}

View File

@@ -11,12 +11,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
return %res : tensor<4x!HLFHE.eint<4>>
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
return %res : tensor<4x!HLFHE.eint<6>>
}
)XXX",
"main", true);
)XXX");
std::vector<uint8_t> a0{31, 6, 12, 9};
std::vector<uint8_t> a1{32, 9, 2, 3};
@@ -43,12 +42,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_broadcast) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
return %res : tensor<4x4x4x!HLFHE.eint<4>>
func @main(%a0: tensor<4x1x4x!HLFHE.eint<5>>, %a1: tensor<1x4x4xi6>) -> tensor<4x4x4x!HLFHE.eint<5>> {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<5>>, tensor<1x4x4xi6>) -> tensor<4x4x4x!HLFHE.eint<5>>
return %res : tensor<4x4x4x!HLFHE.eint<5>>
}
)XXX",
"main", true);
)XXX");
uint8_t a0[4][1][4]{
{{1, 2, 3, 4}},
{{5, 6, 7, 8}},
@@ -102,8 +100,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_column) {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -149,8 +146,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_line) {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -188,8 +184,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_line_missing_dim) {
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -228,12 +223,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
return %res : tensor<4x!HLFHE.eint<4>>
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4x!HLFHE.eint<6>>) -> tensor<4x!HLFHE.eint<6>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4x!HLFHE.eint<6>>) -> tensor<4x!HLFHE.eint<6>>
return %res : tensor<4x!HLFHE.eint<6>>
}
)XXX",
"main", true);
)XXX");
std::vector<uint8_t> a0{31, 6, 12, 9};
std::vector<uint8_t> a1{32, 9, 2, 3};
@@ -263,14 +257,13 @@ TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term_broadcast) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term addition of `%a0` with `%a1`
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1:
tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> {
func @main(%a0: tensor<4x1x4x!HLFHE.eint<5>>, %a1:
tensor<1x4x4x!HLFHE.eint<5>>) -> tensor<4x4x4x!HLFHE.eint<5>> {
%res = "HLFHELinalg.add_eint"(%a0, %a1) :
(tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) ->
tensor<4x4x4x!HLFHE.eint<4>> return %res : tensor<4x4x4x!HLFHE.eint<4>>
(tensor<4x1x4x!HLFHE.eint<5>>, tensor<1x4x4x!HLFHE.eint<5>>) ->
tensor<4x4x4x!HLFHE.eint<5>> return %res : tensor<4x4x4x!HLFHE.eint<5>>
}
)XXX",
"main", true);
)XXX");
uint8_t a0[4][1][4]{
{{1, 2, 3, 4}},
{{5, 6, 7, 8}},
@@ -325,8 +318,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_column) {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -375,8 +367,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line) {
tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res :
tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -415,8 +406,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) {
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -459,8 +449,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_term_to_term) {
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
return %res : tensor<4x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
std::vector<uint8_t> a0{32, 9, 12, 9};
std::vector<uint8_t> a1{31, 6, 2, 3};
@@ -487,12 +476,11 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_term_to_term_broadcast) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equals to one are stretched.
func @main(%a0: tensor<4x1x4xi5>, %a1: tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
return %res : tensor<4x4x4x!HLFHE.eint<4>>
func @main(%a0: tensor<4x1x4xi8>, %a1: tensor<1x4x4x!HLFHE.eint<7>>) -> tensor<4x4x4x!HLFHE.eint<7>> {
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi8>, tensor<1x4x4x!HLFHE.eint<7>>) -> tensor<4x4x4x!HLFHE.eint<7>>
return %res : tensor<4x4x4x!HLFHE.eint<7>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[4][1][4]{
{{1, 2, 3, 4}},
{{5, 6, 7, 8}},
@@ -547,8 +535,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_column) {
tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res :
tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -597,8 +584,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_line) {
tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res :
tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -637,8 +623,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_line_missing_dim) {
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -677,12 +662,11 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_term_to_term) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term multiplication of `%a0` with `%a1`
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
return %res : tensor<4x!HLFHE.eint<4>>
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
return %res : tensor<4x!HLFHE.eint<6>>
}
)XXX",
"main", true);
)XXX");
std::vector<uint8_t> a0{31, 6, 12, 9};
std::vector<uint8_t> a1{2, 3, 2, 3};
@@ -709,12 +693,11 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_term_to_term_broadcast) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the term to term multiplication of `%a0` with `%a1`, where dimensions equals to one are stretched.
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> {
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
return %res : tensor<4x4x4x!HLFHE.eint<4>>
func @main(%a0: tensor<4x1x4x!HLFHE.eint<6>>, %a1: tensor<1x4x4xi7>) -> tensor<4x4x4x!HLFHE.eint<6>> {
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<6>>, tensor<1x4x4xi7>) -> tensor<4x4x4x!HLFHE.eint<6>>
return %res : tensor<4x4x4x!HLFHE.eint<6>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[4][1][4]{
{{1, 2, 3, 4}},
{{5, 6, 7, 8}},
@@ -768,8 +751,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_column) {
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -816,8 +798,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line) {
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -856,8 +837,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) {
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
return %res : tensor<3x3x!HLFHE.eint<4>>
}
)XXX",
"main", true);
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
@@ -886,4 +866,4 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) {
EXPECT_EQ((*res)[i * 3 + j], a0[i][j] * a1[0][j]);
}
}
}
}