[OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)

we currently have a very janky approach to optimizing mixed-precision
matmul workloads, where some layout combinations (e.g., NT matmul) were
explicitly pattern-matched to take a more optimized codepath. Attempt at
unifying all the codepaths to codegen cp.async failed, due to bugs in
SharedToDotOperandMMAv2.cpp.

This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2
modes that aren't well supported, and greatly simplify our handling of
element-wise operations between load and conversions to DotOperand.
This commit is contained in:
Philippe Tillet
2023-07-28 10:29:42 -07:00
committed by GitHub
parent 2689f4a3b0
commit 52c146f66b
11 changed files with 187 additions and 412 deletions

View File

@@ -89,13 +89,21 @@ def f8_to_f16(x, dtype):
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
(128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE),
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
] for ADTYPE, BDTYPE in [("float8e4", "float8e5"),
("float8e4", "float16"),
("float16", "float8e5"),
("float16", "float32"),
("float32", "float16"),
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
],
*[
# float8e4b15 only supports row-col layout
[
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE),
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
("float8e4b15", "float16"),
("float16", "float8e4b15")]
]
),
)
@@ -132,7 +140,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
if t:
return init_input(m, n, False, dtype, is_float8).t()
if is_float8:
return torch.randint(20, 60, (n, m), device="cuda", dtype=torch.int8)
return torch.randint(20, 50, (n, m), device="cuda", dtype=torch.int8)
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
return .1 * torch.randn((n, m), device="cuda", dtype=dtype)