diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 33d9c3f153..2aa238da1e 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -164,7 +164,7 @@ class TestSafetensors(unittest.TestCase): def test_save_all_dtypes(self): for dtype in dtypes.fields().values(): if dtype in [dtypes.bfloat16]: continue # not supported in numpy - if dtype in [dtypes.double] and Device.DEFAULT == "METAL": continue # not supported on METAL + if dtype in [dtypes.double, *dtypes.fp8s] and Device.DEFAULT == "METAL": continue # not supported on METAL path = temp(f"ones.{dtype}.safetensors") ones = Tensor(np.random.rand(10,10), dtype=dtype) safe_save(get_state_dict(ones), path)