[BACKEND] support of Fp8E4M3Nv to Bf16 conversion (#2415)

This commit is contained in:
Hongtao Yu
2023-09-29 17:29:41 -07:00
committed by GitHub
parent e284112818
commit e0edb70f78
2 changed files with 26 additions and 2 deletions

View File

@@ -1299,12 +1299,12 @@ def test_atomic_cas(sem, num_ctas, device):
] + (([
(dtype_x, dtype_z, False, size)
for dtype_x in torch_float8_dtypes
for dtype_z in ["float16", "float32"]
for dtype_z in ["float16", "float32", "bfloat16"]
for size in [1024, 32]
] + [
(dtype_x, dtype_z, False, size)
for dtype_z in torch_float8_dtypes
for dtype_x in ["float16", "float32"]
for dtype_x in ["float16", "float32", "bfloat16"]
for size in [1024, 32]
]) if torch.__version__ >= "2.1" else []))
@pytest.mark.parametrize("num_ctas", num_ctas_list)