[HOPPER][WS] fix missing WS attrs when lowering to llvm (#2063)

This commit is contained in:
allatit23
2023-08-09 15:45:44 +08:00
committed by GitHub
parent 1c45836d5d
commit 6d98a0899f
3 changed files with 32 additions and 11 deletions

View File

@@ -138,8 +138,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
barSync(*builder, op, barId, 128);
} else {
builder->create<gpu::BarrierOp>(op->getLoc());
blockInfo->sync();
}
blockInfo->sync();
return;
}
@@ -203,8 +203,8 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
barSync(*builder, op, barId, 128);
} else {
builder->create<gpu::BarrierOp>(op->getLoc());
blockInfo->sync();
}
blockInfo->sync();
}
// Update the region info, even if barrier is inserted, we have to maintain
// the current op's read/write buffers.

View File

@@ -49,6 +49,14 @@ namespace ttng = mlir::triton::nvidia_gpu;
namespace {
// pass ws related named attrs.
static void addWSNamedAttrs(Operation *op,
ArrayRef<mlir::NamedAttribute> attrs) {
for (const NamedAttribute attr : attrs)
if (attr.getName() == "async_agent" || attr.getName() == "agent.mutex_role")
op->setAttr(attr.getName(), attr.getValue());
}
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
public:
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx, bool isROCM)
@@ -617,10 +625,13 @@ private:
auto newCvtType = RankedTensorType::get(shape, F16Ty, cvtEncoding);
auto newArg = builder.create<mlir::triton::FpToFpOp>(
cvtOp.getLoc(), newArgType, cvtOp.getOperand());
addWSNamedAttrs(newArg, cvtOp->getAttrs());
auto newCvt = builder.create<mlir::triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), newCvtType, newArg);
addWSNamedAttrs(newCvt, cvtOp->getAttrs());
auto newRet = builder.create<mlir::triton::FpToFpOp>(
cvtOp.getLoc(), cvtOp.getType(), newCvt.getResult());
addWSNamedAttrs(newRet, cvtOp->getAttrs());
cvtOp.replaceAllUsesWith(newRet.getResult());
cvtOp.erase();
});
@@ -646,8 +657,10 @@ private:
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
addWSNamedAttrs(tmp, cvtOp->getAttrs());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
@@ -674,8 +687,10 @@ private:
srcType.getElementType()));
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), tmpType, cvtOp.getOperand());
addWSNamedAttrs(tmp, cvtOp->getAttrs());
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
cvtOp.getLoc(), dstType, tmp);
addWSNamedAttrs(newConvert, cvtOp->getAttrs());
cvtOp.replaceAllUsesWith(newConvert.getResult());
cvtOp.erase();
}
@@ -750,6 +765,7 @@ private:
/*boundaryCheck=*/nullptr, /*padding=*/nullptr,
insertSliceAsyncOp.getCache(), insertSliceAsyncOp.getEvict(),
insertSliceAsyncOp.getIsVolatile());
addWSNamedAttrs(loadOp, insertSliceAsyncOp->getAttrs());
// insert_slice
auto axis = insertSliceAsyncOp.getAxis();
@@ -765,6 +781,7 @@ private:
auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
insertSliceAsyncOp.getLoc(), loadOp, insertSliceAsyncOp.getDst(),
offsets, sizes, strides);
addWSNamedAttrs(insertSliceOp, insertSliceAsyncOp->getAttrs());
// Replace
insertSliceAsyncOp.replaceAllUsesWith(insertSliceOp.getResult());
@@ -784,7 +801,9 @@ private:
} else if (decomposed) {
// Wait for all previous async ops
OpBuilder builder(asyncWaitOp);
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
auto newWaitOp =
builder.create<triton::gpu::AsyncWaitOp>(asyncWaitOp.getLoc(), 0);
addWSNamedAttrs(newWaitOp, asyncWaitOp->getAttrs());
asyncWaitOp.erase();
}
});

View File

@@ -855,19 +855,11 @@ def full_static_persistent_matmul_kernel(
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
pytest.skip("known failure, will fix it later!!!")
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [
'128-128-128-4-1-256-256-192-none-float32-True-3-True',
]:
pytest.skip('out of resource: shared memory, Required: 263168')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, USE_TMA_STORE, ENABLE_WS])) in ([
'64-16-16-4-1-512-256-256-True-True',
] + [
f'128-128-64-4-1-{m}-{n}-{k}-True-True' for m in range(512, 4096, 360) for n in range(512, 4096, 360) for k in [512, 1024]
]):
pytest.skip('known kernel hang problem when tma store is enabled')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
@@ -876,6 +868,16 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if epilogue == 'chain-dot':
pytest.skip('known failure: Assertion !region.empty() && unexpected empty region.')