From 5df775a51b39f2f35c7ccfea2faf399f8a84fd66 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 12 Nov 2021 11:36:09 +0100 Subject: [PATCH] feat(compiler): MANP Analysis of HLFHELinalg.matmul (closes #178) --- compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 71 +++++++++++++++++ .../Dialect/HLFHE/Analysis/MANP_linalg.mlir | 79 +++++++++++++++++++ .../unittest/end_to_end_jit_hlfhelinalg.cc | 3 +- 3 files changed, 151 insertions(+), 2 deletions(-) diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 3e2774981..36cf53b4a 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -618,6 +618,73 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUMul(sqNorm, eNorm); } +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `HLFHE.mul_eint_int` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHELinalg::MatMulEintIntOp op, + llvm::ArrayRef *> operandMANPs) { + + mlir::RankedTensorType rhsTy = + op.rhs().getType().cast(); + mlir::RankedTensorType lhsTy = + op.lhs().getType().cast(); + + mlir::Type iTy = rhsTy.getElementType(); + + assert(iTy.isSignlessInteger() && + "Only multiplications with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt lhsNorm = operandMANPs[0]->getValue().getMANP().getValue(); + // Initial value of the accumulator + llvm::APInt accNorm = llvm::APInt{1, 1, false}; + + mlir::arith::ConstantOp cstOp = + llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + mlir::DenseIntElementsAttr denseVals = + cstOp ? cstOp->getAttrOfType("value") + : nullptr; + + if (denseVals) { + // For a constant operand use actual constant to calculate 2-norm + // tensor = tensor * tensor compute the max 2-norm of the + // result + int64_t M = lhsTy.getShape()[0]; + int64_t N = rhsTy.getShape()[1]; + int64_t P = rhsTy.getShape()[0]; + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + llvm::APInt tmpNorm = llvm::APInt{1, 1, false}; + for (int64_t p = 0; p < P; p++) { + llvm::APInt cst = denseVals.getFlatValue(p * N + n); + llvm::APInt rhsNorm = APIntWidthExtendUSq(cst); + llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); + tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); + } + accNorm = APIntUMax(accNorm, tmpNorm); + } + } + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + llvm::APInt rhsNorm = conservativeIntNorm2Sq(iTy); + // For tensor = tensor * tensor they are P HLFHE.mul_eint_int + // and HLFHE.add_eint operations for each elements of the result + int64_t P = rhsTy.getShape()[0]; + for (int64_t i = 0; i < P; i++) { + llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); + accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); + } + } + + return accNorm; +} + static llvm::APInt getSqMANP( mlir::tensor::ExtractOp op, llvm::ArrayRef *> operandMANPs) { @@ -727,6 +794,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); + } else if (auto matmulEintIntOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(matmulEintIntOp, operands); } else if (llvm::isa( op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir index c4c77e56a..860dfb0f1 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir @@ -133,4 +133,83 @@ func @apply_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3> // CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>> %res = "HLFHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<8x!HLFHE.eint<3>> return %res : tensor<8x!HLFHE.eint<3>> +} + +// ----- + +func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!HLFHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> { + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 + // manp(add_eint(mul, acc)) = 64 + 1 = 65 + // ceil(sqrt(65)) = 9 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!HLFHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!HLFHE.eint<2>>, %arg1: tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> { + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 + // manp(add_eint(mul, acc)) = 64 + 1 = 65 + // p = 1 + // manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 + // manp(add_eint(mul, acc)) = 64 + 65 = 129 + // ceil(sqrt(129)) = 12 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!HLFHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_eint_int_cst_p_1(%arg0: tensor<3x1x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + %0 = arith.constant dense<[[3, 1]]> : tensor<1x2xi3> + // c(m,n) = a(m,p) * b(p,n) the max cst is used for n = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9 + // manp(add_eint(mul, acc)) = 9 + 1 = 10 + // ceil(sqrt(10)) = 4 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x1x!HLFHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_eint_int_cst_p_2_n_0(%arg0: tensor<3x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + %0 = arith.constant dense<[[4, 1],[3, 1]]> : tensor<2x2xi3> + // c(m,n) = a(m,p) * b(p,n) the max csts [4,3] are used for n = 0 + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9 + // manp(add_eint(mul, acc)) = 9 + 1 = 10 + // p = 1 + // mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17 + // manp(add_eint(mul, acc)) = 17 + 9 = 26 + // ceil(sqrt(26)) = 6 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 6 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!HLFHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + %0 = arith.constant dense<[[1, 4],[3, 1]]> : tensor<2x2xi3> + // c(m,n) = a(m,p) * b(p,n) the max csts [4,1] are used for n = 1 + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16 + // manp(add_eint(mul, acc)) = 16 + 1 = 17 + // p = 1 + // mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1 + // manp(add_eint(mul, acc)) = 1 + 17 = 18 + // ceil(sqrt(18)) = 5 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!HLFHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> } \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index 6c491eae9..311c488ca 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1133,8 +1133,7 @@ TEST(End2EndJit_HLFHELinalg, matmul_eint_int) { %0 = "HLFHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!HLFHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>> return %0 : tensor<3x3x!HLFHE.eint<6>> } -)XXX", - "main", true); +)XXX"); const uint8_t A[3][2]{ {1, 2}, {3, 4},