mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Extend available types for safe_save (#2720)
* Extend available types to save with * Linter fix
This commit is contained in:
@@ -108,6 +108,14 @@ class TestSafetensors(unittest.TestCase):
|
||||
import json
|
||||
assert json.loads(dat[8:8+sz])['__metadata__']['hello'] == 'world'
|
||||
|
||||
def test_save_all_dtypes(self):
|
||||
for dtype in dtypes.fields().values():
|
||||
if dtype in [dtypes.bfloat16, dtypes._arg_int32]: continue # not supported in numpy
|
||||
path = temp("ones.safetensors")
|
||||
ones = Tensor.rand((10,10), dtype=dtype)
|
||||
safe_save(get_state_dict(ones), path)
|
||||
assert ones == list(safe_load(path).values())[0]
|
||||
|
||||
def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None):
|
||||
if tinygrad_fxn is None: tinygrad_fxn = np_fxn
|
||||
pathlib.Path(temp(fn)).unlink(missing_ok=True)
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64}
|
||||
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64, "F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
|
||||
Reference in New Issue
Block a user