From a1e4329ca8151b2d11cccf51b3bdcf9f4ac0edc9 Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 8 Mar 2022 12:50:02 +0300 Subject: [PATCH] fix: use proper broadcasting condition during matmul to linalg generic --- .../TensorOpsToLinalg.cpp | 18 ++++++++++------ .../FHELinalgToLinalg/matmul.mlir | 21 +++++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index fa629a85a..85b2153a5 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -924,13 +924,16 @@ struct FHELinalgMatmulToLinalgGeneric // and finally, we add the AffineDimExpr corresponding to `N` // which is at the last index of `iteratorTypes` - for (int64_t dim = outDims - lhsDims; dim < outDims - 1; dim++) { - if (lhsShape[dim] == 1) { + int64_t lhsDim = 0; + for (int64_t outDim = outDims - lhsDims; outDim < outDims - 1; outDim++) { + if (lhsDim < lhsDims - 2 && lhsShape[lhsDim] == 1) { // broadcasted so current `dim` will always be indexed with `0` lhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); } else { - lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(dim)); + assert(lhsShape[lhsDim] == outShape[outDim]); + lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim)); } + lhsDim++; } lhsAffineExpressions.push_back( rewriter.getAffineDimExpr(iteratorTypes.size() - 1)); @@ -965,13 +968,16 @@ struct FHELinalgMatmulToLinalgGeneric // and finally, we add the AffineDimExpr corresponding to `N` and `P` // which is at the last and one before last indices of `iteratorTypes` - for (int64_t dim = outDims - rhsDims; dim < outDims - 2; dim++) { - if (rhsShape[dim] == 1) { + int64_t rhsDim = 0; + for (int64_t outDim = outDims - rhsDims; outDim < outDims - 2; outDim++) { + if (rhsShape[rhsDim] == 1) { // broadcasted so current `dim` will always be indexed with `0` rhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); } else { - rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(dim)); + assert(rhsShape[rhsDim] == outShape[outDim]); + rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim)); } + rhsDim++; } rhsAffineExpressions.push_back( rewriter.getAffineDimExpr(iteratorTypes.size() - 1)); diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul.mlir index 06843ba33..75be0a5b2 100644 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul.mlir +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/matmul.mlir @@ -296,6 +296,27 @@ func @main(%x: tensor<2x1x3x4x!FHE.eint<5>>, %y: tensor<5x4x2xi6>) -> tensor<2x5 // ----- +// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)> +// CHECK-NEXT: #[[m1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (0, d4, d3)> +// CHECK-NEXT: #[[m2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)> + +// CHECK: func @main(%[[a0:.*]]: tensor<2x5x4x3x!FHE.eint<5>>, %[[a1:.*]]: tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x5x4x2x!FHE.eint<5>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m1]], #[[m2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%[[a0]], %[[a1]] : tensor<2x5x4x3x!FHE.eint<5>>, tensor<1x3x2xi6>) outs(%[[v0]] : tensor<2x5x4x2x!FHE.eint<5>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<5>, %[[aa1:.*]]: i6, %[[aa2:.*]]: !FHE.eint<5>): // no predecessors +// CHECK-NEXT: %[[vv0:.*]] = "FHE.mul_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.eint<5>, i6) -> !FHE.eint<5> +// CHECK-NEXT: %[[vv1:.*]] = "FHE.add_eint"(%[[aa2]], %[[vv0]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> +// CHECK-NEXT: linalg.yield %[[vv1]] : !FHE.eint<5> +// CHECK-NEXT: } -> tensor<2x5x4x2x!FHE.eint<5>> +// CHECK-NEXT: return %[[v1]] : tensor<2x5x4x2x!FHE.eint<5>> +// CHECK-NEXT: } +func @main(%x: tensor<2x5x4x3x!FHE.eint<5>>, %y: tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>> { + %0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<2x5x4x3x!FHE.eint<5>>, tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>> + return %0 : tensor<2x5x4x2x!FHE.eint<5>> +} + +// ----- + // CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> // CHECK-NEXT: #[[m1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> // CHECK-NEXT: #[[m2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>