From ddbafd713db1b0b9a1bbb6b18ebdfe92bda26faf Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Tue, 16 Nov 2021 13:18:49 +0100 Subject: [PATCH] feat(compiler): Add the HLFHELinalg.matmul_int_eint operator --- .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.td | 41 ++++++- .../TensorOpsToLinalg.cpp | 83 ++++++++----- compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 69 +++++++++++ .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp | 18 +-- .../Dialect/HLFHE/Analysis/MANP_linalg.mlir | 113 ++++++++++++++++-- .../Dialect/HLFHELinalg/ops.invalid.mlir | 34 ++++++ compiler/tests/Dialect/HLFHELinalg/ops.mlir | 13 ++ .../unittest/end_to_end_jit_hlfhelinalg.cc | 61 +++++++++- 8 files changed, 376 insertions(+), 56 deletions(-) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index 32c6808ef..729b1e801 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -402,8 +402,47 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> { let results = (outs Type.predicate, HasStaticShapePred]>>); let verifier = [{ - return ::mlir::zamalang::HLFHELinalg::verifyMatmul(*this); + return ::mlir::zamalang::HLFHELinalg::verifyMatmul(*this); }]; } +def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> { + let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of clear integers and a matrix of encrypted integers."; + + let description = [{ + Performs a matrix multiplication of a matrix of clear integers and a matrix of encrypted integers. + The width of the clear integers must be less than or equals to the witdh of encrypted integers. + + ```mlir + "HLFHELinalg.matmul_int_eint(%a, %b) : (tensor, tensor>) -> tensor>" + ``` + + Examples: + ```mlir + // Returns the matrix multiplication of a 3x2 matrix of clear integers and a 2x3 matrix of encrypted integers. + // [ 1, 2, 3] + // [ 2, 3, 4] + // * + // [1,2] [ 5, 8,11] + // [3,4] = [11,18,25] + // [5,6] [17,28,39] + // + "HLFHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>> + + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let verifier = [{ + return ::mlir::zamalang::HLFHELinalg::verifyMatmul(*this); + }]; +} + + #endif diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index da85e9219..a4db2c5de 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -598,11 +598,12 @@ struct HLFHELinalgNegEintToLinalgGeneric }; }; -// This rewrite pattern transforms any instance of -// operators `HLFHELinalg.matmul_eint_int` to an instance of `linalg.generic` -// with an appropriate region using `HLFHE.mul_eint_int` and `HLFHE.add_eint` -// operation, an appropriate specification for the iteration dimensions and -// appropriate operations managing the accumulator of `linalg.generic`. +// This template rewrite pattern transforms any instance of +// operators `HLFHELinalgMatmulOp` to an instance of `linalg.generic` +// with an appropriate region using a builder that create the multiplication +// operators and `HLFHE.add_eint` operation, an appropriate specification for +// the iteration dimensions and appropriate operations managing the accumulator +// of `linalg.generic`. // // Example: // @@ -633,27 +634,33 @@ struct HLFHELinalgNegEintToLinalgGeneric // outs(%C : tensor>) // { // ^bb0(%a: !HLFHE.eint

, %b: ip', %c: !HLFHE.eint

) : -// %d = "HLFHE.mul_eint_int"(%a, %b) : -// (!HLFHE.eint

, ip') -> !HLFHE.eint

+// %d = createMulOp(%a, %b): !HLFHE.eint

// %e = "HLFHE.add_eint"(%c, %d): // (!HLFHE.eint

, !HLFHE.eint

) -> !HLFHE.eint

// linalg.yield %e : !HLFHE.eint

// } // -struct HLFHELinalgMatmulEintIntToLinalgGeneric - : public mlir::OpRewritePattern< - mlir::zamalang::HLFHELinalg::MatMulEintIntOp> { - HLFHELinalgMatmulEintIntToLinalgGeneric(::mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : ::mlir::OpRewritePattern( - context, benefit) {} +template +struct HLFHELinalgMatmulToLinalgGeneric + : public mlir::OpRewritePattern { + HLFHELinalgMatmulToLinalgGeneric( + mlir::MLIRContext *context, + std::function + createMulOp, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit), + createMulOp(createMulOp) {} ::mlir::LogicalResult - matchAndRewrite(mlir::zamalang::HLFHELinalg::MatMulEintIntOp matmulOp, + matchAndRewrite(HLFHELinalgMatmulOp matmulOp, ::mlir::PatternRewriter &rewriter) const override { + mlir::Location matmulLoc = matmulOp.getLoc(); mlir::RankedTensorType resultTy = ((mlir::Type)matmulOp->getResult(0).getType()) .cast(); + mlir::Type resultElementTy = resultTy.getElementType(); // Create tensor.generate for initial value auto generateBody = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, @@ -661,17 +668,13 @@ struct HLFHELinalgMatmulEintIntToLinalgGeneric // %z = "HLFHE.zero" : () -> !HLFHE.eint<2> mlir::zamalang::HLFHE::ZeroEintOp zeroOp = nestedBuilder.create( - matmulOp.getLoc(), resultTy.getElementType()); + matmulLoc, resultElementTy); // linalg.yield %z : !HLFHE.eint

