From e35bd960e8512dbc139522502b89b455173df876 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 14 Feb 2026 13:24:01 +0800 Subject: [PATCH] Revert "use zip_extract and tar_extract in torch load (#14734)" (#14745) This reverts commit 9d9ef816081cdd4a2a944bfe55a096efc19430ed. --- tinygrad/nn/state.py | 93 +++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index f8c08c806a..3d078674a7 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -3,7 +3,7 @@ from collections import OrderedDict from typing import Any, Callable, BinaryIO, Iterable, cast from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes -from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, GlobalCounters, tqdm, round_up, T, strides_for_shape +from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T, strides_for_shape class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): @@ -165,20 +165,22 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr @accept_filename def zip_extract(t: Tensor) -> dict[str, Tensor]: files: dict[str, Tensor] = {} + file_offsets: dict[str, tuple[Tensor, int, int]] = {} with zipfile.ZipFile(TensorIO(t), "r") as myzip: - # sadly, the extra length needs to be read from the local header of each file. - # this is a limitation of the zip file format - header_contents = [t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to('CPU') for zi in myzip.filelist] - Tensor.realize(*header_contents) - for zi, header_content in zip(myzip.filelist, header_contents): - # header_offset + sizeFileHeader + File name length + Extra field length - file_offset = zi.header_offset + 30 + sum(cast(list[int], header_content.tolist())) - files[zi.filename] = t[file_offset:file_offset+zi.compress_size] - match zi.compress_type: - case zipfile.ZIP_STORED: pass - # TODO: we need a zlib UOp so this can be lazy - case zipfile.ZIP_DEFLATED: files[zi.filename] = Tensor(zlib.decompress(files[zi.filename].data(), -15)) - case _: raise NotImplementedError(f"compression {zi.compress_type} not supported") + for zi in myzip.filelist: + file_offset = zi.header_offset+30+t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to("CPU").sum() + file_offsets[zi.filename] = (file_offset, zi.compress_size, zi.compress_type) + # sadly, the extra length needs to be read from the local header of each file. this is a limitation of the zip file format + Tensor.realize(*[x[0] for x in file_offsets.values()]) + for filename, (file_offset, compress_size, compress_type) in file_offsets.items(): + # possible to remove this realize/item? it's slow + file_offset_int = int(file_offset.item()) + files[filename] = t[file_offset_int:file_offset_int+compress_size] + match compress_type: + case zipfile.ZIP_STORED: pass + # TODO: we need a zlib UOp so this can be lazy + case zipfile.ZIP_DEFLATED: files[filename] = Tensor(zlib.decompress(files[filename].data(), -15)) + case _: raise NotImplementedError(f"compression {compress_type} not supported") return files @accept_filename @@ -199,6 +201,7 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]: # torch support! +# TODO: this should use tar_extract and zip_extract @accept_filename def torch_load(t:Tensor) -> dict[str, Tensor]: """ @@ -212,7 +215,7 @@ def torch_load(t:Tensor) -> dict[str, Tensor]: state_dict = nn.state.torch_load("test.pth") ``` """ - storage_source: dict[str|int, Tensor] = {} + offsets: dict[str|int, int] = {} lens: dict[str|int, int] = {} def _rebuild_tensor(storage, storage_offset, size, stride): @@ -221,8 +224,9 @@ def torch_load(t:Tensor) -> dict[str, Tensor]: def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None): #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) lens[storage[2]] = storage[4] * storage[1].itemsize - if storage[2] not in storage_source: return None - ret = storage_source[storage[2]].bitcast(storage[1]) + if storage[2] not in offsets: return None + byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize + ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1]) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1] @@ -258,36 +262,43 @@ def torch_load(t:Tensor) -> dict[str, Tensor]: fobj = io.BufferedReader(TensorIO(t)) def passthrough_reset(v: bool): return fobj.seek(0, 0) or v + if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14 - files = zip_extract(t) - base_name = next(iter(files)).split('/', 1)[0] - # keyed by persistent_id in pickle file - storage_source = {fn.split("/")[-1]: data for fn, data in files.items() if fn.startswith(f"{base_name}/data/") and not fn.endswith(".pkl")} - return TorchPickle(io.BufferedReader(TensorIO(files[f"{base_name}/data.pkl"]), 1_000_000)).load() + myzip = zipfile.ZipFile(fobj, 'r') + base_name = None + header_offsets = {} + for zi in myzip.filelist: + if base_name is None: base_name = zi.filename.split('/', 1)[0] + if zi.filename.startswith(f'{base_name}/data/'): header_offsets[zi.filename.split("/")[-1]] = zi.header_offset + # sadly there's no way to get the start of the file in the zip without reading the header + # at least here we read them in parallel + header_contents = [t[v+26:v+30].bitcast(dtypes.uint16).to('CPU') for v in header_offsets.values()] + Tensor.realize(*header_contents) + for (n,o),c in zip(header_offsets.items(), header_contents): + # header_offset + sizeFileHeader + File name length + Extra field length : https://en.wikipedia.org/wiki/ZIP_(file_format) + offsets[n] = o+30+sum(cast(list[int], c.tolist())) + with myzip.open(f'{base_name}/data.pkl') as myfile: + return TorchPickle(myfile).load() elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11 - files = tar_extract(t) - f = io.BufferedReader(TensorIO(files["storages"]), 1_000_000) - # slice source tensor t - for _ in range(TorchPickle(f).load()): - (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('