[MFMA] Support tile size 4x4 version 1 (#413)

This PR enables 4x4 tile size in MFMA based dot operations.

Supported tiled dot is (4x64) x (64x4) -> (4x4) in MFMA layout.
However, actual dot operation should have at least 64 output elements, this is a limitation of other layouts appearing during result processing (i.e. blocked layout can not handle tensors smaller than wavesize).

For example, following dots are supported: (4x64) x (64x16) -> (4x16), (16x64) x (64x4) -> (16x4) or (8x64) x (64x8) -> (8x8)
Following dots are not supporter: (4x128) x (128x4) -> (4x4), (4x64) x (64x8) -> (4x8)

This is a first version of dot using mfma 4x4 instructions, with redundancy and reductions.
This commit is contained in:
Alexander Efimov
2023-12-12 18:23:55 +01:00
committed by GitHub
parent a944811b6d
commit 605a90c58e
10 changed files with 332 additions and 129 deletions

View File

@@ -1290,7 +1290,7 @@ def is_hip():
def mfma_supported_granularity(m, n, k) -> bool:
# todo make this gran_type matrix element type sensitive
for gran_type in [(32, 8), (16, 16)]:
for gran_type in [(32, 8), (16, 16), (4, 64)]:
granularity_mn, granularity_k = gran_type
if m % granularity_mn != 0 or n % granularity_mn != 0:
@@ -1357,9 +1357,14 @@ def dot(lhs: tl.tensor,
assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!"
assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!"
assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})"
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16, \
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
if _is_cuda(builder.target):
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16, \
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!"
if is_hip():
assert lhs.shape[0].value >= 4 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 4, \
f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 4!"
# hip for now converts fp8 to fp16 for mixed input
if is_hip():