mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
support safe load bf16 (#3310)
* support safe load bf16 * fix lint error E501 * add test for loading safetensors * key should be BOOL * fix lint
This commit is contained in:
@@ -127,6 +127,30 @@ class TestSafetensors(unittest.TestCase):
|
||||
safe_save(get_state_dict(ones), path)
|
||||
assert ones == list(safe_load(path).values())[0]
|
||||
|
||||
def test_load_supported_types(self):
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
torch.manual_seed(1337)
|
||||
tensors = {
|
||||
"weight_F16": torch.randn((2, 2), dtype=torch.float16),
|
||||
"weight_F32": torch.randn((2, 2), dtype=torch.float32),
|
||||
"weight_U8": torch.tensor([1, 2, 3], dtype=torch.uint8),
|
||||
"weight_I8": torch.tensor([-1, 2, 3], dtype=torch.int8),
|
||||
"weight_I32": torch.tensor([-1, 2, 3], dtype=torch.int32),
|
||||
"weight_I64": torch.tensor([-1, 2, 3], dtype=torch.int64),
|
||||
"weight_F64": torch.randn((2, 2), dtype=torch.double),
|
||||
"weight_BOOL": torch.tensor([True, False], dtype=torch.bool),
|
||||
"weight_I16": torch.tensor([127, 64], dtype=torch.short),
|
||||
# pytorch does not support U16, UI, and UL dtypes.
|
||||
"weight_BF16": torch.randn((2, 2), dtype=torch.bfloat16),
|
||||
}
|
||||
save_file(tensors, temp("model.safetensors"))
|
||||
|
||||
try:
|
||||
safe_load(temp("model.safetensors"))
|
||||
except Exception as e:
|
||||
self.fail(f"got error while loading safetensors: {e}")
|
||||
|
||||
def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None):
|
||||
if tinygrad_fxn is None: tinygrad_fxn = np_fxn
|
||||
pathlib.Path(temp(fn)).unlink(missing_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user