mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix fma mixed-precision (#2184)
and expose the allow_tf32 argument to the matmul op @shunting314
This commit is contained in:
@@ -26,61 +26,61 @@ def f8_to_f16(x, dtype):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE",
|
||||
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32",
|
||||
itertools.chain(
|
||||
*[
|
||||
[
|
||||
# 1 warp
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
# 2 warp
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
# 4 warp
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
# 8 warp
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
|
||||
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
|
||||
# variable input
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# n-stage
|
||||
*[
|
||||
[
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE),
|
||||
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True),
|
||||
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True),
|
||||
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True),
|
||||
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
|
||||
],
|
||||
# mixed-precision
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"),
|
||||
("float8e4nv", "float8e4nv"),
|
||||
("float8e5", "float8e4nv"),
|
||||
@@ -92,10 +92,23 @@ def f8_to_f16(x, dtype):
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
# mixed-precision block layout
|
||||
*[
|
||||
[
|
||||
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False),
|
||||
] for ADTYPE, BDTYPE in [("float8e4nv", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
*[
|
||||
# float8e4b15 only supports row-col layout
|
||||
[
|
||||
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE),
|
||||
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE, True),
|
||||
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
|
||||
("float8e4b15", "float16"),
|
||||
("float16", "float8e4b15"),
|
||||
@@ -105,7 +118,7 @@ def f8_to_f16(x, dtype):
|
||||
]
|
||||
),
|
||||
)
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE):
|
||||
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 7:
|
||||
pytest.skip("Only test tl.dot() on devices with sm >= 70")
|
||||
@@ -173,7 +186,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
a = triton.reinterpret(a, getattr(tl, ADTYPE))
|
||||
if b_fp8:
|
||||
b = triton.reinterpret(b, getattr(tl, BDTYPE))
|
||||
tt_c = triton.ops.matmul(a, b)
|
||||
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32)
|
||||
torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0)
|
||||
except triton.OutOfResources as e:
|
||||
pytest.skip(str(e))
|
||||
|
||||
Reference in New Issue
Block a user