Replace inline assembly in commonShflSync with intrinsics (#418)

Inline assembly does not take into account instructions around,
and in general can not avoid data hazards.
Replacing inline asm with intrinsics solves this problem.
This particular code behaved incorrectly in one of mfma dot tests:

Code generated with help of inline assembly:

```
  v_mfma_f32_4x4x4f16 v[4:7], v[4:5], v[6:7], 0
  ds_swizzle_b32 v3, v4, offset:swizzle(SWAP:4)
```

Correct code generated with intrinsics:

```
  v_mfma_f32_4x4x4f16 v[4:7], v[4:5], v[6:7], 0
  s_nop 4
  ds_swizzle_b32 v3, v4, offset:swizzle(SWAP:4)
```
This commit is contained in:
Alexander Efimov
2023-12-11 16:41:39 +01:00
committed by GitHub
parent 2be6ec771e
commit a944811b6d
2 changed files with 29 additions and 32 deletions

View File

@@ -2,6 +2,10 @@
#include "TypeConverter.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#if USE_ROCM
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#endif
namespace mlir {
namespace LLVM {
@@ -286,10 +290,20 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
#ifdef USE_ROCM
//On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on 32bit/dwords
//so we need promote to 32 here.
if (bits == 8) {
Value i32Val = sext(i32_ty, val);
Value result = commonShflSync(loc, rewriter, i32Val, i, strideInt, mode, clamp, laneId);
return trunc(i8_ty, result);
auto valType = val.getType();
if (!valType.isInteger(32) && bits <= 32) {
if (!valType.isIntOrIndex())
val = bitcast(val, int_ty(bits));
if (bits < 32)
val = sext(i32_ty, val);
val = commonShflSync(loc, rewriter, val, i, strideInt, mode, clamp, laneId);
if (bits < 32)
val = trunc(int_ty(bits), val);
if (!valType.isIntOrIndex())
val = bitcast(val, valType);
return val;
}
#endif
@@ -307,20 +321,12 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
}
#ifdef USE_ROCM
GCNBuilder builder;
auto permute = [&](Value lane, StringRef permuteInstStr) {
assert(permuteInstStr == "ds_permute_b32" ||
permuteInstStr == "ds_bpermute_b32");
auto bpermute = [&](Value lane) {
// Multiple lineId by 4. (More on permute instruction semantics:
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf#page=180
Value byteOffset = i32_val(2);
Value permuteAddr = shl(lane, byteOffset);
auto shfl = builder.create(permuteInstStr.str());
auto dOpr = builder.newOperand("=v");
auto addrOpr = builder.newOperand(permuteAddr, "v");
auto aOpr = builder.newOperand(val, "v");
(*shfl)(dOpr, addrOpr, aOpr);
return rewriter.create<ROCDL::DsBpermuteOp>(loc, valType, permuteAddr, val);
};
switch (mode) {
@@ -334,39 +340,30 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter,
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
.getResult(0);
Value stride = i32_val(32);
Value lineId = add(threadId, stride);
permute(lineId, "ds_permute_b32");
Value lineId = xor_(threadId, stride);
return bpermute(lineId);
} else {
// This map facilates the butterfly shuffle pattern for a stride less
// than 16. The pattern stride is the key of the map.
DenseMap<short, unsigned int> masks{
{16, 0x401F}, {8, 0x201F}, {4, 0x101F}, {2, 0x081F}, {1, 0x041F}};
auto shfl = builder.create("ds_swizzle_b32");
auto dOpr = builder.newOperand("=v");
auto aOpr = builder.newOperand(val, "v");
auto maskOpr =
builder.newConstantOperand("offset:" + std::to_string(masks[strideInt]));
(*shfl)(dOpr, aOpr, maskOpr);
Value offset = i32_val(masks[strideInt]);
return rewriter.create<ROCDL::DsSwizzleOp>(loc, valType, val, offset);
}
break;
case NVVM::ShflKind::up: {
Value mask = icmp_slt(laneId, i);
Value delta = sub(laneId, i);
Value index = select(mask, laneId, delta);
permute(index, "ds_bpermute_b32");
break;
return bpermute(index);
}
case NVVM::ShflKind::idx:
permute(i, "ds_bpermute_b32");
break;
return bpermute(i);
default:
assert(false && "Unsupported ShflKind");
break;
}
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
(*swait)();
return builder.launch(rewriter, loc, val.getType(), true);
return Value();
#else
Type type = val.getType();
if (type != i32_ty) {

View File

@@ -1961,11 +1961,11 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// PTX: nvvm.shfl.sync bfly
// PTX: nvvm.barrier0
// GCN-COUNT-4: ds_swizzle_b32
// GCN-COUNT-4: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
// GCN: llvm.store
// GCN: rocdl.barrier
// GCN: llvm.load
// GCN-COUNT-2: ds_swizzle_b32
// GCN-COUNT-2: rocdl.ds_swizzle %{{.*}} : (i32, i32) -> i32
// GCN: llvm.store
// GCN: rocdl.barrier
// GCN: llvm.load