mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Replace inline assembly in commonShflSync with intrinsics (#418)
Inline assembly does not take into account instructions around, and in general can not avoid data hazards. Replacing inline asm with intrinsics solves this problem. This particular code behaved incorrectly in one of mfma dot tests: Code generated with help of inline assembly: ``` v_mfma_f32_4x4x4f16 v[4:7], v[4:5], v[6:7], 0 ds_swizzle_b32 v3, v4, offset:swizzle(SWAP:4) ``` Correct code generated with intrinsics: ``` v_mfma_f32_4x4x4f16 v[4:7], v[4:5], v[6:7], 0 s_nop 4 ds_swizzle_b32 v3, v4, offset:swizzle(SWAP:4) ```
This commit is contained in:
@@ -2,6 +2,10 @@
|
||||
#include "TypeConverter.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "triton/Dialect/NVGPU/IR/Dialect.h"
|
||||
|
||||
#if USE_ROCM
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#endif
|
||||
namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
@@ -286,10 +290,20 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
#ifdef USE_ROCM
|
||||
//On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords
|
||||
//so we need promote to 32 here.
|
||||
if (bits == 8) {
|
||||
Value i32Val = sext(i32_ty, val);
|
||||
Value result = commonShflSync(loc, rewriter, i32Val, i, strideInt, mode, clamp, laneId);
|
||||
return trunc(i8_ty, result);
|
||||
auto valType = val.getType();
|
||||
if (!valType.isInteger(32) && bits <= 32) {
|
||||
if (!valType.isIntOrIndex())
|
||||
val = bitcast(val, int_ty(bits));
|
||||
if (bits < 32)
|
||||
val = sext(i32_ty, val);
|
||||
|
||||
val = commonShflSync(loc, rewriter, val, i, strideInt, mode, clamp, laneId);
|
||||
|
||||
if (bits < 32)
|
||||
val = trunc(int_ty(bits), val);
|
||||
if (!valType.isIntOrIndex())
|
||||
val = bitcast(val, valType);
|
||||
return val;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -307,20 +321,12 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
GCNBuilder builder;
|
||||
|
||||
auto permute = [&](Value lane, StringRef permuteInstStr) {
|
||||
assert(permuteInstStr == "ds_permute_b32" ||
|
||||
permuteInstStr == "ds_bpermute_b32");
|
||||
auto bpermute = [&](Value lane) {
|
||||
// Multiple lineId by 4. (More on permute instruction semantics:
|
||||
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180
|
||||
Value byteOffset = i32_val(2);
|
||||
Value permuteAddr = shl(lane, byteOffset);
|
||||
auto shfl = builder.create(permuteInstStr.str());
|
||||
auto dOpr = builder.newOperand("=v");
|
||||
auto addrOpr = builder.newOperand(permuteAddr, "v");
|
||||
auto aOpr = builder.newOperand(val, "v");
|
||||
(*shfl)(dOpr, addrOpr, aOpr);
|
||||
return rewriter.create<ROCDL::DsBpermuteOp>(loc, valType, permuteAddr, val);
|
||||
};
|
||||
|
||||
switch (mode) {
|
||||
@@ -334,39 +340,30 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
|
||||
.getResult(0);
|
||||
Value stride = i32_val(32);
|
||||
Value lineId = add(threadId, stride);
|
||||
permute(lineId, "ds_permute_b32");
|
||||
Value lineId = xor_(threadId, stride);
|
||||
return bpermute(lineId);
|
||||
} else {
|
||||
// This map facilates the butterfly shuffle pattern for a stride less
|
||||
// than 16. The pattern stride is the key of the map.
|
||||
DenseMap<short, unsigned int> masks{
|
||||
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
|
||||
auto shfl = builder.create("ds_swizzle_b32");
|
||||
auto dOpr = builder.newOperand("=v");
|
||||
auto aOpr = builder.newOperand(val, "v");
|
||||
auto maskOpr =
|
||||
builder.newConstantOperand("offset:" + std::to_string(masks[strideInt]));
|
||||
(*shfl)(dOpr, aOpr, maskOpr);
|
||||
Value offset = i32_val(masks[strideInt]);
|
||||
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
|
||||
}
|
||||
break;
|
||||
case NVVM::ShflKind::up: {
|
||||
Value mask = icmp_slt(laneId, i);
|
||||
Value delta = sub(laneId, i);
|
||||
Value index = select(mask, laneId, delta);
|
||||
permute(index, "ds_bpermute_b32");
|
||||
break;
|
||||
return bpermute(index);
|
||||
}
|
||||
case NVVM::ShflKind::idx:
|
||||
permute(i, "ds_bpermute_b32");
|
||||
break;
|
||||
return bpermute(i);
|
||||
default:
|
||||
assert(false && "Unsupported ShflKind");
|
||||
break;
|
||||
}
|
||||
|
||||
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
|
||||
(*swait)();
|
||||
return builder.launch(rewriter, loc, val.getType(), true);
|
||||
return Value();
|
||||
#else
|
||||
Type type = val.getType();
|
||||
if (type != i32_ty) {
|
||||
|
||||
@@ -1961,11 +1961,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
|
||||
// PTX: nvvm.shfl.sync bfly
|
||||
// PTX: nvvm.barrier0
|
||||
|
||||
// GCN-COUNT-4: ds_swizzle_b32
|
||||
// GCN-COUNT-4: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
|
||||
// GCN: llvm.store
|
||||
// GCN: rocdl.barrier
|
||||
// GCN: llvm.load
|
||||
// GCN-COUNT-2: ds_swizzle_b32
|
||||
// GCN-COUNT-2: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
|
||||
// GCN: llvm.store
|
||||
// GCN: rocdl.barrier
|
||||
// GCN: llvm.load
|
||||
|
||||
Reference in New Issue
Block a user