mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): MANP Analysis of HLFHELinalg.matmul (closes #178)
This commit is contained in:
@@ -618,6 +618,73 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUMul(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHELinalg::MatMulEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType rhsTy =
|
||||
op.rhs().getType().cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType lhsTy =
|
||||
op.lhs().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type iTy = rhsTy.getElementType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only multiplications with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
// Initial value of the accumulator
|
||||
llvm::APInt accNorm = llvm::APInt{1, 1, false};
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
// tensor<MxN> = tensor<MxP> * tensor<PxN> compute the max 2-norm of the
|
||||
// result
|
||||
int64_t M = lhsTy.getShape()[0];
|
||||
int64_t N = rhsTy.getShape()[1];
|
||||
int64_t P = rhsTy.getShape()[0];
|
||||
for (int64_t m = 0; m < M; m++) {
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
|
||||
for (int64_t p = 0; p < P; p++) {
|
||||
llvm::APInt cst = denseVals.getFlatValue<llvm::APInt>(p * N + n);
|
||||
llvm::APInt rhsNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
}
|
||||
accNorm = APIntUMax(accNorm, tmpNorm);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
llvm::APInt rhsNorm = conservativeIntNorm2Sq(iTy);
|
||||
// For tensor<MxN> = tensor<MxP> * tensor<PxN> they are P HLFHE.mul_eint_int
|
||||
// and HLFHE.add_eint operations for each elements of the result
|
||||
int64_t P = rhsTy.getShape()[0];
|
||||
for (int64_t i = 0; i < P; i++) {
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
|
||||
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
|
||||
}
|
||||
}
|
||||
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -727,6 +794,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::MulEintIntOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
} else if (auto matmulEintIntOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::MatMulEintIntOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(matmulEintIntOp, operands);
|
||||
} else if (llvm::isa<mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
|
||||
@@ -133,4 +133,83 @@ func @apply_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>
|
||||
// CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>>
|
||||
%res = "HLFHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>>
|
||||
return %res : tensor<8x!HLFHE.eint<3>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!HLFHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// ceil(sqrt(65)) = 9
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!HLFHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!HLFHE.eint<2>>, %arg1: tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// p = 1
|
||||
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 65 = 129
|
||||
// ceil(sqrt(129)) = 12
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!HLFHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_eint_int_cst_p_1(%arg0: tensor<3x1x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[3, 1]]> : tensor<1x2xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max cst is used for n = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// ceil(sqrt(10)) = 4
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x1x!HLFHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_eint_int_cst_p_2_n_0(%arg0: tensor<3x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[4, 1],[3, 1]]> : tensor<2x2xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max csts [4,3] are used for n = 0
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// p = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17
|
||||
// manp(add_eint(mul, acc)) = 17 + 9 = 26
|
||||
// ceil(sqrt(26)) = 6
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 6 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!HLFHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[1, 4],[3, 1]]> : tensor<2x2xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max csts [4,1] are used for n = 1
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16
|
||||
// manp(add_eint(mul, acc)) = 16 + 1 = 17
|
||||
// p = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1
|
||||
// manp(add_eint(mul, acc)) = 1 + 17 = 18
|
||||
// ceil(sqrt(18)) = 5
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!HLFHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -1133,8 +1133,7 @@ TEST(End2EndJit_HLFHELinalg, matmul_eint_int) {
|
||||
%0 = "HLFHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!HLFHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>>
|
||||
return %0 : tensor<3x3x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
)XXX");
|
||||
const uint8_t A[3][2]{
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
|
||||
Reference in New Issue
Block a user