[BACKEND] Support MMA V3 with float16 accumulator (#2049)

Also fixes a bug exposed in convertLayout lowering for float16. We
shouldn't be using cvt.pack.sat.u16.s32 to pack 16bits values as this
needs to take a 32bits register. Also this prevented optimization at
llvm ir level.
This commit is contained in:
Thomas
2023-08-07 15:55:44 -07:00
committed by GitHub
parent 521cfae44d
commit 98523bcc48
6 changed files with 79 additions and 72 deletions

View File

@@ -2139,9 +2139,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
if out_dtype == 'float16':
# TODO: support out_dtype=float16 for tl.dot on V100
pytest.skip("Only test out_dtype=float16 on devices with sm >=80")
if capability[0] == 9 and out_dtype == 'float16':
# TODO: support out_dtype=float16 for tl.dot on H100
pytest.skip("Only test out_dtype=float16 on devices with sm<90")
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
@@ -2297,7 +2294,7 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
elif in_dtype == 'float16' and out_dtype == tl.float32:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
elif in_dtype == 'float16' and out_dtype == tl.float16:
assert 'mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16' in ptx
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx)
elif in_dtype == 'int8':
assert 'wgmma.mma_async.sync.aligned' in ptx or\
'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx