mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Calculate a proper divisibility for ExpandDims (#2397)
Previously ExpandDims always inserts 1 as the new divisibility, which makes writing (x * stride)[:, None] far more slower than (x[:, None] * stride). A better divisibility can be afforded by computing the GCD of the old dims. Now the two code above are equally fast. E.g. the conv inductor in pytorch may be faster. --------- Co-authored-by: Yuheng XIE <thinelephant@gmail.com>
This commit is contained in:
@@ -184,13 +184,28 @@ tt.func @rem() {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @expanddims
|
||||
tt.func @expanddims() {
|
||||
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2
|
||||
%1 = arith.constant dense<2> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = <none>
|
||||
%2 = arith.muli %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [2, 2], constancy = [1, 1], constant_value = <none>
|
||||
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @broadcast
|
||||
tt.func @broadcast() {
|
||||
// CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
|
||||
%0 = arith.constant dense<64> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 1], constant_value = 64
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 128], constant_value = 64
|
||||
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
|
||||
tt.return
|
||||
}
|
||||
@@ -290,9 +305,9 @@ tt.func @shift() {
|
||||
%1 = arith.constant dense<8> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
|
||||
%2 = arith.constant dense<4> : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [1], constant_value = <none>
|
||||
%3 = arith.shli %0, %1 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
|
||||
%4 = arith.shrsi %0, %2 : tensor<128xi32>
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
|
||||
%5 = arith.shli %1, %2 : tensor<128xi32>
|
||||
@@ -362,7 +377,7 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
|
||||
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
||||
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none>
|
||||
%4 = arith.muli %2, %3 : tensor<128x1xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
@@ -386,11 +401,11 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
|
||||
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
|
||||
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none>
|
||||
%16 = arith.muli %14, %15 : tensor<1x128xi32>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = <none>
|
||||
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none>
|
||||
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = <none>
|
||||
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>
|
||||
|
||||
Reference in New Issue
Block a user