- nestedBuilder.create(matmulOp.getLoc(), + nestedBuilder.create(matmulLoc, zeroOp.getResult()); }; mlir::tensor::GenerateOp init = rewriter.create( - matmulOp.getLoc(), (mlir::Type)resultTy, mlir::ValueRange{}, - generateBody); - // linalg.init_tensor for initial value - // mlir::Value init = rewriter.create( - // matmulOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); + matmulLoc, (mlir::Type)resultTy, mlir::ValueRange{}, generateBody); // Create the affine #maps_0 llvm::SmallVector maps{ // (m, n, p) -> (m, p), @@ -698,17 +701,15 @@ struct HLFHELinalgMatmulEintIntToLinalgGeneric mlir::ValueRange blockArgs) { // "HLFHE.mul_eint_int"(%a, %b) : (!HLFHE.eint

, ip') -> !HLFHE.eint

mlir::zamalang::HLFHE::MulEintIntOp mulEintIntOp = - nestedBuilder.create( - matmulOp.getLoc(), resultTy.getElementType(), blockArgs[0], - blockArgs[1]); + createMulOp(nestedBuilder, matmulLoc, resultElementTy, blockArgs[0], + blockArgs[1]); // "HLFHE.add_eint"(%c, %d): (!HLFHE.eint

, !HLFHE.eint

) -> // !HLFHE.eint

mlir::zamalang::HLFHE::AddEintOp addEintOp = nestedBuilder.create( - matmulOp.getLoc(), resultTy.getElementType(), blockArgs[2], - mulEintIntOp); + matmulLoc, resultElementTy, blockArgs[2], mulEintIntOp); // linalg.yield %e : !HLFHE.eint

