mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ANALYSIS] Fix divisibility calculation for addptr (#1453)
This commit is contained in:
@@ -220,7 +220,16 @@ private:
|
||||
// rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs)
|
||||
// lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) *
|
||||
// gcd(d_lhs, d_rhs)
|
||||
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
|
||||
auto elemSize = 1;
|
||||
if constexpr (std::is_same_v<OpTy, triton::AddPtrOp>) {
|
||||
// %ptr = addptr %lhs, %rhs
|
||||
// is equivalent to
|
||||
// %0 = mul %lhs, %elemSize
|
||||
// %ptr = add %0, %rhs
|
||||
elemSize = std::max<unsigned int>(
|
||||
1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8);
|
||||
}
|
||||
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim) * elemSize);
|
||||
}
|
||||
|
||||
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
|
||||
|
||||
@@ -32,6 +32,63 @@ func.func @add() {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @addptr
|
||||
func.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
%cst1 = arith.constant 1 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%0 = tt.addptr %arg0, %cst1 : !tt.ptr<i1>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%1 = tt.addptr %arg1, %cst1 : !tt.ptr<i8>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = <none>
|
||||
%2 = tt.addptr %arg2, %cst1 : !tt.ptr<i16>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%3 = tt.addptr %arg3, %cst1 : !tt.ptr<i32>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
|
||||
%4 = tt.addptr %arg4, %cst1 : !tt.ptr<i64>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4
|
||||
%cst4 = arith.constant 4 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%5 = tt.addptr %arg0, %cst4 : !tt.ptr<i1>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%6 = tt.addptr %arg1, %cst4 : !tt.ptr<i8>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
|
||||
%7 = tt.addptr %arg2, %cst4 : !tt.ptr<i16>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
|
||||
%8 = tt.addptr %arg3, %cst4 : !tt.ptr<i32>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
|
||||
%9 = tt.addptr %arg4, %cst4 : !tt.ptr<i64>, i32
|
||||
// CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
|
||||
%11 = tt.expand_dims %10 {axis = 0: i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none>
|
||||
%12 = tt.broadcast %11 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
||||
%13 = tt.splat %arg0 : (!tt.ptr<i1>) -> tensor<128x128x!tt.ptr<i1>>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
||||
%14 = tt.splat %arg1 : (!tt.ptr<i8>) -> tensor<128x128x!tt.ptr<i8>>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
||||
%15 = tt.splat %arg2 : (!tt.ptr<i16>) -> tensor<128x128x!tt.ptr<i16>>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
||||
%16 = tt.splat %arg3 : (!tt.ptr<i32>) -> tensor<128x128x!tt.ptr<i32>>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
|
||||
%17 = tt.splat %arg4 : (!tt.ptr<i64>) -> tensor<128x128x!tt.ptr<i64>>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = <none>
|
||||
%18 = tt.addptr %13, %12 : tensor<128x128x!tt.ptr<i1>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = <none>
|
||||
%19 = tt.addptr %14, %12 : tensor<128x128x!tt.ptr<i8>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [2, 16], constancy = [128, 1], constant_value = <none>
|
||||
%20 = tt.addptr %15, %12 : tensor<128x128x!tt.ptr<i16>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [4, 16], constancy = [128, 1], constant_value = <none>
|
||||
%21 = tt.addptr %16, %12 : tensor<128x128x!tt.ptr<i32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [8, 16], constancy = [128, 1], constant_value = <none>
|
||||
%22 = tt.addptr %17, %12 : tensor<128x128x!tt.ptr<i64>>, tensor<128x128xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @sub
|
||||
func.func @sub() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
@@ -317,13 +374,13 @@ func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i
|
||||
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none>
|
||||
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [4, 16], constancy = [1, 1], constant_value = <none>
|
||||
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
|
||||
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = <none>
|
||||
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
|
||||
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
@@ -331,11 +388,11 @@ func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i
|
||||
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none>
|
||||
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = <none>
|
||||
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none>
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = <none>
|
||||
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
|
||||
Reference in New Issue
Block a user