mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -29,7 +29,7 @@ def convert_f32_to_f16(input_file, output_file):
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
_, data_start, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
chunk_size = 536870912
|
||||
|
||||
@@ -51,12 +51,12 @@ def split_safetensor(fn):
|
||||
part_offset = offset - last_offset
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(8+json_len+offset)
|
||||
part_end_offsets.append(data_start+offset)
|
||||
last_offset = offset
|
||||
|
||||
text_model_start = int(text_model_offset/2)
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+8+json_len)
|
||||
part_end_offsets.append(text_model_start+data_start)
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
@@ -65,7 +65,7 @@ def split_safetensor(fn):
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
f.write(net_bytes[text_model_start+data_start:])
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
|
||||
@@ -42,13 +42,12 @@ def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str,
|
||||
return wrapper
|
||||
|
||||
@accept_filename
|
||||
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Any]:
|
||||
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Dict[str, Any]]:
|
||||
"""
|
||||
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
||||
"""
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
assert isinstance(json_len, int)
|
||||
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
||||
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
||||
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
@@ -58,14 +57,10 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
|
||||
state_dict = nn.state.safe_load("test.safetensor")
|
||||
```
|
||||
"""
|
||||
t, json_len, metadata = safe_load_metadata(fn)
|
||||
ret = {}
|
||||
for k,v in metadata.items():
|
||||
if k == "__metadata__": continue
|
||||
dtype = safe_dtypes[v['dtype']]
|
||||
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
||||
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
||||
return ret
|
||||
t, data_start, metadata = safe_load_metadata(fn)
|
||||
data = t[data_start:]
|
||||
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
|
||||
for k, v in metadata.items() if k != "__metadata__" }
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user