[BACKEND] Fix fma mixed-precision (#2184)

and expose the allow_tf32 argument to the matmul op

@shunting314
This commit is contained in:
Keren Zhou
2023-08-26 12:49:58 -04:00
committed by GitHub
parent fa03b92109
commit 6e4932cda8
3 changed files with 65 additions and 44 deletions

View File

@@ -26,61 +26,61 @@ def f8_to_f16(x, dtype):
@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE",
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32",
itertools.chain(
*[
[
# 1 warp
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
# 2 warp
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
# 4 warp
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
# 8 warp
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE),
(128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
(256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True),
# variable input
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True),
(128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True),
(128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True),
(128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]
],
# n-stage
*[
[
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE),
(16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True),
(64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True),
(128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True),
(256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True),
(128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True),
] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4]
],
# mixed-precision
*[
[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True),
] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"),
("float8e4nv", "float8e4nv"),
("float8e5", "float8e4nv"),
@@ -92,10 +92,23 @@ def f8_to_f16(x, dtype):
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
],
# mixed-precision block layout
*[
[
(32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False),
] for ADTYPE, BDTYPE in [("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
("float32", "float16"),
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
],
*[
# float8e4b15 only supports row-col layout
[
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE),
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE, True),
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
("float8e4b15", "float16"),
("float16", "float8e4b15"),
@@ -105,7 +118,7 @@ def f8_to_f16(x, dtype):
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE):
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
@@ -173,7 +186,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
a = triton.reinterpret(a, getattr(tl, ADTYPE))
if b_fp8:
b = triton.reinterpret(b, getattr(tl, BDTYPE))
tt_c = triton.ops.matmul(a, b)
tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32)
torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0)
except triton.OutOfResources as e:
pytest.skip(str(e))