mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[HOPPER][WS] fix missing WS attrs when lowering to llvm (#2063)
This commit is contained in:
@@ -138,8 +138,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
barSync(*builder, op, barId, 128);
|
||||
} else {
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
blockInfo->sync();
|
||||
}
|
||||
blockInfo->sync();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -203,8 +203,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
barSync(*builder, op, barId, 128);
|
||||
} else {
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
blockInfo->sync();
|
||||
}
|
||||
blockInfo->sync();
|
||||
}
|
||||
// Update the region info, even if barrier is inserted, we have to maintain
|
||||
// the current op's read/write buffers.
|
||||
|
||||
@@ -49,6 +49,14 @@ namespace ttng = mlir::triton::nvidia_gpu;
|
||||
|
||||
namespace {
|
||||
|
||||
// pass ws related named attrs.
|
||||
static void addWSNamedAttrs(Operation *op,
|
||||
ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
for (const NamedAttribute attr : attrs)
|
||||
if (attr.getName() == "async_agent" || attr.getName() == "agent.mutex_role")
|
||||
op->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, bool isROCM)
|
||||
@@ -617,10 +625,13 @@ private:
|
||||
auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding);
|
||||
auto newArg = builder.create<mlir::triton::FpToFpOp>(
|
||||
cvtOp.getLoc(), newArgType, cvtOp.getOperand());
|
||||
addWSNamedAttrs(newArg, cvtOp->getAttrs());
|
||||
auto newCvt = builder.create<mlir::triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), newCvtType, newArg);
|
||||
addWSNamedAttrs(newCvt, cvtOp->getAttrs());
|
||||
auto newRet = builder.create<mlir::triton::FpToFpOp>(
|
||||
cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult());
|
||||
addWSNamedAttrs(newRet, cvtOp->getAttrs());
|
||||
cvtOp.replaceAllUsesWith(newRet.getResult());
|
||||
cvtOp.erase();
|
||||
});
|
||||
@@ -646,8 +657,10 @@ private:
|
||||
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
addWSNamedAttrs(tmp, cvtOp->getAttrs());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
@@ -674,8 +687,10 @@ private:
|
||||
srcType.getElementType()));
|
||||
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
|
||||
addWSNamedAttrs(tmp, cvtOp->getAttrs());
|
||||
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
cvtOp.getLoc(), dstType, tmp);
|
||||
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
|
||||
cvtOp.replaceAllUsesWith(newConvert.getResult());
|
||||
cvtOp.erase();
|
||||
}
|
||||
@@ -750,6 +765,7 @@ private:
|
||||
/*boundaryCheck=*/nullptr, /*padding=*/nullptr,
|
||||
insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(),
|
||||
insertSliceAsyncOp.getIsVolatile());
|
||||
addWSNamedAttrs(loadOp, insertSliceAsyncOp->getAttrs());
|
||||
|
||||
// insert_slice
|
||||
auto axis = insertSliceAsyncOp.getAxis();
|
||||
@@ -765,6 +781,7 @@ private:
|
||||
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
|
||||
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.getDst(),
|
||||
offsets, sizes, strides);
|
||||
addWSNamedAttrs(insertSliceOp, insertSliceAsyncOp->getAttrs());
|
||||
|
||||
// Replace
|
||||
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
|
||||
@@ -784,7 +801,9 @@ private:
|
||||
} else if (decomposed) {
|
||||
// Wait for all previous async ops
|
||||
OpBuilder builder(asyncWaitOp);
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
auto newWaitOp =
|
||||
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
|
||||
addWSNamedAttrs(newWaitOp, asyncWaitOp->getAttrs());
|
||||
asyncWaitOp.erase();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -855,19 +855,11 @@ def full_static_persistent_matmul_kernel(
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
pytest.skip("known failure, will fix it later!!!")
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
|
||||
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
|
||||
]:
|
||||
pytest.skip('out of resource: shared memory, Required: 263168')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, USE_TMA_STORE, ENABLE_WS])) in ([
|
||||
'64-16-16-4-1-512-256-256-True-True',
|
||||
] + [
|
||||
f'128-128-64-4-1-{m}-{n}-{k}-True-True' for m in range(512, 4096, 360) for n in range(512, 4096, 360) for k in [512, 1024]
|
||||
]):
|
||||
pytest.skip('known kernel hang problem when tma store is enabled')
|
||||
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
@@ -876,6 +868,16 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
|
||||
]:
|
||||
pytest.skip('shapePerCTA[1] < 16 not supported')
|
||||
|
||||
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
|
||||
'16-32-64-4-1-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-False',
|
||||
'16-32-64-4-2-256-256-256-True',
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
]:
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
|
||||
if epilogue == 'chain-dot':
|
||||
pytest.skip('known failure: Assertion !region.empty() && unexpected empty region.')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user