mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
minor cleanups in state.py [pr] (#8438)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
||||
from collections import OrderedDict
|
||||
from typing import Union, Optional, Any, Callable, BinaryIO, Iterable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -8,12 +9,12 @@ from tinygrad.multi import MultiLazyBuffer
|
||||
|
||||
class TensorIO(io.RawIOBase, BinaryIO):
|
||||
def __init__(self, t: Tensor):
|
||||
if len(t.shape) != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
|
||||
if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!")
|
||||
self._position, self._tensor = 0, t
|
||||
|
||||
def readable(self) -> bool: return True
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens, if readinto returns None (never)
|
||||
if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never)
|
||||
return buf
|
||||
def readinto(self, buffer: Any) -> int:
|
||||
data = self._tensor[self._position:self._position+len(buffer)].data()
|
||||
@@ -76,7 +77,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
|
||||
headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]}
|
||||
offset += v.nbytes()
|
||||
j = json.dumps(headers, separators=(',', ':'))
|
||||
j += "\x20"*((8-len(j)%8)%8)
|
||||
j += "\x20"*(round_up(len(j),8)-len(j))
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:8].bitcast(dtypes.int64).assign([len(j)])
|
||||
@@ -85,7 +86,6 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
|
||||
|
||||
# state dict
|
||||
|
||||
from collections import OrderedDict
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
||||
"""
|
||||
Returns a state_dict of the object, with optional prefix.
|
||||
@@ -110,6 +110,7 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
||||
elif isinstance(obj, dict):
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||||
return state_dict
|
||||
|
||||
def get_parameters(obj) -> list[Tensor]:
|
||||
"""
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
|
||||
Reference in New Issue
Block a user