mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -29,7 +29,7 @@ def convert_f32_to_f16(input_file, output_file):
|
|||||||
rest_float32_values.tofile(f)
|
rest_float32_values.tofile(f)
|
||||||
|
|
||||||
def split_safetensor(fn):
|
def split_safetensor(fn):
|
||||||
_, json_len, metadata = safe_load_metadata(fn)
|
_, data_start, metadata = safe_load_metadata(fn)
|
||||||
text_model_offset = 3772703308
|
text_model_offset = 3772703308
|
||||||
chunk_size = 536870912
|
chunk_size = 536870912
|
||||||
|
|
||||||
@@ -51,12 +51,12 @@ def split_safetensor(fn):
|
|||||||
part_offset = offset - last_offset
|
part_offset = offset - last_offset
|
||||||
|
|
||||||
if (part_offset >= chunk_size):
|
if (part_offset >= chunk_size):
|
||||||
part_end_offsets.append(8+json_len+offset)
|
part_end_offsets.append(data_start+offset)
|
||||||
last_offset = offset
|
last_offset = offset
|
||||||
|
|
||||||
text_model_start = int(text_model_offset/2)
|
text_model_start = int(text_model_offset/2)
|
||||||
net_bytes = bytes(open(fn, 'rb').read())
|
net_bytes = bytes(open(fn, 'rb').read())
|
||||||
part_end_offsets.append(text_model_start+8+json_len)
|
part_end_offsets.append(text_model_start+data_start)
|
||||||
cur_pos = 0
|
cur_pos = 0
|
||||||
|
|
||||||
for i, end_pos in enumerate(part_end_offsets):
|
for i, end_pos in enumerate(part_end_offsets):
|
||||||
@@ -65,7 +65,7 @@ def split_safetensor(fn):
|
|||||||
cur_pos = end_pos
|
cur_pos = end_pos
|
||||||
|
|
||||||
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
|
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
|
||||||
f.write(net_bytes[text_model_start+8+json_len:])
|
f.write(net_bytes[text_model_start+data_start:])
|
||||||
|
|
||||||
return part_end_offsets
|
return part_end_offsets
|
||||||
|
|
||||||
|
|||||||
@@ -42,13 +42,12 @@ def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str,
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@accept_filename
|
@accept_filename
|
||||||
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Any]:
|
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
||||||
"""
|
"""
|
||||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
||||||
assert isinstance(json_len, int)
|
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
||||||
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
|
||||||
|
|
||||||
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
|
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -58,14 +57,10 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
|
|||||||
state_dict = nn.state.safe_load("test.safetensor")
|
state_dict = nn.state.safe_load("test.safetensor")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
t, json_len, metadata = safe_load_metadata(fn)
|
t, data_start, metadata = safe_load_metadata(fn)
|
||||||
ret = {}
|
data = t[data_start:]
|
||||||
for k,v in metadata.items():
|
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
|
||||||
if k == "__metadata__": continue
|
for k, v in metadata.items() if k != "__metadata__" }
|
||||||
dtype = safe_dtypes[v['dtype']]
|
|
||||||
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
|
||||||
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user