mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] support of Fp8E4M3Nv to Bf16 conversion (#2415)
This commit is contained in:
@@ -1299,12 +1299,12 @@ def test_atomic_cas(sem, num_ctas, device):
|
||||
] + (([
|
||||
(dtype_x, dtype_z, False, size)
|
||||
for dtype_x in torch_float8_dtypes
|
||||
for dtype_z in ["float16", "float32"]
|
||||
for dtype_z in ["float16", "float32", "bfloat16"]
|
||||
for size in [1024, 32]
|
||||
] + [
|
||||
(dtype_x, dtype_z, False, size)
|
||||
for dtype_z in torch_float8_dtypes
|
||||
for dtype_x in ["float16", "float32"]
|
||||
for dtype_x in ["float16", "float32", "bfloat16"]
|
||||
for size in [1024, 32]
|
||||
]) if torch.__version__ >= "2.1" else []))
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
|
||||
Reference in New Issue
Block a user