[Backend] Fix CTA->warp ordering for MMAv3 and fix dot-chain scripts in hopper tests (#2041)

Co-authored-by: goostavz <gzhu@nvidia.com>
Co-authored-by: Philippe Tillet <phil@openai.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
This commit is contained in:
goostavz
2023-08-08 14:30:04 +08:00
committed by GitHub
parent a76ecd74e7
commit b525880d8b
2 changed files with 40 additions and 63 deletions

View File

@@ -984,32 +984,6 @@ private:
return ret;
}
SmallVector<Value>
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
auto shape = type.getShape();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
assert(warpsPerCTA.size() == 2);
auto order = triton::gpu::getOrder(mmaLayout);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
Value offWarp0 = mul(multiDimWarpId[0], i32_val(16));
Value offWarp1 = mul(multiDimWarpId[1], i32_val(8));
SmallVector<Value> multiDimBase(2);
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
multiDimBase[1] = add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarp1);
return multiDimBase;
}
SmallVector<SmallVector<unsigned>>
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
RankedTensorType type) const {
@@ -1062,8 +1036,18 @@ private:
else
warpsN = shape[1] / instrShape[1];
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, _warpsPerCTA, order);
SmallVector<Value> multiDimWarpId(2);
if (mmaLayout.isHopper()) {
// TODO[goostavz]: the tiling order from CTA->warp level is different for
// MMAv2/3. This is a workaround since we don't explicitly have warpGrp
// level in the layout definition, and the tiling order of warpGrp->warp
// must be fixed to meet the HW's needs. We may need to consider to
// explicitly define warpGrpPerCTA for MMAv3 layout.
multiDimWarpId[0] = urem(warpId, warpsPerCTA[0]);
multiDimWarpId[1] = urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]);
} else {
multiDimWarpId = delinearize(rewriter, loc, warpId, _warpsPerCTA, order);
}
Value warpId0 = urem(multiDimWarpId[0], i32_val(warpsM));
Value warpId1 = urem(multiDimWarpId[1], i32_val(warpsN));

View File

@@ -149,6 +149,7 @@ def matmul_kernel(
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr,
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr,
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr,
W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr,
Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr
):
pid = tl.program_id(axis=0)
@@ -167,8 +168,9 @@ def matmul_kernel(
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1))
# for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix
w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn),
offsets=(0, block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(Z_ORDER_1, Z_ORDER_0))
offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1))
z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
offs_m = block_offset_m + tl.arange(0, BLOCK_M)
@@ -213,7 +215,7 @@ def matmul_kernel(
tl.store(z_ptrs, z, mask=mask)
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_C,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS',
[(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws)
for shape_w_c in [
# badcase from cublas-important-layers
@@ -225,7 +227,7 @@ def matmul_kernel(
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for enable_ws in [False, True]
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
] + [(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
# softmax works for one CTA
for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
@@ -239,10 +241,10 @@ def matmul_kernel(
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_c in [False, True]
for trans_output in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [(*shape_w_c, trans_a, trans_b, trans_c, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
] + [(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
@@ -264,11 +266,11 @@ def matmul_kernel(
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_c in [False, True]
for trans_output in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1]))
] + [(*shape_w_c, trans_a, trans_b, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
@@ -287,18 +289,18 @@ def matmul_kernel(
for use_tma_store in [False, True]
for trans_a in [False, True]
for trans_b in [False, True]
for trans_c in [False, True]
for trans_output in [False, True]
for num_stages in [3]
for enable_ws in [False, True]
] + [(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
] + [(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
# loop over instr shapes
for n in [16, 32, 64, 128, 256]
for trans_c in [False, True]
for trans_output in [False, True]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for num_stages in [2, 4, 5, 7]
for enable_ws in [False, True]
] + [(*shape_w_c, *shape, False, True, trans_c, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
] + [(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
# irregular shapes
for shape_w_c in [
[128, 128, 64, 4, 1],
@@ -306,7 +308,7 @@ def matmul_kernel(
[128, 128, 128, 4, 2],
]
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
for trans_c in [False, True]
for trans_output in [False, True]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
for num_stages in [2, 3, 4]
@@ -314,7 +316,7 @@ def matmul_kernel(
])
@pytest.mark.skipif(torch.cuda.get_device_capability()
[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_C, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
@@ -331,13 +333,7 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('illegal memory access.')
# with ENABLE_TMA=1 and ENABLE_MMA_V3=1
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K])) in [
'64-64-32-8-1-128-256-64',
]:
pytest.skip('Tensor-likes are not close!')
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if NUM_CTAS > 1 and NUM_WARPS == 8:
pytest.skip('Tensor-likes are not close!')
@@ -378,27 +374,23 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_C):
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
if epilogue == 'chain-dot':
if (TRANS_C):
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
else:
w = torch.randn((M, M), device='cuda', dtype=torch.float16)
else:
w = torch.randn((1, 1), device='cuda', dtype=torch.float16).T
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_C):
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
else:
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
@@ -442,6 +434,7 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A,
CHAIN_DOT=epilogue == 'chain-dot',
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1],
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1],
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1],
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1],
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES,
enable_warp_specialization=ENABLE_WS)