mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Now that the bug related to attribute is fixed in MLIR we can use arith ops for cmp and select ops.
32 lines
2.3 KiB
MLIR
32 lines
2.3 KiB
MLIR
// RUN: triton-translate %s --mlir-print-ir-after-all -o %t 2>&1 | FileCheck %s
|
|
|
|
// CHECK: IR Dump After SCFToControlFlow (convert-scf-to-cf)
|
|
// CHECK: tt.func public @add_kernel_0d1d2d3de
|
|
// CHECK: IR Dump After ConvertIndexToLLVMPass (convert-index-to-llvm)
|
|
// CHECK: tt.func public @add_kernel_0d1d2d3de
|
|
|
|
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
|
|
module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
|
tt.func public @add_kernel_0d1d2d3de(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
|
|
%c1024_i32 = arith.constant 1024 : i32
|
|
%0 = tt.get_program_id x : i32
|
|
%1 = arith.muli %0, %c1024_i32 : i32
|
|
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
|
|
%3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked>
|
|
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
|
|
%5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked>
|
|
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
|
|
%7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
|
|
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
|
|
%9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
|
|
%10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
|
|
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
|
|
%12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked>
|
|
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked>
|
|
%14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<1024x!tt.ptr<f32, 1>, #blocked>
|
|
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32, 1>, #blocked>, tensor<1024xi32, #blocked>
|
|
tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked>
|
|
tt.return
|
|
}
|
|
}
|