diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 9cb5ecac4..7b08adc0f 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -436,6 +436,138 @@ struct HLFHELinalgNegEintToLinalgGeneric }; }; +// This rewrite pattern transforms any instance of +// operators `HLFHELinalg.matmul_eint_int` to an instance of `linalg.generic` +// with an appropriate region using `HLFHE.mul_eint_int` and `HLFHE.add_eint` +// operation, an appropriate specification for the iteration dimensions and +// appropriate operations managing the accumulator of `linalg.generic`. +// +// Example: +// +// "HLFHELinalg.matmul_eint_int(%a, %b) : +// (tensor>, tensor) -> +// tensor>" + +// +// becomes: +// +// #maps_0 = [ +// (m, n, p) -> (m, p), +// (m, n, p) -> (p, n), +// (m, n, p) -> (m, n) +// ] +// #attributes_0 = { +// indexing_maps = #maps_0, +// iterator_types = ["parallel", "parallel", "reduction"] +// } +// %init = linalg.generate { +// ^bb0(%i : index, %j : index, %k : index): +// %z = "HLFHE.zero" : () -> !HLFHE.eint<2> +// linalg.yield %z +// }: tensor> +// linalg.generic #attributes_0 +// ins(%A, %B : tensor>, +// tensor) +// outs(%C : tensor>) +// { +// ^bb0(%a: !HLFHE.eint

, %b: ip', %c: !HLFHE.eint

) : +// %d = "HLFHE.mul_eint_int"(%a, %b) : +// (!HLFHE.eint

, ip') -> !HLFHE.eint

+// %e = "HLFHE.add_eint"(%c, %d): +// (!HLFHE.eint

, !HLFHE.eint

) -> !HLFHE.eint

+// linalg.yield %e : !HLFHE.eint

+// } +// +struct HLFHELinalgMatmulEintIntToLinalgGeneric + : public mlir::OpRewritePattern< + mlir::zamalang::HLFHELinalg::MatMulEintIntOp> { + HLFHELinalgMatmulEintIntToLinalgGeneric(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(mlir::zamalang::HLFHELinalg::MatMulEintIntOp matmulOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::RankedTensorType resultTy = + ((mlir::Type)matmulOp->getResult(0).getType()) + .cast(); + // Create tensor.generate for initial value + auto generateBody = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + // %z = "HLFHE.zero" : () -> !HLFHE.eint<2> + mlir::zamalang::HLFHE::ZeroEintOp zeroOp = + nestedBuilder.create( + matmulOp.getLoc(), resultTy.getElementType()); + // linalg.yield %z : !HLFHE.eint

+ nestedBuilder.create(matmulOp.getLoc(), + zeroOp.getResult()); + }; + mlir::tensor::GenerateOp init = rewriter.create( + matmulOp.getLoc(), (mlir::Type)resultTy, mlir::ValueRange{}, + generateBody); + // linalg.init_tensor for initial value + // mlir::Value init = rewriter.create( + // matmulOp.getLoc(), resultTy.getShape(), resultTy.getElementType()); + // Create the affine #maps_0 + llvm::SmallVector maps{ + // (m, n, p) -> (m, p), + mlir::AffineMap::get( + 3, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2)}, + rewriter.getContext()), + // (m, n, p) -> (p, n), + mlir::AffineMap::get( + 3, 0, {rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(1)}, + rewriter.getContext()), + // (m, n, p) -> (m, n) + mlir::AffineMap::get( + 3, 0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, + rewriter.getContext()), + }; + + // Create the iterator_types + llvm::SmallVector iteratorTypes{"parallel", "parallel", + "reduction"}; + + // Create the body of the `linalg.generic` op + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + // "HLFHE.mul_eint_int"(%a, %b) : (!HLFHE.eint

, ip') -> !HLFHE.eint

+ mlir::zamalang::HLFHE::MulEintIntOp mulEintIntOp = + nestedBuilder.create( + matmulOp.getLoc(), resultTy.getElementType(), blockArgs[0], + blockArgs[1]); + // "HLFHE.add_eint"(%c, %d): (!HLFHE.eint

, !HLFHE.eint

) -> + // !HLFHE.eint

+ mlir::zamalang::HLFHE::AddEintOp addEintOp = + nestedBuilder.create( + matmulOp.getLoc(), resultTy.getElementType(), blockArgs[2], + mulEintIntOp); + // linalg.yield %e : !HLFHE.eint

