torch_load: fix shared storage slicing (#14771)

* faster zip_extract + usage in torch load

* clean zip in torch load

* working zipextract in torchload

* tar_extract in tar path

* faster tar path

* tests passing, cleanup needed

* faster tar with 1MB buffer

* comments

* unify storage_source with all paths

* use bufferedreader in zip path

* fix ruff

* clean

* removed unnecessary string conversion

* fix for tensors that share storage

* less hacky

* shared storage test

* test comment

* linter

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Bautista Garcia
2026-02-16 03:30:13 -03:00
committed by GitHub
parent dff9cf35c2
commit 0f1ca8eb43
2 changed files with 50 additions and 53 deletions

View File

@@ -28,7 +28,7 @@ def compare_weights_both(url):
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
print(f"compared {len(tg_weights)} weights")
class TestTorchLoad(unittest.TestCase):
class TestTorchLoad(TempDirTestCase):
# pytorch pkl format
def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
# pytorch zip format
@@ -42,6 +42,13 @@ class TestTorchLoad(unittest.TestCase):
# pytorch tar format
def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
# shared storage (mixtral-8x7b-32kseqlen)
def test_shared_storage(self):
import torch
fn = self.tmp("shared_storage.pth")
torch.save({"a": (a := torch.randn(100)), "b": a[5:]}, fn)
compare_weights_both(fn)
test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
#test_size = test_fn.stat().st_size
test_size = 1024*1024*1024*2

View File

@@ -3,7 +3,7 @@ from collections import OrderedDict
from typing import Any, Callable, BinaryIO, Iterable, cast
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T, strides_for_shape
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, GlobalCounters, tqdm, round_up, T, strides_for_shape
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
@@ -165,22 +165,20 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
@accept_filename
def zip_extract(t: Tensor) -> dict[str, Tensor]:
files: dict[str, Tensor] = {}
file_offsets: dict[str, tuple[Tensor, int, int]] = {}
with zipfile.ZipFile(TensorIO(t), "r") as myzip:
for zi in myzip.filelist:
file_offset = zi.header_offset+30+t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to("CPU").sum()
file_offsets[zi.filename] = (file_offset, zi.compress_size, zi.compress_type)
# sadly, the extra length needs to be read from the local header of each file. this is a limitation of the zip file format
Tensor.realize(*[x[0] for x in file_offsets.values()])
for filename, (file_offset, compress_size, compress_type) in file_offsets.items():
# possible to remove this realize/item? it's slow
file_offset_int = int(file_offset.item())
files[filename] = t[file_offset_int:file_offset_int+compress_size]
match compress_type:
case zipfile.ZIP_STORED: pass
# TODO: we need a zlib UOp so this can be lazy
case zipfile.ZIP_DEFLATED: files[filename] = Tensor(zlib.decompress(files[filename].data(), -15))
case _: raise NotImplementedError(f"compression {compress_type} not supported")
# sadly, the extra length needs to be read from the local header of each file.
# this is a limitation of the zip file format
header_contents = [t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to('CPU') for zi in myzip.filelist]
Tensor.realize(*header_contents)
for zi, header_content in zip(myzip.filelist, header_contents):
# header_offset + sizeFileHeader + File name length + Extra field length
file_offset = zi.header_offset + 30 + sum(cast(list[int], header_content.tolist()))
files[zi.filename] = t[file_offset:file_offset+zi.compress_size]
match zi.compress_type:
case zipfile.ZIP_STORED: pass
# TODO: we need a zlib UOp so this can be lazy
case zipfile.ZIP_DEFLATED: files[zi.filename] = Tensor(zlib.decompress(files[zi.filename].data(), -15))
case _: raise NotImplementedError(f"compression {zi.compress_type} not supported")
return files
@accept_filename
@@ -201,7 +199,6 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
# torch support!
# TODO: this should use tar_extract and zip_extract
@accept_filename
def torch_load(t:Tensor) -> dict[str, Tensor]:
"""
@@ -215,7 +212,7 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
state_dict = nn.state.torch_load("test.pth")
```
"""
offsets: dict[str|int, int] = {}
storage_source: dict[str|int, Tensor] = {}
lens: dict[str|int, int] = {}
def _rebuild_tensor(storage, storage_offset, size, stride):
@@ -224,9 +221,9 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
if storage[2] not in storage_source: return None
byte_start, byte_end = storage_offset*storage[1].itemsize, (storage_offset + prod(size))*storage[1].itemsize
ret = storage_source[storage[2]][byte_start:byte_end].bitcast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
@@ -262,43 +259,36 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
fobj = io.BufferedReader(TensorIO(t))
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
myzip = zipfile.ZipFile(fobj, 'r')
base_name = None
header_offsets = {}
for zi in myzip.filelist:
if base_name is None: base_name = zi.filename.split('/', 1)[0]
if zi.filename.startswith(f'{base_name}/data/'): header_offsets[zi.filename.split("/")[-1]] = zi.header_offset
# sadly there's no way to get the start of the file in the zip without reading the header
# at least here we read them in parallel
header_contents = [t[v+26:v+30].bitcast(dtypes.uint16).to('CPU') for v in header_offsets.values()]
Tensor.realize(*header_contents)
for (n,o),c in zip(header_offsets.items(), header_contents):
# header_offset + sizeFileHeader + File name length + Extra field length : https://en.wikipedia.org/wiki/ZIP_(file_format)
offsets[n] = o+30+sum(cast(list[int], c.tolist()))
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
files = zip_extract(t)
base_name = next(iter(files)).split('/', 1)[0]
# keyed by persistent_id in pickle file
storage_source = {fn.split("/")[-1]: data for fn, data in files.items() if fn.startswith(f"{base_name}/data/") and not fn.endswith(".pkl")}
return TorchPickle(io.BufferedReader(TensorIO(files[f"{base_name}/data.pkl"]), 1_000_000)).load()
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
with tarfile.open(fileobj=fobj, mode="r") as tar:
storages_offset = tar.getmember('storages').offset_data
f = unwrap(tar.extractfile('storages'))
for i in range(TorchPickle(f).load()): # num_storages
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
offsets[key] = storages_offset + f.tell()
f.seek(sz*storage_type.itemsize, 1)
f = unwrap(tar.extractfile('tensors'))
for _ in range(TorchPickle(f).load()): # num_tensors
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset = struct.unpack('<q', f.read(8))[0]
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
files = tar_extract(t)
f = io.BufferedReader(TensorIO(files["storages"]), 1_000_000)
# slice source tensor t
for _ in range(TorchPickle(f).load()):
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
byte_offset = f.tell()
storage_source[key] = files["storages"][byte_offset:byte_offset + sz * storage_type.itemsize]
f.seek(sz * storage_type.itemsize, 1)
f = io.BufferedReader(TensorIO(files["tensors"]), 1_000_000)
# get tensor metadata
for _ in range(TorchPickle(f).load()):
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset = struct.unpack('<q', f.read(8))[0]
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
pkl_data = TorchPickle(io.BufferedReader(TensorIO(files["pickle"]), 1_000_000)).load()
return {k: v.tensor if isinstance(v, Parameter) else v for k, v in pkl_data.items()}
else:
pkl = TorchPickle(fobj)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
# slice source tensor t
for i in ids:
offsets[i] = base_offset + 8
storage_source[i] = t[base_offset + 8:base_offset + 8 + lens[i]]
base_offset += 8 + lens[i]
fobj.seek(rwd)
return TorchPickle(fobj).load()