mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Backend] Fix CTA->warp ordering for MMAv3 and fix dot-chain scripts in hopper tests (#2041)
Co-authored-by: goostavz <gzhu@nvidia.com> Co-authored-by: Philippe Tillet <phil@openai.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
This commit is contained in:
@@ -984,32 +984,6 @@ private:
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(warpsPerCTA.size() == 2);
|
||||
auto order = triton::gpu::getOrder(mmaLayout);
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
|
||||
Value offWarp0 = mul(multiDimWarpId[0], i32_val(16));
|
||||
Value offWarp1 = mul(multiDimWarpId[1], i32_val(8));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
|
||||
multiDimBase[1] = add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarp1);
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
|
||||
RankedTensorType type) const {
|
||||
@@ -1062,8 +1036,18 @@ private:
|
||||
else
|
||||
warpsN = shape[1] / instrShape[1];
|
||||
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, _warpsPerCTA, order);
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
if (mmaLayout.isHopper()) {
|
||||
// TODO[goostavz]: the tiling order from CTA->warp level is different for
|
||||
// MMAv2/3. This is a workaround since we don't explicitly have warpGrp
|
||||
// level in the layout definition, and the tiling order of warpGrp->warp
|
||||
// must be fixed to meet the HW's needs. We may need to consider to
|
||||
// explicitly define warpGrpPerCTA for MMAv3 layout.
|
||||
multiDimWarpId[0] = urem(warpId, warpsPerCTA[0]);
|
||||
multiDimWarpId[1] = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
|
||||
} else {
|
||||
multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, order);
|
||||
}
|
||||
Value warpId0 = urem(multiDimWarpId[0], i32_val(warpsM));
|
||||
Value warpId1 = urem(multiDimWarpId[1], i32_val(warpsN));
|
||||
|
||||
|
||||
@@ -149,6 +149,7 @@ def matmul_kernel(
|
||||
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
|
||||
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
|
||||
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
|
||||
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr,
|
||||
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
@@ -167,8 +168,9 @@ def matmul_kernel(
|
||||
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
|
||||
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
|
||||
# for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix
|
||||
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
|
||||
offsets=(0, block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(Z_ORDER_1, Z_ORDER_0))
|
||||
offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1))
|
||||
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
|
||||
@@ -213,7 +215,7 @@ def matmul_kernel(
|
||||
tl.store(z_ptrs, z, mask=mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_C,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
|
||||
[(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
|
||||
for shape_w_c in [
|
||||
# badcase from cublas-important-layers
|
||||
@@ -225,7 +227,7 @@ def matmul_kernel(
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for enable_ws in [False, True]
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# softmax works for one CTA
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 64, 64, 64],
|
||||
@@ -239,10 +241,10 @@ def matmul_kernel(
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_c in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 16, 4, 1, 128, 128, 64],
|
||||
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
|
||||
@@ -264,11 +266,11 @@ def matmul_kernel(
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_c in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
|
||||
] + [(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
for shape_w_c in [
|
||||
[64, 64, 32, 4, 1, 128, 256, 64],
|
||||
[128, 128, 16, 4, 4, 512, 256, 64],
|
||||
@@ -287,18 +289,18 @@ def matmul_kernel(
|
||||
for use_tma_store in [False, True]
|
||||
for trans_a in [False, True]
|
||||
for trans_b in [False, True]
|
||||
for trans_c in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for num_stages in [3]
|
||||
for enable_ws in [False, True]
|
||||
] + [(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
] + [(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# loop over instr shapes
|
||||
for n in [16, 32, 64, 128, 256]
|
||||
for trans_c in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [2, 4, 5, 7]
|
||||
for enable_ws in [False, True]
|
||||
] + [(*shape_w_c, *shape, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
] + [(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
|
||||
# irregular shapes
|
||||
for shape_w_c in [
|
||||
[128, 128, 64, 4, 1],
|
||||
@@ -306,7 +308,7 @@ def matmul_kernel(
|
||||
[128, 128, 128, 4, 2],
|
||||
]
|
||||
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
|
||||
for trans_c in [False, True]
|
||||
for trans_output in [False, True]
|
||||
for out_dtype in ['float16', 'float32']
|
||||
for use_tma_store in [False, True]
|
||||
for num_stages in [2, 3, 4]
|
||||
@@ -314,7 +316,7 @@ def matmul_kernel(
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()
|
||||
[0] < 9, reason="Requires compute capability >= 9")
|
||||
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_C, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
|
||||
'16-32-64-4-4-512-256-64-True-False',
|
||||
'16-32-64-4-4-512-256-64-True-True',
|
||||
@@ -331,13 +333,7 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
'16-32-64-8-2-256-256-256-False',
|
||||
'16-32-64-8-2-256-256-256-True',
|
||||
]:
|
||||
pytest.skip('illegal memory access.')
|
||||
|
||||
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
|
||||
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K])) in [
|
||||
'64-64-32-8-1-128-256-64',
|
||||
]:
|
||||
pytest.skip('Tensor-likes are not close!')
|
||||
pytest.skip('Known legacy issue, ldmatrix can only support x4')
|
||||
|
||||
if NUM_CTAS > 1 and NUM_WARPS == 8:
|
||||
pytest.skip('Tensor-likes are not close!')
|
||||
@@ -378,27 +374,23 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
|
||||
# avoid out of memory
|
||||
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
|
||||
if (TRANS_C):
|
||||
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
|
||||
else:
|
||||
if (TRANS_OUTPUT):
|
||||
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
|
||||
else:
|
||||
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
|
||||
else:
|
||||
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
|
||||
|
||||
if epilogue == 'chain-dot':
|
||||
if (TRANS_C):
|
||||
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
|
||||
else:
|
||||
w = torch.randn((M, M), device='cuda', dtype=torch.float16)
|
||||
else:
|
||||
w = torch.randn((1, 1), device='cuda', dtype=torch.float16).T
|
||||
# for chain-dot only
|
||||
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
|
||||
w_order = [0, 1]
|
||||
|
||||
if (TRANS_C):
|
||||
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
|
||||
z_order = [1, 0]
|
||||
else:
|
||||
if (TRANS_OUTPUT):
|
||||
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
|
||||
z_order = [0, 1]
|
||||
else:
|
||||
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
|
||||
z_order = [1, 0]
|
||||
|
||||
# torch result
|
||||
a_f32 = a.to(torch.float32)
|
||||
@@ -442,6 +434,7 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
|
||||
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
|
||||
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1],
|
||||
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
|
||||
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
|
||||
enable_warp_specialization=ENABLE_WS)
|
||||
|
||||
Reference in New Issue
Block a user