Revert "use zip_extract and tar_extract in torch load (#14734)" (#14745)

This reverts commit 9d9ef81608.
This commit is contained in:
George Hotz
2026-02-14 13:24:01 +08:00
committed by GitHub
parent eaa9506a00
commit e35bd960e8

View File

@@ -3,7 +3,7 @@ from collections import OrderedDict
from typing import Any, Callable, BinaryIO, Iterable, cast
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, GlobalCounters, tqdm, round_up, T, strides_for_shape
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T, strides_for_shape
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
@@ -165,20 +165,22 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
@accept_filename
def zip_extract(t: Tensor) -> dict[str, Tensor]:
files: dict[str, Tensor] = {}
file_offsets: dict[str, tuple[Tensor, int, int]] = {}
with zipfile.ZipFile(TensorIO(t), "r") as myzip:
# sadly, the extra length needs to be read from the local header of each file.
# this is a limitation of the zip file format
header_contents = [t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to('CPU') for zi in myzip.filelist]
Tensor.realize(*header_contents)
for zi, header_content in zip(myzip.filelist, header_contents):
# header_offset + sizeFileHeader + File name length + Extra field length
file_offset = zi.header_offset + 30 + sum(cast(list[int], header_content.tolist()))
files[zi.filename] = t[file_offset:file_offset+zi.compress_size]
match zi.compress_type:
case zipfile.ZIP_STORED: pass
# TODO: we need a zlib UOp so this can be lazy
case zipfile.ZIP_DEFLATED: files[zi.filename] = Tensor(zlib.decompress(files[zi.filename].data(), -15))
case _: raise NotImplementedError(f"compression {zi.compress_type} not supported")
for zi in myzip.filelist:
file_offset = zi.header_offset+30+t[zi.header_offset+26:zi.header_offset+30].bitcast(dtypes.uint16).to("CPU").sum()
file_offsets[zi.filename] = (file_offset, zi.compress_size, zi.compress_type)
# sadly, the extra length needs to be read from the local header of each file. this is a limitation of the zip file format
Tensor.realize(*[x[0] for x in file_offsets.values()])
for filename, (file_offset, compress_size, compress_type) in file_offsets.items():
# possible to remove this realize/item? it's slow
file_offset_int = int(file_offset.item())
files[filename] = t[file_offset_int:file_offset_int+compress_size]
match compress_type:
case zipfile.ZIP_STORED: pass
# TODO: we need a zlib UOp so this can be lazy
case zipfile.ZIP_DEFLATED: files[filename] = Tensor(zlib.decompress(files[filename].data(), -15))
case _: raise NotImplementedError(f"compression {compress_type} not supported")
return files
@accept_filename
@@ -199,6 +201,7 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
# torch support!
# TODO: this should use tar_extract and zip_extract
@accept_filename
def torch_load(t:Tensor) -> dict[str, Tensor]:
"""
@@ -212,7 +215,7 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
state_dict = nn.state.torch_load("test.pth")
```
"""
storage_source: dict[str|int, Tensor] = {}
offsets: dict[str|int, int] = {}
lens: dict[str|int, int] = {}
def _rebuild_tensor(storage, storage_offset, size, stride):
@@ -221,8 +224,9 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in storage_source: return None
ret = storage_source[storage[2]].bitcast(storage[1])
if storage[2] not in offsets: return None
byte_offset = offsets[storage[2]]+storage_offset*storage[1].itemsize
ret = t[byte_offset:byte_offset+prod(size)*storage[1].itemsize].bitcast(storage[1])
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s,st in zip(size, stride) if s != 1]
@@ -258,36 +262,43 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
fobj = io.BufferedReader(TensorIO(t))
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
files = zip_extract(t)
base_name = next(iter(files)).split('/', 1)[0]
# keyed by persistent_id in pickle file
storage_source = {fn.split("/")[-1]: data for fn, data in files.items() if fn.startswith(f"{base_name}/data/") and not fn.endswith(".pkl")}
return TorchPickle(io.BufferedReader(TensorIO(files[f"{base_name}/data.pkl"]), 1_000_000)).load()
myzip = zipfile.ZipFile(fobj, 'r')
base_name = None
header_offsets = {}
for zi in myzip.filelist:
if base_name is None: base_name = zi.filename.split('/', 1)[0]
if zi.filename.startswith(f'{base_name}/data/'): header_offsets[zi.filename.split("/")[-1]] = zi.header_offset
# sadly there's no way to get the start of the file in the zip without reading the header
# at least here we read them in parallel
header_contents = [t[v+26:v+30].bitcast(dtypes.uint16).to('CPU') for v in header_offsets.values()]
Tensor.realize(*header_contents)
for (n,o),c in zip(header_offsets.items(), header_contents):
# header_offset + sizeFileHeader + File name length + Extra field length : https://en.wikipedia.org/wiki/ZIP_(file_format)
offsets[n] = o+30+sum(cast(list[int], c.tolist()))
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
files = tar_extract(t)
f = io.BufferedReader(TensorIO(files["storages"]), 1_000_000)
# slice source tensor t
for _ in range(TorchPickle(f).load()):
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
byte_offset = f.tell()
storage_source[key] = files["storages"][byte_offset:byte_offset + sz * storage_type.itemsize]
f.seek(sz * storage_type.itemsize, 1)
f = io.BufferedReader(TensorIO(files["tensors"]), 1_000_000)
# get tensor metadata
for _ in range(TorchPickle(f).load()):
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset = struct.unpack('<q', f.read(8))[0]
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
pkl_data = TorchPickle(io.BufferedReader(TensorIO(files["pickle"]), 1_000_000)).load()
return {k: v.tensor if isinstance(v, Parameter) else v for k, v in pkl_data.items()}
with tarfile.open(fileobj=fobj, mode="r") as tar:
storages_offset = tar.getmember('storages').offset_data
f = unwrap(tar.extractfile('storages'))
for i in range(TorchPickle(f).load()): # num_storages
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
offsets[key] = storages_offset + f.tell()
f.seek(sz*storage_type.itemsize, 1)
f = unwrap(tar.extractfile('tensors'))
for _ in range(TorchPickle(f).load()): # num_tensors
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
size, stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset = struct.unpack('<q', f.read(8))[0]
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
else:
pkl = TorchPickle(fobj)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
# slice source tensor t
for i in ids:
storage_source[i] = t[base_offset + 8:base_offset + 8 + lens[i]]
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
fobj.seek(rwd)
return TorchPickle(fobj).load()