diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index e9262b9a31..ae98a19fc8 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -28,7 +28,7 @@ def compare_weights_both(url): np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}") print(f"compared {len(tg_weights)} weights") -class TestTorchLoad(unittest.TestCase): +class TestTorchLoad(TempDirTestCase): # pytorch pkl format def test_load_enet(self): compare_weights_both("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth") # pytorch zip format @@ -42,6 +42,13 @@ class TestTorchLoad(unittest.TestCase): # pytorch tar format def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth') + # shared storage (mixtral-8x7b-32kseqlen) + def test_shared_storage(self): + import torch + fn = self.tmp("shared_storage.pth") + torch.save({"a": (a := torch.randn(100)), "b": a[5:]}, fn) + compare_weights_both(fn) + test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth" #test_size = test_fn.stat().st_size test_size = 1024*1024*1024*2 diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 3d078674a7..b844872eee 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -3,7 +3,7 @@ from collections import OrderedDict from typing import Any, Callable, BinaryIO, Iterable, cast from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes -from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T, strides_for_shape +from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, GlobalCounters, tqdm, round_up, T, strides_for_shape class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): @@ -165,22 +165,20 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr @accept_filename def zip_extract(t: Tensor) -> dict[str, Tensor]: files: dict[str, Tensor] = {} - file_offsets: dict[str, tuple[Tensor, int, int]] = {} with zipfile.ZipFile(TensorIO(t), "r") as myzip: - for zi in myzip.filelist: - file_offset = zi.header_offset+30+t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to("CPU").sum() - file_offsets[zi.filename] = (file_offset, zi.compress_size, zi.compress_type) - # sadly, the extra length needs to be read from the local header of each file. this is a limitation of the zip file format - Tensor.realize(*[x[0] for x in file_offsets.values()]) - for filename, (file_offset, compress_size, compress_type) in file_offsets.items(): - # possible to remove this realize/item? it's slow - file_offset_int = int(file_offset.item()) - files[filename] = t[file_offset_int:file_offset_int+compress_size] - match compress_type: - case zipfile.ZIP_STORED: pass - # TODO: we need a zlib UOp so this can be lazy - case zipfile.ZIP_DEFLATED: files[filename] = Tensor(zlib.decompress(files[filename].data(), -15)) - case _: raise NotImplementedError(f"compression {compress_type} not supported") + # sadly, the extra length needs to be read from the local header of each file. + # this is a limitation of the zip file format + header_contents = [t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to('CPU') for zi in myzip.filelist] + Tensor.realize(*header_contents) + for zi, header_content in zip(myzip.filelist, header_contents): + # header_offset + sizeFileHeader + File name length + Extra field length + file_offset = zi.header_offset + 30 + sum(cast(list[int], header_content.tolist())) + files[zi.filename] = t[file_offset:file_offset+zi.compress_size] + match zi.compress_type: + case zipfile.ZIP_STORED: pass + # TODO: we need a zlib UOp so this can be lazy + case zipfile.ZIP_DEFLATED: files[zi.filename] = Tensor(zlib.decompress(files[zi.filename].data(), -15)) + case _: raise NotImplementedError(f"compression {zi.compress_type} not supported") return files @accept_filename @@ -201,7 +199,6 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]: # torch support! -# TODO: this should use tar_extract and zip_extract @accept_filename def torch_load(t:Tensor) -> dict[str, Tensor]: """ @@ -215,7 +212,7 @@ def torch_load(t:Tensor) -> dict[str, Tensor]: state_dict = nn.state.torch_load("test.pth") ``` """ - offsets: dict[str|int, int] = {} + storage_source: dict[str|int, Tensor] = {} lens: dict[str|int, int] = {} def _rebuild_tensor(storage, storage_offset, size, stride): @@ -224,9 +221,9 @@ def torch_load(t:Tensor) -> dict[str, Tensor]: def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None): #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) lens[storage[2]] = storage[4] * storage[1].itemsize - if storage[2] not in offsets: return None - byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize - ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1]) + if storage[2] not in storage_source: return None + byte_start, byte_end = storage_offset*storage[1].itemsize, (storage_offset + prod(size))*storage[1].itemsize + ret = storage_source[storage[2]][byte_start:byte_end].bitcast(storage[1]) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] @@ -262,43 +259,36 @@ def torch_load(t:Tensor) -> dict[str, Tensor]: fobj = io.BufferedReader(TensorIO(t)) def passthrough_reset(v: bool): return fobj.seek(0, 0) or v - if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14 - myzip = zipfile.ZipFile(fobj, 'r') - base_name = None - header_offsets = {} - for zi in myzip.filelist: - if base_name is None: base_name = zi.filename.split('/', 1)[0] - if zi.filename.startswith(f'{base_name}/data/'): header_offsets[zi.filename.split("/")[-1]] = zi.header_offset - # sadly there's no way to get the start of the file in the zip without reading the header - # at least here we read them in parallel - header_contents = [t[v+26:v+30].bitcast(dtypes.uint16).to('CPU') for v in header_offsets.values()] - Tensor.realize(*header_contents) - for (n,o),c in zip(header_offsets.items(), header_contents): - # header_offset + sizeFileHeader + File name length + Extra field length : https://en.wikipedia.org/wiki/ZIP_(file_format) - offsets[n] = o+30+sum(cast(list[int], c.tolist())) - with myzip.open(f'{base_name}/data.pkl') as myfile: - return TorchPickle(myfile).load() + files = zip_extract(t) + base_name = next(iter(files)).split('/', 1)[0] + # keyed by persistent_id in pickle file + storage_source = {fn.split("/")[-1]: data for fn, data in files.items() if fn.startswith(f"{base_name}/data/") and not fn.endswith(".pkl")} + return TorchPickle(io.BufferedReader(TensorIO(files[f"{base_name}/data.pkl"]), 1_000_000)).load() elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11 - with tarfile.open(fileobj=fobj, mode="r") as tar: - storages_offset = tar.getmember('storages').offset_data - f = unwrap(tar.extractfile('storages')) - for i in range(TorchPickle(f).load()): # num_storages - (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('