mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: use proper broadcasting condition during matmul to linalg generic
This commit is contained in:
@@ -924,13 +924,16 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
// and finally, we add the AffineDimExpr corresponding to `N`
|
||||
// which is at the last index of `iteratorTypes`
|
||||
|
||||
for (int64_t dim = outDims - lhsDims; dim < outDims - 1; dim++) {
|
||||
if (lhsShape[dim] == 1) {
|
||||
int64_t lhsDim = 0;
|
||||
for (int64_t outDim = outDims - lhsDims; outDim < outDims - 1; outDim++) {
|
||||
if (lhsDim < lhsDims - 2 && lhsShape[lhsDim] == 1) {
|
||||
// broadcasted so current `dim` will always be indexed with `0`
|
||||
lhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0));
|
||||
} else {
|
||||
lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(dim));
|
||||
assert(lhsShape[lhsDim] == outShape[outDim]);
|
||||
lhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim));
|
||||
}
|
||||
lhsDim++;
|
||||
}
|
||||
lhsAffineExpressions.push_back(
|
||||
rewriter.getAffineDimExpr(iteratorTypes.size() - 1));
|
||||
@@ -965,13 +968,16 @@ struct FHELinalgMatmulToLinalgGeneric
|
||||
// and finally, we add the AffineDimExpr corresponding to `N` and `P`
|
||||
// which is at the last and one before last indices of `iteratorTypes`
|
||||
|
||||
for (int64_t dim = outDims - rhsDims; dim < outDims - 2; dim++) {
|
||||
if (rhsShape[dim] == 1) {
|
||||
int64_t rhsDim = 0;
|
||||
for (int64_t outDim = outDims - rhsDims; outDim < outDims - 2; outDim++) {
|
||||
if (rhsShape[rhsDim] == 1) {
|
||||
// broadcasted so current `dim` will always be indexed with `0`
|
||||
rhsAffineExpressions.push_back(rewriter.getAffineConstantExpr(0));
|
||||
} else {
|
||||
rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(dim));
|
||||
assert(rhsShape[rhsDim] == outShape[outDim]);
|
||||
rhsAffineExpressions.push_back(rewriter.getAffineDimExpr(outDim));
|
||||
}
|
||||
rhsDim++;
|
||||
}
|
||||
rhsAffineExpressions.push_back(
|
||||
rewriter.getAffineDimExpr(iteratorTypes.size() - 1));
|
||||
|
||||
@@ -296,6 +296,27 @@ func @main(%x: tensor<2x1x3x4x!FHE.eint<5>>, %y: tensor<5x4x2xi6>) -> tensor<2x5
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4)>
|
||||
// CHECK-NEXT: #[[m1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (0, d4, d3)>
|
||||
// CHECK-NEXT: #[[m2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
|
||||
|
||||
// CHECK: func @main(%[[a0:.*]]: tensor<2x5x4x3x!FHE.eint<5>>, %[[a1:.*]]: tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<2x5x4x2x!FHE.eint<5>>
|
||||
// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[m0]], #[[m1]], #[[m2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%[[a0]], %[[a1]] : tensor<2x5x4x3x!FHE.eint<5>>, tensor<1x3x2xi6>) outs(%[[v0]] : tensor<2x5x4x2x!FHE.eint<5>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<5>, %[[aa1:.*]]: i6, %[[aa2:.*]]: !FHE.eint<5>): // no predecessors
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.mul_eint_int"(%[[aa0]], %[[aa1]]) : (!FHE.eint<5>, i6) -> !FHE.eint<5>
|
||||
// CHECK-NEXT: %[[vv1:.*]] = "FHE.add_eint"(%[[aa2]], %[[vv0]]) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
// CHECK-NEXT: linalg.yield %[[vv1]] : !FHE.eint<5>
|
||||
// CHECK-NEXT: } -> tensor<2x5x4x2x!FHE.eint<5>>
|
||||
// CHECK-NEXT: return %[[v1]] : tensor<2x5x4x2x!FHE.eint<5>>
|
||||
// CHECK-NEXT: }
|
||||
func @main(%x: tensor<2x5x4x3x!FHE.eint<5>>, %y: tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>> {
|
||||
%0 = "FHELinalg.matmul_eint_int"(%x, %y): (tensor<2x5x4x3x!FHE.eint<5>>, tensor<1x3x2xi6>) -> tensor<2x5x4x2x!FHE.eint<5>>
|
||||
return %0 : tensor<2x5x4x2x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[m0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-NEXT: #[[m1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK-NEXT: #[[m2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
Reference in New Issue
Block a user