[BACKEND] Lift restriction for float8e4b15 to only support row-col layout (#2212)

This commit is contained in:
Philippe Tillet
2023-08-30 14:06:31 -07:00
committed by GitHub
parent 3175ee4ce7
commit ec51552fff
2 changed files with 1 additions and 19 deletions

View File

@@ -85,6 +85,7 @@ def f8_to_f16(x, dtype):
("float8e4nv", "float8e4nv"),
("float8e5", "float8e4nv"),
("float8e5", "float8e5"),
("float8e4b15", "float8e4b15"),
("float8e4nv", "float16"),
("float16", "float8e5"),
("float16", "float32"),
@@ -105,17 +106,6 @@ def f8_to_f16(x, dtype):
("bfloat16", "float32"),
("float32", "bfloat16")] for AT in [False, True] for BT in [False, True]
],
*[
# float8e4b15 only supports row-col layout
[
(128, 128, 32, 1, 4, 2, None, None, None, False, True, ADTYPE, BDTYPE, True),
] for ADTYPE, BDTYPE in [("float8e4b15", "float8e5"),
("float8e4b15", "float16"),
("float16", "float8e4b15"),
("float8e5", "float8e5"),
("float8e4nv", "float8e4nv"),
("int8", "int8")]
]
),
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32):