Extend available types for safe_save (#2720)

* Extend available types to save with

* Linter fix
This commit is contained in:
Guy Leroy
2023-12-11 22:50:35 +00:00
committed by GitHub
parent b5fd160b39
commit ee9e1d3662
2 changed files with 9 additions and 1 deletions

View File

@@ -108,6 +108,14 @@ class TestSafetensors(unittest.TestCase):
import json
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
def test_save_all_dtypes(self):
for dtype in dtypes.fields().values():
if dtype in [dtypes.bfloat16, dtypes._arg_int32]: continue # not supported in numpy
path = temp("ones.safetensors")
ones = Tensor.rand((10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)
assert ones == list(safe_load(path).values())[0]
def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None):
if tinygrad_fxn is None: tinygrad_fxn = np_fxn
pathlib.Path(temp(fn)).unlink(missing_ok=True)