mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[OPTIMIZER] Fixed up divisibility analysis in div operation (#1341)
This commit is contained in:
@@ -112,7 +112,8 @@ private:
|
||||
|
||||
/// The _divisibility_ information maps the `d`-th
|
||||
/// dimension to the largest power-of-two that
|
||||
/// divides the first element of all the values along it
|
||||
/// divides the first element of all groups of
|
||||
// _contiguity_ values along it
|
||||
/// For example:
|
||||
/// [10, 11, 12, 13, 18, 19, 20, 21]
|
||||
/// [20, 21, 22, 23, 28, 29, 30, 31]
|
||||
@@ -123,6 +124,10 @@ private:
|
||||
/// [14, 18, 22, 26]
|
||||
/// [15, 19, 23, 27]
|
||||
// would have divisibility [4, 1]
|
||||
// On the other hand:
|
||||
// [0, 1, 2, 0, 4, 5, 6, 7]
|
||||
// would have divisibility 1 because
|
||||
// _contiguity_=1
|
||||
DimVectorT divisibility;
|
||||
|
||||
/// The _constancy_ information maps the `d`-th
|
||||
|
||||
@@ -334,14 +334,11 @@ private:
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
lhs.getConstantValue().value() == 0)
|
||||
return lhs.getDivisibility(dim);
|
||||
// Case 2: rhs is constant
|
||||
if (rhs.getConstantValue().has_value()) {
|
||||
auto lhsDivisibility = lhs.getDivisibility(dim);
|
||||
auto rhsValue = rhs.getConstantValue().value();
|
||||
if (lhsDivisibility % rhsValue == 0)
|
||||
return lhsDivisibility / rhsValue;
|
||||
}
|
||||
// Case 3: both are not constant
|
||||
// Case 2: rhs is 1
|
||||
if (rhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().value() == 1)
|
||||
return lhs.getDivisibility(dim);
|
||||
// otherwise: return 1
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ func.func @div() {
|
||||
%3 = arith.divui %1, %0 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%4 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
|
||||
%5 = arith.divsi %0, %4 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%6 = arith.divsi %4, %0 : tensor<128xi32>
|
||||
@@ -94,11 +94,12 @@ func.func @div() {
|
||||
%9 = arith.divui %0, %8 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none>
|
||||
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
|
||||
%11 = arith.divsi %10, %4 : tensor<128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @rem
|
||||
@@ -179,11 +180,11 @@ func.func @logic() {
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%1 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
|
||||
%2 = arith.divsi %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
|
||||
%3 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
|
||||
%4 = arith.divsi %0, %3 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%5 = arith.andi %0, %1 : tensor<128xi32>
|
||||
|
||||
Reference in New Issue
Block a user