hotfix: fix test_save_all_dtypes on METAL

This commit is contained in:
George Hotz
2025-04-18 08:42:31 +01:00
parent 16dfe0a902
commit 8919370c76

View File

@@ -164,7 +164,7 @@ class TestSafetensors(unittest.TestCase):
def test_save_all_dtypes(self):
for dtype in dtypes.fields().values():
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
if dtype in [dtypes.double] and Device.DEFAULT == "METAL": continue # not supported on METAL
if dtype in [dtypes.double, *dtypes.fp8s] and Device.DEFAULT == "METAL": continue # not supported on METAL
path = temp(f"ones.{dtype}.safetensors")
ones = Tensor(np.random.rand(10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)