Revert "threefry_2x32 (#2601)" (#3784)

This reverts commit db3de54bc4.
This commit is contained in:
George Hotz
2024-03-17 10:27:20 -07:00
committed by GitHub
parent db3de54bc4
commit 311cf2b7d3
13 changed files with 56 additions and 92 deletions

View File

@@ -120,7 +120,7 @@ class TestSafetensors(unittest.TestCase):
for dtype in dtypes.fields().values():
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
path = temp(f"ones.{dtype}.safetensors")
ones = Tensor(np.random.rand(10,10), dtype=dtype)
ones = Tensor.rand((10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)
np.testing.assert_equal(ones.numpy(), list(safe_load(path).values())[0].numpy())