support safe load bf16 (#3310)

* support safe load bf16

* fix lint error E501

* add test for loading safetensors

* key should be BOOL

* fix lint
This commit is contained in:
Yoshinori Sano
2024-02-05 00:08:39 +09:00
committed by GitHub
parent ca7973f61c
commit edb74897b2
2 changed files with 26 additions and 1 deletions

View File

@@ -127,6 +127,30 @@ class TestSafetensors(unittest.TestCase):
safe_save(get_state_dict(ones), path)
assert ones == list(safe_load(path).values())[0]
def test_load_supported_types(self):
import torch
from safetensors.torch import save_file
torch.manual_seed(1337)
tensors = {
"weight_F16": torch.randn((2, 2), dtype=torch.float16),
"weight_F32": torch.randn((2, 2), dtype=torch.float32),
"weight_U8": torch.tensor([1, 2, 3], dtype=torch.uint8),
"weight_I8": torch.tensor([-1, 2, 3], dtype=torch.int8),
"weight_I32": torch.tensor([-1, 2, 3], dtype=torch.int32),
"weight_I64": torch.tensor([-1, 2, 3], dtype=torch.int64),
"weight_F64": torch.randn((2, 2), dtype=torch.double),
"weight_BOOL": torch.tensor([True, False], dtype=torch.bool),
"weight_I16": torch.tensor([127, 64], dtype=torch.short),
# pytorch does not support U16, UI, and UL dtypes.
"weight_BF16": torch.randn((2, 2), dtype=torch.bfloat16),
}
save_file(tensors, temp("model.safetensors"))
try:
safe_load(temp("model.safetensors"))
except Exception as e:
self.fail(f"got error while loading safetensors: {e}")
def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None):
if tinygrad_fxn is None: tinygrad_fxn = np_fxn
pathlib.Path(temp(fn)).unlink(missing_ok=True)

View File

@@ -9,7 +9,8 @@ from tinygrad.shape.view import strides_for_shape
from tinygrad.features.multi import MultiLazyBuffer
safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64,
"F64": dtypes.double, "B": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong}
"F64": dtypes.double, "BOOL": dtypes.bool, "I16": dtypes.short, "U16": dtypes.ushort, "UI": dtypes.uint, "UL": dtypes.ulong,
"BF16": dtypes.bfloat16}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]: