diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 0467814e4..98d433079 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -112,7 +112,8 @@ private: /// The _divisibility_ information maps the `d`-th /// dimension to the largest power-of-two that - /// divides the first element of all the values along it + /// divides the first element of all groups of + // _contiguity_ values along it /// For example: /// [10, 11, 12, 13, 18, 19, 20, 21] /// [20, 21, 22, 23, 28, 29, 30, 31] @@ -123,6 +124,10 @@ private: /// [14, 18, 22, 26] /// [15, 19, 23, 27] // would have divisibility [4, 1] + // On the other hand: + // [0, 1, 2, 0, 4, 5, 6, 7] + // would have divisibility 1 because + // _contiguity_=1 DimVectorT divisibility; /// The _constancy_ information maps the `d`-th diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index c575c5ceb..dba2101ee 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -334,14 +334,11 @@ private: if (lhs.getConstantValue().has_value() && lhs.getConstantValue().value() == 0) return lhs.getDivisibility(dim); - // Case 2: rhs is constant - if (rhs.getConstantValue().has_value()) { - auto lhsDivisibility = lhs.getDivisibility(dim); - auto rhsValue = rhs.getConstantValue().value(); - if (lhsDivisibility % rhsValue == 0) - return lhsDivisibility / rhsValue; - } - // Case 3: both are not constant + // Case 2: rhs is 1 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return lhs.getDivisibility(dim); + // otherwise: return 1 return 1; } diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index af8ea6f85..5d9f6e338 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -82,7 +82,7 @@ func.func @div() { %3 = arith.divui %1, %0 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %4 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = %5 = arith.divsi %0, %4 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %6 = arith.divsi %4, %0 : tensor<128xi32> @@ -94,11 +94,12 @@ func.func @div() { %9 = arith.divui %0, %8 : tensor<128xi32> // CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = %10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = %11 = arith.divsi %10, %4 : tensor<128xi32> return } + // ----- // CHECK-LABEL: @rem @@ -179,11 +180,11 @@ func.func @logic() { %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %1 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = %2 = arith.divsi %0, %1 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 %3 = arith.constant dense<8> : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = %4 = arith.divsi %0, %3 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %5 = arith.andi %0, %1 : tensor<128xi32>