[BACKEND] Dedup elementwise in LLVM IR based on constancy (#2512)

### Summary

When Triton GPU IR is lowered into LLVM IR, we can make use of the
constancy information about the result of the elementwise ops to
deduplicate otherwise redundant computation. That is the contribution of
this PR: the constancy is checked and, if possible, some of the values
in LLVM IR are reused multiple times instead of computing equal values
separately.

The change is beneficial for the PyTorch 2 / TorchInductor-generated
Triton code, as the leftmost sub-indices extracted from the flat index
by div / mod operations can be equal, given sufficiently large 2^n
factor in the rightmost rightmost dimension(s). This makes the
computation resulting in those sub-indices redundant. Consequently,
under the necessary constancy conditions, the redundant indexing
arithmetics can be deduplicated. We observe up to 29% decrease in the
latency of some of our jagged tensor kernels
This commit is contained in:
Adnan Akhundov
2023-10-25 17:25:29 +02:00
committed by GitHub
parent e70e11e834
commit 7d55968fee
2 changed files with 216 additions and 23 deletions

View File

@@ -0,0 +1,72 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --llvm-optimize-for-nvvm-target | FileCheck %s
// CHECK-LABEL: dedup_by_constancy_full
// CHECK-COUNT-5: llvm.add
// CHECK-NOT: llvm.add
// CHECK: llvm.icmp "slt"
// CHECK-NOT: llvm.icmp "slt"
// CHECK: llvm.sdiv
// CHECK-NOT: llvm.sdiv
// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]]
// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]]
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<256> : tensor<1024xi32, #blocked>
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
tt.return
}
}
// -----
// CHECK-LABEL: dedup_by_constancy_partial
// CHECK-COUNT-8: llvm.add
// CHECK-NOT: llvm.add
// CHECK: llvm.icmp "slt"
// CHECK-NOT: llvm.icmp "slt"
// CHECK-COUNT-2: llvm.sdiv
// CHECK-NOT: llvm.sdiv
// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]]
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]]
// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]]
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]]
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]]
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<4> : tensor<1024xi32, #blocked>
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
tt.return
}
}