[OPTIMIZER] Fixed up divisibility analysis in div operation (#1341)

This commit is contained in:
Philippe Tillet
2023-03-14 18:17:05 -07:00
committed by GitHub
parent da0b0bfde6
commit 082828af47
3 changed files with 16 additions and 13 deletions

View File

@@ -112,7 +112,8 @@ private:
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
/// divides the first element of all groups of
// _contiguity_ values along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
@@ -123,6 +124,10 @@ private:
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
// On the other hand:
// [0, 1, 2, 0, 4, 5, 6, 7]
// would have divisibility 1 because
// _contiguity_=1
DimVectorT divisibility;
/// The _constancy_ information maps the `d`-th

View File

@@ -334,14 +334,11 @@ private:
if (lhs.getConstantValue().has_value() &&
lhs.getConstantValue().value() == 0)
return lhs.getDivisibility(dim);
// Case 2: rhs is constant
if (rhs.getConstantValue().has_value()) {
auto lhsDivisibility = lhs.getDivisibility(dim);
auto rhsValue = rhs.getConstantValue().value();
if (lhsDivisibility % rhsValue == 0)
return lhsDivisibility / rhsValue;
}
// Case 3: both are not constant
// Case 2: rhs is 1
if (rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1)
return lhs.getDivisibility(dim);
// otherwise: return 1
return 1;
}

View File

@@ -82,7 +82,7 @@ func.func @div() {
%3 = arith.divui %1, %0 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%4 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
%5 = arith.divsi %0, %4 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%6 = arith.divsi %4, %0 : tensor<128xi32>
@@ -94,11 +94,12 @@ func.func @div() {
%9 = arith.divui %0, %8 : tensor<128xi32>
// CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none>
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
%11 = arith.divsi %10, %4 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: @rem
@@ -179,11 +180,11 @@ func.func @logic() {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%1 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
%2 = arith.divsi %0, %1 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
%3 = arith.constant dense<8> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
%4 = arith.divsi %0, %3 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%5 = arith.andi %0, %1 : tensor<128xi32>