[HOPPER][WS] remove numCTAs = 1 check in guard pass (#2066)

This commit is contained in:
allatit23
2023-08-09 17:07:56 +08:00
committed by GitHub
parent de47bba07d
commit 8a610f7cf7
2 changed files with 10 additions and 15 deletions

View File

@@ -777,20 +777,17 @@ def full_static_persistent_matmul_kernel(
(*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
# TODO: enable when num_ctas != 1 is supported.
# [128, 128, 16, 4, 4, 512, 256, 64],
# [128, 256, 32, 4, 8, 256, 256, 192],
# [512, 256, 32, 4, 8, 1024, 256, 192],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
# TODO: enable when num_ctas != 1 is supported.
# [128, 128, 128, 4, 2, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
# TODO: enable when num_ctas != 1 is supported.
# [16, 32, 64, 4, 4, 512, 256, 64],
[16, 32, 64, 4, 4, 512, 256, 64],
]
for out_dtype in ['float16', 'float32']
for use_tma_store in [False, True]
@@ -801,13 +798,13 @@ def full_static_persistent_matmul_kernel(
] + [(*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws)
for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1]],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1]],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]],
# # TODO: enable when num_warps != 4 is supported.
# # repeat
# # [64, 64, 32, 8, 1, 128, 256, 64],
@@ -836,9 +833,9 @@ def full_static_persistent_matmul_kernel(
(*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws)
# irregular shapes
for shape_w_c in [
[128, 128, 64, 4, 1]
# [256, 128, 64, 4, 2],
# [128, 128, 128, 4, 2],
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2]
]
for shape in list(itertools.product([*range(512, 4096, 360)], [*range(512, 4096, 360)], [512, 1024]))
for out_dtype in ['float16', 'float32']