mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] Support tile size 4x4 version 1 (#413)
This PR enables 4x4 tile size in MFMA based dot operations. Supported tiled dot is (4x64) x (64x4) -> (4x4) in MFMA layout. However, actual dot operation should have at least 64 output elements, this is a limitation of other layouts appearing during result processing (i.e. blocked layout can not handle tensors smaller than wavesize). For example, following dots are supported: (4x64) x (64x16) -> (4x16), (16x64) x (64x4) -> (16x4) or (8x64) x (64x8) -> (8x8) Following dots are not supporter: (4x128) x (128x4) -> (4x4), (4x64) x (64x8) -> (4x8) This is a first version of dot using mfma 4x4 instructions, with redundancy and reductions.
This commit is contained in:
@@ -54,6 +54,12 @@ def prune_configs(M, N, K, configs):
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elemens per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = config.get("SPLIT_K")
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if BLOCK_SIZE_M < mfma or BLOCK_SIZE_N < mfma:
|
||||
|
||||
Reference in New Issue
Block a user