[ANALYSIS] Fix divisibility calculation for addptr (#1453)

This commit is contained in:
Keren Zhou
2023-03-31 17:57:31 -07:00
committed by GitHub
parent 859952a0aa
commit 801bb9d3b5
2 changed files with 71 additions and 5 deletions

View File

@@ -220,7 +220,16 @@ private:
// rhs = p * d_rhs = p * p' * gcd(d_lhs, d_rhs)
// lhs + rhs = k * d_lhs + p * d_rhs = (k * d_lhs + p * d_rhs) *
// gcd(d_lhs, d_rhs)
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim));
auto elemSize = 1;
if constexpr (std::is_same_v<OpTy, triton::AddPtrOp>) {
// %ptr = addptr %lhs, %rhs
// is equivalent to
// %0 = mul %lhs, %elemSize
// %ptr = add %0, %rhs
elemSize = std::max<unsigned int>(
1, triton::getPointeeBitWidth(op.getPtr().getType()) / 8);
}
return gcd(lhs.getDivisibility(dim), rhs.getDivisibility(dim) * elemSize);
}
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

View File

@@ -32,6 +32,63 @@ func.func @add() {
// -----
// CHECK-LABEL: @addptr
func.func @addptr(%arg0: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst1 = arith.constant 1 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%0 = tt.addptr %arg0, %cst1 : !tt.ptr<i1>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%1 = tt.addptr %arg1, %cst1 : !tt.ptr<i8>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = <none>
%2 = tt.addptr %arg2, %cst1 : !tt.ptr<i16>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%3 = tt.addptr %arg3, %cst1 : !tt.ptr<i32>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
%4 = tt.addptr %arg4, %cst1 : !tt.ptr<i64>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4
%cst4 = arith.constant 4 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%5 = tt.addptr %arg0, %cst4 : !tt.ptr<i1>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%6 = tt.addptr %arg1, %cst4 : !tt.ptr<i8>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
%7 = tt.addptr %arg2, %cst4 : !tt.ptr<i16>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
%8 = tt.addptr %arg3, %cst4 : !tt.ptr<i32>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = <none>
%9 = tt.addptr %arg4, %cst4 : !tt.ptr<i64>, i32
// CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%10 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
%11 = tt.expand_dims %10 {axis = 0: i32} : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none>
%12 = tt.broadcast %11 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
%13 = tt.splat %arg0 : (!tt.ptr<i1>) -> tensor<128x128x!tt.ptr<i1>>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
%14 = tt.splat %arg1 : (!tt.ptr<i8>) -> tensor<128x128x!tt.ptr<i8>>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
%15 = tt.splat %arg2 : (!tt.ptr<i16>) -> tensor<128x128x!tt.ptr<i16>>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
%16 = tt.splat %arg3 : (!tt.ptr<i32>) -> tensor<128x128x!tt.ptr<i32>>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = <none>
%17 = tt.splat %arg4 : (!tt.ptr<i64>) -> tensor<128x128x!tt.ptr<i64>>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = <none>
%18 = tt.addptr %13, %12 : tensor<128x128x!tt.ptr<i1>>, tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [128, 1], constant_value = <none>
%19 = tt.addptr %14, %12 : tensor<128x128x!tt.ptr<i8>>, tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [2, 16], constancy = [128, 1], constant_value = <none>
%20 = tt.addptr %15, %12 : tensor<128x128x!tt.ptr<i16>>, tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [4, 16], constancy = [128, 1], constant_value = <none>
%21 = tt.addptr %16, %12 : tensor<128x128x!tt.ptr<i32>>, tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [8, 16], constancy = [128, 1], constant_value = <none>
%22 = tt.addptr %17, %12 : tensor<128x128x!tt.ptr<i64>>, tensor<128x128xi32>
return
}
// -----
// CHECK-LABEL: @sub
func.func @sub() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
@@ -317,13 +374,13 @@ func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i
%8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = <none>
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = <none>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [4, 16], constancy = [1, 1], constant_value = <none>
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = <none>
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = <none>
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>, tensor<128x1xi32>
// CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = <none>
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
@@ -331,11 +388,11 @@ func.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none>
%16 = arith.muli %14, %15 : tensor<1x128xi32>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = <none>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = <none>
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none>
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = <none>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = <none>
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = <none>
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>