+ nestedBuilder.create(matmulOp.getLoc(), + addEintOp.getResult()); + }; + + // Create the `linalg.generic` op + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{matmulOp.lhs(), matmulOp.rhs()}; + llvm::SmallVector outs{init}; + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + mlir::linalg::GenericOp genericOp = + rewriter.create(matmulOp.getLoc(), resTypes, + ins, outs, maps, iteratorTypes, + doc, call, bodyBuilder); + + rewriter.replaceOp(matmulOp, {genericOp.getResult(0)}); + + return ::mlir::success(); + }; +}; + namespace { struct HLFHETensorOpsToLinalg : public HLFHETensorOpsToLinalgBase { @@ -477,6 +609,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() { &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/tests/Conversion/HLFHELinalgToLinalg/matmul_eint_int.mlir b/compiler/tests/Conversion/HLFHELinalgToLinalg/matmul_eint_int.mlir new file mode 100644 index 000000000..c8bed2cbb --- /dev/null +++ b/compiler/tests/Conversion/HLFHELinalgToLinalg/matmul_eint_int.mlir @@ -0,0 +1,25 @@ +// RUN: zamacompiler %s --action=dump-midlfhe --passes hlfhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-NEXT: #map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-NEXT: #map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-NEXT: module { +// CHECK-NEXT: func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> { +// CHECK-NEXT: %0 = tensor.generate { +// CHECK-NEXT: ^bb0(%arg2: index, %arg3: index): // no predecessors +// CHECK-NEXT: %2 = "HLFHE.zero"() : () -> !HLFHE.eint<2> +// CHECK-NEXT: tensor.yield %2 : !HLFHE.eint<2> +// CHECK-NEXT: } : tensor<3x2x!HLFHE.eint<2>> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) outs(%0 : tensor<3x2x!HLFHE.eint<2>>) { +// CHECK-NEXT: ^bb0(%arg2: !HLFHE.eint<2>, %arg3: i3, %arg4: !HLFHE.eint<2>): // no predecessors +// CHECK-NEXT: %2 = "HLFHE.mul_eint_int"(%arg2, %arg3) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> +// CHECK-NEXT: %3 = "HLFHE.add_eint"(%arg4, %2) : (!HLFHE.eint<2>, !HLFHE.eint<2>) -> !HLFHE.eint<2> +// CHECK-NEXT: linalg.yield %3 : !HLFHE.eint<2> +// CHECK-NEXT: } -> tensor<3x2x!HLFHE.eint<2>> +// CHECK-NEXT: return %1 : tensor<3x2x!HLFHE.eint<2>> +// CHECK-NEXT: } +// CHECK-NEXT: } +func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> + return %1 : tensor<3x2x!HLFHE.eint<2>> +} \ No newline at end of file diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index 9992e7c42..6c491eae9 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1114,3 +1114,60 @@ TEST(End2EndJit_HLFHELinalg, neg_eint) { } } } + +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg matmul_eint_int //////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_HLFHELinalg, matmul_eint_int) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the matrix multiplication of a 3x2 matrix of encrypted integers and a 2x3 matrix of integers. + // [ 1, 2, 3] + // [ 2, 3, 4] + // * + // [1,2] [ 5, 8,11] + // [3,4] = [11,18,25] + // [5,6] [17,28,39] + func @main(%a: tensor<3x2x!HLFHE.eint<6>>, %b: tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>> { + %0 = "HLFHELinalg.matmul_eint_int"(%a, %b) : (tensor<3x2x!HLFHE.eint<6>>, tensor<2x3xi7>) -> tensor<3x3x!HLFHE.eint<6>> + return %0 : tensor<3x3x!HLFHE.eint<6>> + } +)XXX", + "main", true); + const uint8_t A[3][2]{ + {1, 2}, + {3, 4}, + {5, 6}, + }; + const uint8_t B[2][3]{ + {1, 2, 3}, + {2, 3, 4}, + }; + const uint8_t expected[3][3]{ + {5, 8, 11}, + {11, 18, 25}, + {17, 28, 39}, + }; + + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + aArg(llvm::MutableArrayRef((uint8_t *)A, 3 * 2), {3, 2}); + mlir::zamalang::TensorLambdaArgument< + mlir::zamalang::IntLambdaArgument> + bArg(llvm::MutableArrayRef((uint8_t *)B, 2 * 3), {2, 3}); + + llvm::Expected> res = + lambda.operator()>({&aArg, &bArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +}