mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] Support tile size 4x4 version 1 (#413)
This PR enables 4x4 tile size in MFMA based dot operations. Supported tiled dot is (4x64) x (64x4) -> (4x4) in MFMA layout. However, actual dot operation should have at least 64 output elements, this is a limitation of other layouts appearing during result processing (i.e. blocked layout can not handle tensors smaller than wavesize). For example, following dots are supported: (4x64) x (64x16) -> (4x16), (16x64) x (64x4) -> (16x4) or (8x64) x (64x8) -> (8x8) Following dots are not supporter: (4x128) x (128x4) -> (4x4), (4x64) x (64x8) -> (4x8) This is a first version of dot using mfma 4x4 instructions, with redundancy and reductions.
This commit is contained in:
@@ -1644,7 +1644,7 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
('float8e4m3fnuz', 'float32'),
|
||||
('float16', 'float32'),
|
||||
('float32', 'float32')]
|
||||
for non_k_dim in [0, 16, 32]
|
||||
for non_k_dim in [0, 4, 16, 32]
|
||||
if not (allow_tf32 and (in_dtype in ['float16']))] +
|
||||
|
||||
[(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype, non_k_dim)
|
||||
@@ -1670,13 +1670,18 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
[64, 32, 32, 2],
|
||||
[256, 32, 32, 2],
|
||||
[256, 32, 32, 4],
|
||||
[32, 8, 128, 4],
|
||||
[8, 32, 128, 2],
|
||||
[4, 32, 64, 4],
|
||||
[32, 4, 64, 2],
|
||||
[16, 4, 64, 8]
|
||||
]
|
||||
for allow_tf32 in [False, True]
|
||||
for col_a in [True, False]
|
||||
for col_b in [True, False]
|
||||
for in_dtype in ['int8', 'bfloat16', 'float16', 'float32']
|
||||
for out_dtype in [None]
|
||||
for non_k_dim in [0, 16, 32]])
|
||||
for non_k_dim in [0, 4, 16, 32]])
|
||||
def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim, device='cuda'):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
||||
@@ -1697,6 +1702,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
out_dtype = "int32"
|
||||
if non_k_dim == 32 and (M < 32 or N < 32):
|
||||
pytest.skip("incompatible non_k_dim == 32 with MN sizes")
|
||||
if non_k_dim == 16 and (M < 16 or N < 16):
|
||||
pytest.skip("incompatible non_k_dim == 16 with MN sizes")
|
||||
if non_k_dim == 4 and (K < 64):
|
||||
pytest.skip("incompatible non_k_dim == 4 with K size")
|
||||
if non_k_dim == 4 and (M > 16 or N > 16):
|
||||
pytest.skip("skipping lage matrices for non_k_dim == 4 to speedup testing")
|
||||
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
@@ -2003,6 +2014,7 @@ def get_variant_golden(a, b):
|
||||
|
||||
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,NUM_STAGES', [
|
||||
[64, 32, 128, 4, 64, 32, 64, 0],
|
||||
[4, 16, 128, 4, 4, 16, 64, 1],
|
||||
[64, 32, 128, 4, 64, 32, 64, 2]
|
||||
])
|
||||
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES):
|
||||
|
||||
@@ -1290,7 +1290,7 @@ def is_hip():
|
||||
|
||||
def mfma_supported_granularity(m, n, k) -> bool:
|
||||
# todo make this gran_type matrix element type sensitive
|
||||
for gran_type in [(32, 8), (16, 16)]:
|
||||
for gran_type in [(32, 8), (16, 16), (4, 64)]:
|
||||
granularity_mn, granularity_k = gran_type
|
||||
|
||||
if m % granularity_mn != 0 or n % granularity_mn != 0:
|
||||
@@ -1357,9 +1357,14 @@ def dot(lhs: tl.tensor,
|
||||
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
|
||||
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
|
||||
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
if _is_cuda(builder.target):
|
||||
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 16, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
|
||||
if is_hip():
|
||||
assert lhs.shape[0].value >= 4 and lhs.shape[1].value >= 16 \
|
||||
and rhs.shape[1].value >= 4, \
|
||||
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 4!"
|
||||
|
||||
# hip for now converts fp8 to fp16 for mixed input
|
||||
if is_hip():
|
||||
|
||||
Reference in New Issue
Block a user