refactor safe_load (#8035)

* refactor safe_load

* cleanup
This commit is contained in:
leopf
2024-12-06 05:08:21 +01:00
committed by GitHub
parent e7d5fe4a32
commit 65b6696f3b
2 changed files with 11 additions and 16 deletions

View File

@@ -29,7 +29,7 @@ def convert_f32_to_f16(input_file, output_file):
rest_float32_values.tofile(f)
def split_safetensor(fn):
_, json_len, metadata = safe_load_metadata(fn)
_, data_start, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
chunk_size = 536870912
@@ -51,12 +51,12 @@ def split_safetensor(fn):
part_offset = offset - last_offset
if (part_offset >= chunk_size):
part_end_offsets.append(8+json_len+offset)
part_end_offsets.append(data_start+offset)
last_offset = offset
text_model_start = int(text_model_offset/2)
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+8+json_len)
part_end_offsets.append(text_model_start+data_start)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
@@ -65,7 +65,7 @@ def split_safetensor(fn):
cur_pos = end_pos
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
f.write(net_bytes[text_model_start+8+json_len:])
f.write(net_bytes[text_model_start+data_start:])
return part_end_offsets

View File

@@ -42,13 +42,12 @@ def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str,
return wrapper
@accept_filename
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Any]:
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Dict[str, Any]]:
"""
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
"""
json_len = t[0:8].bitcast(dtypes.int64).item()
assert isinstance(json_len, int)
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
data_start = int.from_bytes(t[0:8].data(), "little") + 8
return t, data_start, json.loads(t[8:data_start].data().tobytes())
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
"""
@@ -58,14 +57,10 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
state_dict = nn.state.safe_load("test.safetensor")
```
"""
t, json_len, metadata = safe_load_metadata(fn)
ret = {}
for k,v in metadata.items():
if k == "__metadata__": continue
dtype = safe_dtypes[v['dtype']]
sz = (v['data_offsets'][1]-v['data_offsets'][0])
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
return ret
t, data_start, metadata = safe_load_metadata(fn)
data = t[data_start:]
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
for k, v in metadata.items() if k != "__metadata__" }
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
"""