mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Add support of HLFHELinalg binary operators in MANP pass (close #172)
This commit is contained in:
committed by
Andi Drebes
parent
be92b4580d
commit
2900c9a2a1
@@ -3,6 +3,7 @@
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEDialect.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHEOps.h>
|
||||
#include <zamalang/Dialect/HLFHE/IR/HLFHETypes.h>
|
||||
#include <zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.h>
|
||||
#include <zamalang/Support/math.h>
|
||||
|
||||
#include <limits>
|
||||
@@ -422,6 +423,161 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `HLFHELinalg.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHELinalg::AddEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType op1Ty =
|
||||
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type iTy = op1Ty.getElementType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only additions with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
|
||||
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
|
||||
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
if (maxCst.ult(iCst)) {
|
||||
maxCst = iCst;
|
||||
}
|
||||
}
|
||||
sqNorm = APIntWidthExtendUSq(maxCst);
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHELinalg::AddEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHELinalg.sub_int_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHELinalg::SubIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType op0Ty =
|
||||
op->getOpOperand(0).get().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type iTy = op0Ty.getElementType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only subtractions with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
|
||||
op->getOpOperand(0).get().getDefiningOp());
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
|
||||
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
|
||||
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
if (maxCst.ult(iCst)) {
|
||||
maxCst = iCst;
|
||||
}
|
||||
}
|
||||
sqNorm = APIntWidthExtendUSq(maxCst);
|
||||
} else {
|
||||
// For dynamic plaintext operands conservatively assume that the integer has
|
||||
// its maximum possible value
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHELinalg::MulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType op0Ty =
|
||||
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type iTy = op0Ty.getElementType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only multiplications with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
mlir::ConstantOp cstOp = llvm::dyn_cast_or_null<mlir::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
|
||||
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
|
||||
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
if (maxCst.ult(iCst)) {
|
||||
maxCst = iCst;
|
||||
}
|
||||
}
|
||||
sqNorm = APIntWidthExtendUSq(maxCst);
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -508,6 +664,23 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::isa<mlir::zamalang::HLFHE::ApplyLookupTableEintOp>(op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
}
|
||||
// HLFHELinalg Operators
|
||||
else if (auto addEintIntOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::AddEintIntOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(addEintIntOp, operands);
|
||||
} else if (auto addEintOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::AddEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(addEintOp, operands);
|
||||
} else if (auto subIntEintOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::SubIntEintOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
|
||||
} else if (auto mulEintIntOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::MulEintIntOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
}
|
||||
// Tensor Operators
|
||||
// ExtractOp
|
||||
else if (auto extractOp = llvm::dyn_cast<mlir::tensor::ExtractOp>(op)) {
|
||||
|
||||
94
compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir
Normal file
94
compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir
Normal file
@@ -0,0 +1,94 @@
|
||||
// RUN: zamacompiler --passes MANP --action=dump-hlfhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
func @single_cst_add_eint_int(%t: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
%cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
|
||||
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.add_eint_int"(%e, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_add_eint(%e0: tensor<8x!HLFHE.eint<2>>, %e1: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.add_eint"(%e0, %e1) : (tensor<8x!HLFHE.eint<2>>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_sub_int_eint(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
%cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
|
||||
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_sub_int_eint(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_mul_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
%cst = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
|
||||
|
||||
// %0 = "HLFHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_mul_eint_int(%e: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.mul_eint_int"(%e, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @chain_add_eint_int(%e: tensor<8x!HLFHE.eint<2>>) -> tensor<8x!HLFHE.eint<2>>
|
||||
{
|
||||
%cst0 = std.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
|
||||
%cst1 = std.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi3>
|
||||
%cst2 = std.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi3>
|
||||
%cst3 = std.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi3>
|
||||
// CHECK: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%1 = "HLFHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%2 = "HLFHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "HLFHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%3 = "HLFHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
return %3 : tensor<8x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -11,12 +11,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x!HLFHE.eint<4>>
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
|
||||
return %res : tensor<4x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
std::vector<uint8_t> a0{31, 6, 12, 9};
|
||||
std::vector<uint8_t> a1{32, 9, 2, 3};
|
||||
|
||||
@@ -43,12 +42,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_term_to_term_broadcast) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x4x4x!HLFHE.eint<4>>
|
||||
func @main(%a0: tensor<4x1x4x!HLFHE.eint<5>>, %a1: tensor<1x4x4xi6>) -> tensor<4x4x4x!HLFHE.eint<5>> {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<5>>, tensor<1x4x4xi6>) -> tensor<4x4x4x!HLFHE.eint<5>>
|
||||
return %res : tensor<4x4x4x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
uint8_t a0[4][1][4]{
|
||||
{{1, 2, 3, 4}},
|
||||
{{5, 6, 7, 8}},
|
||||
@@ -102,8 +100,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_column) {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -149,8 +146,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_line) {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -188,8 +184,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_int_matrix_line_missing_dim) {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -228,12 +223,11 @@ TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x!HLFHE.eint<4>>
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4x!HLFHE.eint<6>>) -> tensor<4x!HLFHE.eint<6>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4x!HLFHE.eint<6>>) -> tensor<4x!HLFHE.eint<6>>
|
||||
return %res : tensor<4x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
|
||||
std::vector<uint8_t> a0{31, 6, 12, 9};
|
||||
std::vector<uint8_t> a1{32, 9, 2, 3};
|
||||
@@ -263,14 +257,13 @@ TEST(End2EndJit_HLFHELinalg, add_eint_term_to_term_broadcast) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term addition of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1:
|
||||
tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> {
|
||||
func @main(%a0: tensor<4x1x4x!HLFHE.eint<5>>, %a1:
|
||||
tensor<1x4x4x!HLFHE.eint<5>>) -> tensor<4x4x4x!HLFHE.eint<5>> {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) :
|
||||
(tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4x!HLFHE.eint<4>>) ->
|
||||
tensor<4x4x4x!HLFHE.eint<4>> return %res : tensor<4x4x4x!HLFHE.eint<4>>
|
||||
(tensor<4x1x4x!HLFHE.eint<5>>, tensor<1x4x4x!HLFHE.eint<5>>) ->
|
||||
tensor<4x4x4x!HLFHE.eint<5>> return %res : tensor<4x4x4x!HLFHE.eint<5>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
uint8_t a0[4][1][4]{
|
||||
{{1, 2, 3, 4}},
|
||||
{{5, 6, 7, 8}},
|
||||
@@ -325,8 +318,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_column) {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -375,8 +367,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line) {
|
||||
tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res :
|
||||
tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -415,8 +406,7 @@ TEST(End2EndJit_HLFHELinalg, add_eint_matrix_line_missing_dim) {
|
||||
%res = "HLFHELinalg.add_eint"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -459,8 +449,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_term_to_term) {
|
||||
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!HLFHE.eint<4>>) -> tensor<4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
std::vector<uint8_t> a0{32, 9, 12, 9};
|
||||
std::vector<uint8_t> a1{31, 6, 2, 3};
|
||||
|
||||
@@ -487,12 +476,11 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_term_to_term_broadcast) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equals to one are stretched.
|
||||
func @main(%a0: tensor<4x1x4xi5>, %a1: tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!HLFHE.eint<4>>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x4x4x!HLFHE.eint<4>>
|
||||
func @main(%a0: tensor<4x1x4xi8>, %a1: tensor<1x4x4x!HLFHE.eint<7>>) -> tensor<4x4x4x!HLFHE.eint<7>> {
|
||||
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi8>, tensor<1x4x4x!HLFHE.eint<7>>) -> tensor<4x4x4x!HLFHE.eint<7>>
|
||||
return %res : tensor<4x4x4x!HLFHE.eint<7>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[4][1][4]{
|
||||
{{1, 2, 3, 4}},
|
||||
{{5, 6, 7, 8}},
|
||||
@@ -547,8 +535,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_column) {
|
||||
tensor<3x1x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res :
|
||||
tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -597,8 +584,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_line) {
|
||||
tensor<1x3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>> return %res :
|
||||
tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -637,8 +623,7 @@ TEST(End2EndJit_HLFHELinalg, sub_int_eint_matrix_line_missing_dim) {
|
||||
%res = "HLFHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x!HLFHE.eint<4>>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -677,12 +662,11 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_term_to_term) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term multiplication of `%a0` with `%a1`
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<4>>, %a1: tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<4>>, tensor<4xi5>) -> tensor<4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x!HLFHE.eint<4>>
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
|
||||
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
|
||||
return %res : tensor<4x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
std::vector<uint8_t> a0{31, 6, 12, 9};
|
||||
std::vector<uint8_t> a1{2, 3, 2, 3};
|
||||
|
||||
@@ -709,12 +693,11 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_term_to_term_broadcast) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the term to term multiplication of `%a0` with `%a1`, where dimensions equals to one are stretched.
|
||||
func @main(%a0: tensor<4x1x4x!HLFHE.eint<4>>, %a1: tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>> {
|
||||
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<4>>, tensor<1x4x4xi5>) -> tensor<4x4x4x!HLFHE.eint<4>>
|
||||
return %res : tensor<4x4x4x!HLFHE.eint<4>>
|
||||
func @main(%a0: tensor<4x1x4x!HLFHE.eint<6>>, %a1: tensor<1x4x4xi7>) -> tensor<4x4x4x!HLFHE.eint<6>> {
|
||||
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x1x4x!HLFHE.eint<6>>, tensor<1x4x4xi7>) -> tensor<4x4x4x!HLFHE.eint<6>>
|
||||
return %res : tensor<4x4x4x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[4][1][4]{
|
||||
{{1, 2, 3, 4}},
|
||||
{{5, 6, 7, 8}},
|
||||
@@ -768,8 +751,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_column) {
|
||||
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -816,8 +798,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line) {
|
||||
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -856,8 +837,7 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) {
|
||||
%res = "HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<3x3x!HLFHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!HLFHE.eint<4>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<4>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
@@ -886,4 +866,4 @@ TEST(End2EndJit_HLFHELinalg, mul_eint_int_matrix_line_missing_dim) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], a0[i][j] * a1[0][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user