mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -29,7 +29,7 @@ def convert_f32_to_f16(input_file, output_file):
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
_, data_start, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
chunk_size = 536870912
|
||||
|
||||
@@ -51,12 +51,12 @@ def split_safetensor(fn):
|
||||
part_offset = offset - last_offset
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(8+json_len+offset)
|
||||
part_end_offsets.append(data_start+offset)
|
||||
last_offset = offset
|
||||
|
||||
text_model_start = int(text_model_offset/2)
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+8+json_len)
|
||||
part_end_offsets.append(text_model_start+data_start)
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
@@ -65,7 +65,7 @@ def split_safetensor(fn):
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
f.write(net_bytes[text_model_start+data_start:])
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
|
||||
Reference in New Issue
Block a user