refactor safe_load (#8035)

* refactor safe_load

* cleanup
This commit is contained in:
leopf
2024-12-06 05:08:21 +01:00
committed by GitHub
parent e7d5fe4a32
commit 65b6696f3b
2 changed files with 11 additions and 16 deletions

View File

@@ -29,7 +29,7 @@ def convert_f32_to_f16(input_file, output_file):
rest_float32_values.tofile(f)
def split_safetensor(fn):
_, json_len, metadata = safe_load_metadata(fn)
_, data_start, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
chunk_size = 536870912
@@ -51,12 +51,12 @@ def split_safetensor(fn):
part_offset = offset - last_offset
if (part_offset >= chunk_size):
part_end_offsets.append(8+json_len+offset)
part_end_offsets.append(data_start+offset)
last_offset = offset
text_model_start = int(text_model_offset/2)
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+8+json_len)
part_end_offsets.append(text_model_start+data_start)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
@@ -65,7 +65,7 @@ def split_safetensor(fn):
cur_pos = end_pos
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
f.write(net_bytes[text_model_start+8+json_len:])
f.write(net_bytes[text_model_start+data_start:])
return part_end_offsets