mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Dedup elementwise in LLVM IR based on constancy (#2512)
### Summary When Triton GPU IR is lowered into LLVM IR, we can make use of the constancy information about the result of the elementwise ops to deduplicate otherwise redundant computation. That is the contribution of this PR: the constancy is checked and, if possible, some of the values in LLVM IR are reused multiple times instead of computing equal values separately. The change is beneficial for the PyTorch 2 / TorchInductor-generated Triton code, as the leftmost sub-indices extracted from the flat index by div / mod operations can be equal, given sufficiently large 2^n factor in the rightmost rightmost dimension(s). This makes the computation resulting in those sub-indices redundant. Consequently, under the necessary constancy conditions, the redundant indexing arithmetics can be deduplicated. We observe up to 29% decrease in the latency of some of our jagged tensor kernels
This commit is contained in:
72
test/Conversion/dedup-by-constancy.mlir
Normal file
72
test/Conversion/dedup-by-constancy.mlir
Normal file
@@ -0,0 +1,72 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --llvm-optimize-for-nvvm-target | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: dedup_by_constancy_full
|
||||
// CHECK-COUNT-5: llvm.add
|
||||
// CHECK-NOT: llvm.add
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK-NOT: llvm.icmp "slt"
|
||||
// CHECK: llvm.sdiv
|
||||
// CHECK-NOT: llvm.sdiv
|
||||
// CHECK: llvm.getelementptr %arg0[[[REGISTER:%[0-9]+]]]
|
||||
// CHECK-COUNT-7: llvm.getelementptr %arg0[[[REGISTER]]]
|
||||
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER]]]
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @dedup_by_constancy_full(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
|
||||
%cst = arith.constant dense<256> : tensor<1024xi32, #blocked>
|
||||
%c1024_i32 = arith.constant 1024 : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c1024_i32 : i32
|
||||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
|
||||
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
|
||||
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
|
||||
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
|
||||
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
|
||||
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
|
||||
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
|
||||
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
|
||||
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
|
||||
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: dedup_by_constancy_partial
|
||||
// CHECK-COUNT-8: llvm.add
|
||||
// CHECK-NOT: llvm.add
|
||||
// CHECK: llvm.icmp "slt"
|
||||
// CHECK-NOT: llvm.icmp "slt"
|
||||
// CHECK-COUNT-2: llvm.sdiv
|
||||
// CHECK-NOT: llvm.sdiv
|
||||
// CHECK: llvm.getelementptr %arg0[[[REGISTER1:%[0-9]+]]]
|
||||
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER1]]]
|
||||
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER1]]]
|
||||
// CHECK: llvm.getelementptr %arg0[[[REGISTER2:%[0-9]+]]]
|
||||
// CHECK-COUNT-3: llvm.getelementptr %arg0[[[REGISTER2]]]
|
||||
// CHECK-NOT: llvm.getelementptr %arg0[[[REGISTER2]]]
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
||||
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
tt.func public @dedup_by_constancy_partial(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
|
||||
%cst = arith.constant dense<4> : tensor<1024xi32, #blocked>
|
||||
%c1024_i32 = arith.constant 1024 : i32
|
||||
%0 = tt.get_program_id x : i32
|
||||
%1 = arith.muli %0, %c1024_i32 : i32
|
||||
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
|
||||
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
|
||||
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
|
||||
%5 = tt.splat %arg2 : (i32) -> tensor<1024xi32, #blocked>
|
||||
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
|
||||
%7 = arith.divsi %4, %cst : tensor<1024xi32, #blocked>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
|
||||
%9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
|
||||
%10 = tt.load %9, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked>
|
||||
%11 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<1024x!tt.ptr<f16, 1>, #blocked>
|
||||
%12 = tt.addptr %11, %4 : tensor<1024x!tt.ptr<f16, 1>, #blocked>, tensor<1024xi32, #blocked>
|
||||
tt.store %12, %10, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked>
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user