From 0fd6d7482be884b5b6d79faa253a06ad2086d8d2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 28 Dec 2024 20:13:25 -0500 Subject: [PATCH] minor cleanups in state.py [pr] (#8438) --- tinygrad/nn/state.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 0830d5049c..9400dd2e82 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,4 +1,5 @@ import json, pathlib, zipfile, pickle, tarfile, struct, functools, io +from collections import OrderedDict from typing import Union, Optional, Any, Callable, BinaryIO, Iterable from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes @@ -8,12 +9,12 @@ from tinygrad.multi import MultiLazyBuffer class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): - if len(t.shape) != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!") + if t.ndim != 1 or t.dtype != dtypes.uint8: raise ValueError("Tensor must be 1d and of dtype uint8!") self._position, self._tensor = 0, t def readable(self) -> bool: return True def read(self, size: int = -1) -> bytes: - if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens, if readinto returns None (never) + if (buf:=super().read(size)) is None: raise ValueError("io.RawIOBase.read returned None") # only happens if readinto returns None (never) return buf def readinto(self, buffer: Any) -> int: data = self._tensor[self._position:self._position+len(buffer)].data() @@ -76,7 +77,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any headers[k] = {'dtype': inverse_safe_dtypes[v.dtype], 'shape': list(v.shape), 'data_offsets':[offset, offset+v.nbytes()]} offset += v.nbytes() j = json.dumps(headers, separators=(',', ':')) - j += "\x20"*((8-len(j)%8)%8) + j += "\x20"*(round_up(len(j),8)-len(j)) pathlib.Path(fn).unlink(missing_ok=True) t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") t[0:8].bitcast(dtypes.int64).assign([len(j)]) @@ -85,7 +86,6 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any # state dict -from collections import OrderedDict def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]: """ Returns a state_dict of the object, with optional prefix. @@ -110,6 +110,7 @@ def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]: elif isinstance(obj, dict): for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) return state_dict + def get_parameters(obj) -> list[Tensor]: """ ```python exec="true" source="above" session="tensor" result="python"