mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)
we currently have a very janky approach to optimizing mixed-precision matmul workloads, where some layout combinations (e.g., NT matmul) were explicitly pattern-matched to take a more optimized codepath. Attempt at unifying all the codepaths to codegen cp.async failed, due to bugs in SharedToDotOperandMMAv2.cpp. This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2 modes that aren't well supported, and greatly simplify our handling of element-wise operations between load and conversions to DotOperand.
This commit is contained in:
@@ -89,13 +89,21 @@ def f8_to_f16(x, dtype):
|
||||
(128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE),
|
||||
(32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE),
|
||||
(128, 128, 32, 8, 4, 2, 256, 256, 128, AT, BT, ADTYPE, BDTYPE),
|
||||
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
|
||||
] for ADTYPE, BDTYPE in [("float8e4", "float8e5"),
|
||||
("float8e4", "float16"),
|
||||
("float16", "float8e5"),
|
||||
("float16", "float32"),
|
||||
("float32", "float16"),
|
||||
("bfloat16", "float32"),
|
||||
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
|
||||
],
|
||||
*[
|
||||
# float8e4b15 only supports row-col layout
|
||||
[
|
||||
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE),
|
||||
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
|
||||
("float8e4b15", "float16"),
|
||||
("float16", "float8e4b15")]
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -132,7 +140,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT,
|
||||
if t:
|
||||
return init_input(m, n, False, dtype, is_float8).t()
|
||||
if is_float8:
|
||||
return torch.randint(20, 60, (n, m), device="cuda", dtype=torch.int8)
|
||||
return torch.randint(20, 50, (n, m), device="cuda", dtype=torch.int8)
|
||||
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype]
|
||||
return .1 * torch.randn((n, m), device="cuda", dtype=dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user