[OPTIMIZER] Calculate a proper divisibility for ExpandDims (#2397)

Previously ExpandDims always inserts 1 as the new divisibility, which
makes writing (x * stride)[:, None] far more slower than (x[:, None] *
stride). A better divisibility can be afforded by computing the GCD of
the old dims. Now the two code above are equally fast. E.g. the conv
inductor in pytorch may be faster.

---------

Co-authored-by: Yuheng XIE <thinelephant@gmail.com>
This commit is contained in:
Yuheng XIE
2023-09-28 14:10:01 +08:00
committed by GitHub
parent 9073a393e0
commit 1e093fbfff
2 changed files with 67 additions and 20 deletions

View File

@@ -184,13 +184,28 @@ tt.func @rem() {
// -----
// CHECK-LABEL: @expanddims
tt.func @expanddims() {
// CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = <none>
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2
%1 = arith.constant dense<2> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = <none>
%2 = arith.muli %0, %1 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [2, 2], constancy = [1, 1], constant_value = <none>
%3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
tt.return
}
// -----
// CHECK-LABEL: @broadcast
tt.func @broadcast() {
// CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%0 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 1], constant_value = 64
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64
// CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 128], constant_value = 64
%2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32>
tt.return
}
@@ -290,9 +305,9 @@ tt.func @shift() {
%1 = arith.constant dense<8> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4
%2 = arith.constant dense<4> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [1], constant_value = <none>
%3 = arith.shli %0, %1 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%4 = arith.shrsi %0, %2 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
%5 = arith.shli %1, %2 : tensor<128xi32>
@@ -362,7 +377,7 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
%2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
%3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = <none>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none>
%4 = arith.muli %2, %3 : tensor<128x1xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
@@ -386,11 +401,11 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = <none>
%15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = <none>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = <none>
%16 = arith.muli %14, %15 : tensor<1x128xi32>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = <none>
%17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr<f32>>) -> tensor<128x128x!tt.ptr<f32>>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = <none>
// CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = <none>
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
// CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = <none>
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>, tensor<128x128xi32>