fix: use proper broadcasting condition during matmul to linalg generic

This commit is contained in:
Umut
2022-03-08 12:50:02 +03:00
parent 1dd8cfaf48
commit a1e4329ca8
2 changed files with 33 additions and 6 deletions

View File

@@ -924,13 +924,16 @@ struct FHELinalgMatmulToLinalgGeneric
// and finally, we add the AffineDimExpr corresponding to `N`
// which is at the last index of `iteratorTypes`
for (int64_t dim = outDims - lhsDims; dim < outDims - 1; dim++) {
if (lhsShape[dim] == 1) {
int64_t lhsDim = 0;
for (int64_t outDim = outDims - lhsDims; outDim < outDims - 1; outDim++) {
if (lhsDim < lhsDims - 2 && lhsShape[lhsDim] == 1) {
// broadcasted so current `dim` will always be indexed with `0`
lhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0));
} else {
lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(dim));
assert(lhsShape[lhsDim] == outShape[outDim]);
lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim));
}
lhsDim++;
}
lhsAffineExpressions.push_back(
rewriter.getAffineDimExpr(iteratorTypes.size() - 1));
@@ -965,13 +968,16 @@ struct FHELinalgMatmulToLinalgGeneric
// and finally, we add the AffineDimExpr corresponding to `N` and `P`
// which is at the last and one before last indices of `iteratorTypes`
for (int64_t dim = outDims - rhsDims; dim < outDims - 2; dim++) {
if (rhsShape[dim] == 1) {
int64_t rhsDim = 0;
for (int64_t outDim = outDims - rhsDims; outDim < outDims - 2; outDim++) {
if (rhsShape[rhsDim] == 1) {
// broadcasted so current `dim` will always be indexed with `0`
rhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0));
} else {
rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(dim));
assert(rhsShape[rhsDim] == outShape[outDim]);
rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim));
}
rhsDim++;
}
rhsAffineExpressions.push_back(
rewriter.getAffineDimExpr(iteratorTypes.size() - 1));

View File

@@ -296,6 +296,27 @@ func @main(%x: tensor<2x1x3x4x!FHE.eint<5>>, %y: tensor<5x4x2xi6>) -> tensor<2x5
// -----
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
// CHECK-NEXT: #[[m1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (0, d4, d3)>
// CHECK-NEXT: #[[m2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
// CHECK: func @main(%[[a0:.*]]: tensor<2x5x4x3x!FHE.eint<5>>, %[[a1:.*]]: tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x5x4x2x!FHE.eint<5>>
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m1]], #[[m2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%[[a0]], %[[a1]] : tensor<2x5x4x3x!FHE.eint<5>>, tensor<1x3x2xi6>) outs(%[[v0]] : tensor<2x5x4x2x!FHE.eint<5>>) {
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<5>, %[[aa1:.*]]: i6, %[[aa2:.*]]: !FHE.eint<5>): // no predecessors
// CHECK-NEXT: %[[vv0:.*]] = "FHE.mul_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.eint<5>, i6) -> !FHE.eint<5>
// CHECK-NEXT: %[[vv1:.*]] = "FHE.add_eint"(%[[aa2]], %[[vv0]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
// CHECK-NEXT: linalg.yield %[[vv1]] : !FHE.eint<5>
// CHECK-NEXT: } -> tensor<2x5x4x2x!FHE.eint<5>>
// CHECK-NEXT: return %[[v1]] : tensor<2x5x4x2x!FHE.eint<5>>
// CHECK-NEXT: }
func @main(%x: tensor<2x5x4x3x!FHE.eint<5>>, %y: tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>> {
%0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<2x5x4x3x!FHE.eint<5>>, tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>>
return %0 : tensor<2x5x4x2x!FHE.eint<5>>
}
// -----
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-NEXT: #[[m1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK-NEXT: #[[m2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>