mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
torch_load: fix shared storage slicing (#14771)
* faster zip_extract + usage in torch load * clean zip in torch load * working zipextract in torchload * tar_extract in tar path * faster tar path * tests passing, cleanup needed * faster tar with 1MB buffer * comments * unify storage_source with all paths * use bufferedreader in zip path * fix ruff * clean * removed unnecessary string conversion * fix for tensors that share storage * less hacky * shared storage test * test comment * linter --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,7 @@ def compare_weights_both(url):
|
||||
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
|
||||
print(f"compared {len(tg_weights)} weights")
|
||||
|
||||
class TestTorchLoad(unittest.TestCase):
|
||||
class TestTorchLoad(TempDirTestCase):
|
||||
# pytorch pkl format
|
||||
def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
# pytorch zip format
|
||||
@@ -42,6 +42,13 @@ class TestTorchLoad(unittest.TestCase):
|
||||
# pytorch tar format
|
||||
def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
|
||||
# shared storage (mixtral-8x7b-32kseqlen)
|
||||
def test_shared_storage(self):
|
||||
import torch
|
||||
fn = self.tmp("shared_storage.pth")
|
||||
torch.save({"a": (a := torch.randn(100)), "b": a[5:]}, fn)
|
||||
compare_weights_both(fn)
|
||||
|
||||
test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
|
||||
#test_size = test_fn.stat().st_size
|
||||
test_size = 1024*1024*1024*2
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import OrderedDict
|
||||
from typing import Any, Callable, BinaryIO, Iterable, cast
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T, strides_for_shape
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, GlobalCounters, tqdm, round_up, T, strides_for_shape
|
||||
|
||||
class TensorIO(io.RawIOBase, BinaryIO):
|
||||
def __init__(self, t: Tensor):
|
||||
@@ -165,22 +165,20 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
||||
@accept_filename
|
||||
def zip_extract(t: Tensor) -> dict[str, Tensor]:
|
||||
files: dict[str, Tensor] = {}
|
||||
file_offsets: dict[str, tuple[Tensor, int, int]] = {}
|
||||
with zipfile.ZipFile(TensorIO(t), "r") as myzip:
|
||||
for zi in myzip.filelist:
|
||||
file_offset = zi.header_offset+30+t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to("CPU").sum()
|
||||
file_offsets[zi.filename] = (file_offset, zi.compress_size, zi.compress_type)
|
||||
# sadly, the extra length needs to be read from the local header of each file. this is a limitation of the zip file format
|
||||
Tensor.realize(*[x[0] for x in file_offsets.values()])
|
||||
for filename, (file_offset, compress_size, compress_type) in file_offsets.items():
|
||||
# possible to remove this realize/item? it's slow
|
||||
file_offset_int = int(file_offset.item())
|
||||
files[filename] = t[file_offset_int:file_offset_int+compress_size]
|
||||
match compress_type:
|
||||
case zipfile.ZIP_STORED: pass
|
||||
# TODO: we need a zlib UOp so this can be lazy
|
||||
case zipfile.ZIP_DEFLATED: files[filename] = Tensor(zlib.decompress(files[filename].data(), -15))
|
||||
case _: raise NotImplementedError(f"compression {compress_type} not supported")
|
||||
# sadly, the extra length needs to be read from the local header of each file.
|
||||
# this is a limitation of the zip file format
|
||||
header_contents = [t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to('CPU') for zi in myzip.filelist]
|
||||
Tensor.realize(*header_contents)
|
||||
for zi, header_content in zip(myzip.filelist, header_contents):
|
||||
# header_offset + sizeFileHeader + File name length + Extra field length
|
||||
file_offset = zi.header_offset + 30 + sum(cast(list[int], header_content.tolist()))
|
||||
files[zi.filename] = t[file_offset:file_offset+zi.compress_size]
|
||||
match zi.compress_type:
|
||||
case zipfile.ZIP_STORED: pass
|
||||
# TODO: we need a zlib UOp so this can be lazy
|
||||
case zipfile.ZIP_DEFLATED: files[zi.filename] = Tensor(zlib.decompress(files[zi.filename].data(), -15))
|
||||
case _: raise NotImplementedError(f"compression {zi.compress_type} not supported")
|
||||
return files
|
||||
|
||||
@accept_filename
|
||||
@@ -201,7 +199,6 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
||||
|
||||
# torch support!
|
||||
|
||||
# TODO: this should use tar_extract and zip_extract
|
||||
@accept_filename
|
||||
def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
@@ -215,7 +212,7 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
state_dict = nn.state.torch_load("test.pth")
|
||||
```
|
||||
"""
|
||||
offsets: dict[str|int, int] = {}
|
||||
storage_source: dict[str|int, Tensor] = {}
|
||||
lens: dict[str|int, int] = {}
|
||||
|
||||
def _rebuild_tensor(storage, storage_offset, size, stride):
|
||||
@@ -224,9 +221,9 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
|
||||
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
|
||||
if storage[2] not in storage_source: return None
|
||||
byte_start, byte_end = storage_offset*storage[1].itemsize, (storage_offset + prod(size))*storage[1].itemsize
|
||||
ret = storage_source[storage[2]][byte_start:byte_end].bitcast(storage[1])
|
||||
|
||||
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
|
||||
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
|
||||
@@ -262,43 +259,36 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
|
||||
fobj = io.BufferedReader(TensorIO(t))
|
||||
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
|
||||
|
||||
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
|
||||
myzip = zipfile.ZipFile(fobj, 'r')
|
||||
base_name = None
|
||||
header_offsets = {}
|
||||
for zi in myzip.filelist:
|
||||
if base_name is None: base_name = zi.filename.split('/', 1)[0]
|
||||
if zi.filename.startswith(f'{base_name}/data/'): header_offsets[zi.filename.split("/")[-1]] = zi.header_offset
|
||||
# sadly there's no way to get the start of the file in the zip without reading the header
|
||||
# at least here we read them in parallel
|
||||
header_contents = [t[v+26:v+30].bitcast(dtypes.uint16).to('CPU') for v in header_offsets.values()]
|
||||
Tensor.realize(*header_contents)
|
||||
for (n,o),c in zip(header_offsets.items(), header_contents):
|
||||
# header_offset + sizeFileHeader + File name length + Extra field length : https://en.wikipedia.org/wiki/ZIP_(file_format)
|
||||
offsets[n] = o+30+sum(cast(list[int], c.tolist()))
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
return TorchPickle(myfile).load()
|
||||
files = zip_extract(t)
|
||||
base_name = next(iter(files)).split('/', 1)[0]
|
||||
# keyed by persistent_id in pickle file
|
||||
storage_source = {fn.split("/")[-1]: data for fn, data in files.items() if fn.startswith(f"{base_name}/data/") and not fn.endswith(".pkl")}
|
||||
return TorchPickle(io.BufferedReader(TensorIO(files[f"{base_name}/data.pkl"]), 1_000_000)).load()
|
||||
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
|
||||
with tarfile.open(fileobj=fobj, mode="r") as tar:
|
||||
storages_offset = tar.getmember('storages').offset_data
|
||||
f = unwrap(tar.extractfile('storages'))
|
||||
for i in range(TorchPickle(f).load()): # num_storages
|
||||
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
|
||||
offsets[key] = storages_offset + f.tell()
|
||||
f.seek(sz*storage_type.itemsize, 1)
|
||||
f = unwrap(tar.extractfile('tensors'))
|
||||
for _ in range(TorchPickle(f).load()): # num_tensors
|
||||
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
|
||||
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
|
||||
storage_offset = struct.unpack('<q', f.read(8))[0]
|
||||
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
||||
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
||||
files = tar_extract(t)
|
||||
f = io.BufferedReader(TensorIO(files["storages"]), 1_000_000)
|
||||
# slice source tensor t
|
||||
for _ in range(TorchPickle(f).load()):
|
||||
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
|
||||
byte_offset = f.tell()
|
||||
storage_source[key] = files["storages"][byte_offset:byte_offset + sz * storage_type.itemsize]
|
||||
f.seek(sz * storage_type.itemsize, 1)
|
||||
f = io.BufferedReader(TensorIO(files["tensors"]), 1_000_000)
|
||||
# get tensor metadata
|
||||
for _ in range(TorchPickle(f).load()):
|
||||
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
|
||||
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
|
||||
storage_offset = struct.unpack('<q', f.read(8))[0]
|
||||
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
||||
pkl_data = TorchPickle(io.BufferedReader(TensorIO(files["pickle"]), 1_000_000)).load()
|
||||
return {k: v.tensor if isinstance(v, Parameter) else v for k, v in pkl_data.items()}
|
||||
else:
|
||||
pkl = TorchPickle(fobj)
|
||||
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
|
||||
# slice source tensor t
|
||||
for i in ids:
|
||||
offsets[i] = base_offset + 8
|
||||
storage_source[i] = t[base_offset + 8:base_offset + 8 + lens[i]]
|
||||
base_offset += 8 + lens[i]
|
||||
fobj.seek(rwd)
|
||||
return TorchPickle(fobj).load()
|
||||
|
||||
Reference in New Issue
Block a user