minor cleanups in state.py [pr] (#8438)

This commit is contained in:
chenyu
2024-12-28 20:13:25 -05:00
committed by GitHub
parent da2fa0b37f
commit 0fd6d7482b

View File

@@ -1,4 +1,5 @@
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from collections import OrderedDict
from typing import Union, Optional, Any, Callable, BinaryIO, Iterable
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
@@ -8,12 +9,12 @@ from tinygrad.multi import MultiLazyBuffer
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
if len(t.shape) != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
self._position, self._tensor = 0, t
def readable(self) -> bool: return True
def read(self, size: int = -1) -> bytes:
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens, if readinto returns None (never)
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
return buf
def readinto(self, buffer: Any) -> int:
data = self._tensor[self._position:self._position+len(buffer)].data()
@@ -76,7 +77,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
offset += v.nbytes()
j = json.dumps(headers, separators=(',', ':'))
j += "\x20"*((8-len(j)%8)%8)
j += "\x20"*(round_up(len(j),8)-len(j))
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:8].bitcast(dtypes.int64).assign([len(j)])
@@ -85,7 +86,6 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
# state dict
from collections import OrderedDict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
"""
Returns a state_dict of the object, with optional prefix.
@@ -110,6 +110,7 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> list[Tensor]:
"""
```python exec="true" source="above" session="tensor" result="python"