Temporarily disable F8 tests on ROCm

This commit is contained in:
Jason Furmanek
2023-08-28 17:39:24 +00:00
parent 7954cc8e16
commit 320b1029da

View File

@@ -928,6 +928,8 @@ def test_f8_xf16_roundtrip(in_dtype, out_dtype):
output = input
tl.store(output_ptr + offsets, output, mask=mask)
if torch.version.hip is not None:
pytest.skip("test_masked_load_shared_memory[bfloat16] is only supported on AMDGPU")
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
@@ -959,6 +961,8 @@ def test_f16_to_f8_rounding(in_dtype):
output = input
tl.store(output_ptr + offsets, output, mask=mask)
if torch.version.hip is not None:
pytest.skip("test_masked_load_shared_memory[bfloat16] is only supported on AMDGPU")
# torch.view with a dtype isn't supported in triton's torch yet so use numpy's view
f16_input_np = (
np.array(