mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Init values before load to avoid ptxas issues (#1396)
This commit is contained in:
@@ -122,6 +122,7 @@ struct LoadOpConversion
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
const size_t movWidth = width < 16 ? 16 : width;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
@@ -137,11 +138,18 @@ struct LoadOpConversion
|
||||
const std::string writeConstraint =
|
||||
(width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
|
||||
PTXInstr &init =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
PTXInstr::Operand *zero = ptxBuilder.newConstantOperand(0);
|
||||
|
||||
// prepare asm operands
|
||||
auto *dstsOpr = ptxBuilder.newListOperand();
|
||||
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
|
||||
auto *opr = ptxBuilder.newOperand(writeConstraint); // =r operations
|
||||
dstsOpr->listAppend(opr);
|
||||
// Initialize the destination register, otherwise the register will
|
||||
// be undefined if the predicate is false.
|
||||
init(opr, zero);
|
||||
}
|
||||
|
||||
auto *addrOpr =
|
||||
@@ -175,7 +183,6 @@ struct LoadOpConversion
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
// PTX doesn't support mov.u8, so we need to use mov.u16
|
||||
auto movWidth = width < 16 ? 16 : width;
|
||||
PTXInstr &mov =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file --convert-scf-to-cf --convert-triton-gpu-to-llvm | FileCheck %s
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||
@@ -97,16 +97,24 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
|
||||
|
||||
// Load 4 elements from vector0
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
|
||||
// Load 4 elements from vector1
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
// CHECK: mov.u32 $0, 0x0
|
||||
// CHECK: @${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
|
||||
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
|
||||
Reference in New Issue
Block a user