- nestedBuilder.create(matmulOp.getLoc(), + nestedBuilder.create(matmulLoc, addEintOp.getResult()); }; @@ -720,14 +721,19 @@ struct HLFHELinalgMatmulEintIntToLinalgGeneric llvm::StringRef call{""}; mlir::linalg::GenericOp genericOp = - rewriter.create(matmulOp.getLoc(), resTypes, - ins, outs, maps, iteratorTypes, - doc, call, bodyBuilder); + rewriter.create(matmulLoc, resTypes, ins, outs, + maps, iteratorTypes, doc, call, + bodyBuilder); rewriter.replaceOp(matmulOp, {genericOp.getResult(0)}); return ::mlir::success(); }; + +private: + std::function + createMulOp; }; namespace { @@ -771,7 +777,20 @@ void HLFHETensorOpsToLinalg::runOnFunction() { &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); - patterns.insert(&getContext()); + patterns.insert>( + &getContext(), [](mlir::OpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value arg0, mlir::Value arg1) { + return builder.create(loc, type, + arg0, arg1); + }); + patterns.insert>( + &getContext(), [](mlir::OpBuilder &builder, mlir::Location loc, + mlir::Type type, mlir::Value arg0, mlir::Value arg1) { + return builder.create(loc, type, + arg1, arg0); + }); patterns.insert( &getContext()); diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index fe4a792f9..aa8845fd6 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -686,6 +686,71 @@ static llvm::APInt getSqMANP( return accNorm; } +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHELinalg::MatMulIntEintOp op, + llvm::ArrayRef *> operandMANPs) { + + mlir::RankedTensorType rhsTy = + op.rhs().getType().cast(); + mlir::RankedTensorType lhsTy = + op.lhs().getType().cast(); + + mlir::Type iTy = lhsTy.getElementType(); + + assert(iTy.isSignlessInteger() && + "Only multiplications with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().getValue(); + // Initial value of the accumulator + llvm::APInt accNorm = llvm::APInt{1, 1, false}; + + mlir::arith::ConstantOp cstOp = + llvm::dyn_cast_or_null( + op->getOpOperand(0).get().getDefiningOp()); + mlir::DenseIntElementsAttr denseVals = + cstOp ? cstOp->getAttrOfType("value") + : nullptr; + + if (denseVals) { + // For a constant operand use actual constant to calculate 2-norm + // tensor = tensor * tensor compute the max 2-norm of the + // result + int64_t M = lhsTy.getShape()[0]; + int64_t N = rhsTy.getShape()[1]; + int64_t P = rhsTy.getShape()[0]; + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + llvm::APInt tmpNorm = llvm::APInt{1, 1, false}; + for (int64_t p = 0; p < P; p++) { + llvm::APInt cst = denseVals.getFlatValue(m * P + p); + llvm::APInt lhsNorm = APIntWidthExtendUSq(cst); + llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); + tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); + } + accNorm = APIntUMax(accNorm, tmpNorm); + } + } + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + llvm::APInt lhsNorm = conservativeIntNorm2Sq(iTy); + // For tensor = tensor * tensor they are P HLFHE.mul_eint_int + // and HLFHE.add_eint operations for each elements of the result + int64_t P = rhsTy.getShape()[0]; + for (int64_t i = 0; i < P; i++) { + llvm::APInt mulNorm = APIntWidthExtendUMul(rhsNorm, lhsNorm); + accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); + } + } + + return accNorm; +} + static llvm::APInt getSqMANP( mlir::tensor::ExtractOp op, llvm::ArrayRef *> operandMANPs) { @@ -823,6 +888,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(matmulEintIntOp, operands); + } else if (auto matmulIntEintOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(matmulIntEintOp, operands); } else if (llvm::isa< mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp, mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>( diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp index f64e122e5..e9cc7114a 100644 --- a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp +++ b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp @@ -309,14 +309,15 @@ verifyApplyMultiLookupTable(ApplyMultiLookupTableEintOp &op) { return ::mlir::success(); } -/// Verify the matmul shapes, the type of tensor elements are checked by -/// TensorBinaryEintInt -mlir::LogicalResult verifyMatmul(MatMulEintIntOp &op) { - auto lhsTy = op.lhs().getType().cast(); +/// Verify the matmul shapes, the type of tensor elements should be checked by +/// something else +template mlir::LogicalResult verifyMatmul(MatMulOp &op) { + auto lhsTy = ((mlir::Type)op.lhs().getType()).cast(); - auto rhsTy = op.rhs().getType().cast(); + auto rhsTy = ((mlir::Type)op.rhs().getType()).cast(); - auto resultTy = op.getResult().getType().cast(); + auto resultTy = + ((mlir::Type)op.getResult().getType()).cast(); if (lhsTy.getShape().size() != 2 || rhsTy.getShape().size() != 2) { op.emitOpError() << "should have 2D tensors as operands"; @@ -333,9 +334,8 @@ mlir::LogicalResult verifyMatmul(MatMulEintIntOp &op) { rhsTy.getDimSize(1)}; if (!resultTy.hasStaticShape(expectedShape)) { op.emitOpError() << "should have the result shape compatible with operands " - "shape, expect " - << expectedShape[0] << "x" << expectedShape[1] - << " as the shape of the result"; + << "shape, expect " << expectedShape[0] << "x" + << expectedShape[1] << " as the shape of the result"; return mlir::failure(); } return mlir::success(); diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir index 7893787e8..6135662e8 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir @@ -137,6 +137,32 @@ func @apply_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3> // ----- +///////////////////////////////////////////////// +// HLFHELinalg.apply_multi_lookup_table +///////////////////////////////////////////////// + +func @apply_multi_lookup_table(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> { + // CHECK: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> + %res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> + return %res : tensor<3x3x!HLFHE.eint<3>> +} + +// ----- + +func @apply_multi_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> { + // CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + %0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> + // CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> + %res = "HLFHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> + return %res : tensor<8x!HLFHE.eint<3>> +} + +// ----- + +///////////////////////////////////////////////// +// HLFHELinalg.matmul_ent_int +///////////////////////////////////////////////// + func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!HLFHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> { // p = 0 // acc = manp(0) = 1 @@ -214,20 +240,85 @@ func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!HLFHE.eint<2>>) -> tensor<3 return %1 : tensor<3x2x!HLFHE.eint<2>> } +///////////////////////////////////////////////// +// HLFHELinalg.matmul_int_eint +///////////////////////////////////////////////// + // ----- -func @apply_multi_lookup_table(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> { - // CHECK: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> - %res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> - return %res : tensor<3x3x!HLFHE.eint<3>> +func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 + // manp(add_eint(mul, acc)) = 64 + 1 = 65 + // ceil(sqrt(65)) = 9 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> } // ----- -func @apply_multi_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> { - // CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> - %0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>> - // CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> - %res = "HLFHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> - return %res : tensor<8x!HLFHE.eint<3>> -} \ No newline at end of file +func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 + // manp(add_eint(mul, acc)) = 64 + 1 = 65 + // p = 1 + // manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 + // manp(add_eint(mul, acc)) = 64 + 65 = 129 + // ceil(sqrt(129)) = 12 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_int_eint_cst_p_1(%arg0: tensor<1x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> { + %0 = arith.constant dense<[[3], [1]]> : tensor<2x1xi3> + // c(m,n) = a(m,p) * b(p,n) the max cst is used for m = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9 + // manp(add_eint(mul, acc)) = 9 + 1 = 10 + // ceil(sqrt(10)) = 4 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x1xi3>, tensor<1x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> + return %1 : tensor<2x3x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_int_eint_cst_p_2_n_0(%arg0: tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> { + %0 = arith.constant dense<[[3, 4],[1, 1]]> : tensor<2x2xi3> + // c(m,n) = a(m,p) * b(p,n) the max csts [4,3] are used for m = 0 + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9 + // manp(add_eint(mul, acc)) = 9 + 1 = 10 + // p = 1 + // mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17 + // manp(add_eint(mul, acc)) = 17 + 9 = 26 + // ceil(sqrt(26)) = 6 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 6 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> + return %1 : tensor<2x3x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> { + %0 = arith.constant dense<[[4, 1],[3, 1]]> : tensor<2x2xi3> + // c(m,n) = a(m,p) * b(p,n) the max csts [4,1] are used for m = 1 + // p = 0 + // acc = manp(0) = 1 + // mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16 + // manp(add_eint(mul, acc)) = 16 + 1 = 17 + // p = 1 + // mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1 + // manp(add_eint(mul, acc)) = 1 + 17 = 18 + // ceil(sqrt(18)) = 5 + // CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + %1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> + return %1 : tensor<2x3x!HLFHE.eint<2>> +} diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir index 62b0506e9..6131e9107 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir @@ -194,4 +194,38 @@ func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) return %1 : tensor<4x2x!HLFHE.eint<2>> } +// ----- +///////////////////////////////////////////////// +// HLFHELinalg.matmul_int_eint +///////////////////////////////////////////////// + +func @matmul_int_eint(%arg0: tensor<2x3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have 2D tensors as operands}} + %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<2x3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<2x4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have 2D tensors as operands}} + %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<2x4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<5x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have the dimension #0 of operand #1equals to the dimension #1 of operand #0, expect 4 got 5}} + %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<5x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} + +// ----- + +func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<4x2x!HLFHE.eint<2>> { + // expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have the result shape compatible with operands shape, expect 3x2 as the shape of the result}} + %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<4x2x!HLFHE.eint<2>> + return %1 : tensor<4x2x!HLFHE.eint<2>> +} diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir index ebaa00c77..b854b13f1 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -316,3 +316,16 @@ func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> return %1 : tensor<3x2x!HLFHE.eint<2>> } + +///////////////////////////////////////////////// +// HLFHELinalg.matmul_int_eint +///////////////////////////////////////////////// + +// CHECK-LABEL: @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> +func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> { + // CHECK-NEXT: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1) : (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> + // CHECK-NEXT: return %[[V1]] : tensor<3x2x!HLFHE.eint<2>> + + %1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index 1bbb8cbfb..8d2b82457 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1132,8 +1132,7 @@ TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table_with_boradcast) { mlir::zamalang::TensorLambdaArgument< mlir::zamalang::IntLambdaArgument> - lutsArg(llvm::MutableArrayRef((uint64_t *)luts, 3 * 4), - {3, 4}); + lutsArg(llvm::MutableArrayRef((uint64_t *)luts, 3 * 4), {3, 4}); llvm::Expected> res = lambda.operator()>({&tArg, &lutsArg}); @@ -1276,6 +1275,62 @@ TEST(End2EndJit_HLFHELinalg, matmul_eint_int) { } } +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg matmul_eint_int //////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_HLFHELinalg, matmul_int_eint) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the matrix multiplication of a 3x2 matrix of encrypted integers and a 2x3 matrix of integers. + // [ 1, 2, 3] + // [ 2, 3, 4] + // * + // [1,2] [ 5, 8,11] + // [3,4] = [11,18,25] + // [5,6] [17,28,39] + func @main(%a: tensor<3x2xi7>, %b: tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>> { + %0 = "HLFHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>> + return %0 : tensor<3x3x!HLFHE.eint<6>> + } +)XXX"); + const uint8_t A[3][2]{ + {1, 2}, + {3, 4}, + {5, 6}, + }; + const uint8_t B[2][3]{ + {1, 2, 3}, + {2, 3, 4}, + }; + const uint8_t expected[3][3]{ + {5, 8, 11}, + {11, 18, 25}, + {17, 28, 39}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + aArg(llvm::ArrayRef((const uint8_t *)A, 3 * 2), {3, 2}); + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + bArg(llvm::ArrayRef((const uint8_t *)B, 2 * 3), {2, 3}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // linalg.tensor_collapse_shape /////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// @@ -1376,4 +1431,4 @@ func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> { } } } -} \ No newline at end of file +}