Fast DiskTensor to other Tensor (#916)

* make disktensors fast

* loading

* loader for sd and llama
This commit is contained in:
George Hotz
2023-06-03 12:25:41 -07:00
committed by GitHub
parent 791530045d
commit ed1963b899
11 changed files with 109 additions and 76 deletions

View File

@@ -1,5 +1,5 @@
# sorted in order of increasing complexity
from typing import List, Dict
from typing import List
from tinygrad.tensor import Tensor
class Optimizer:
@@ -67,15 +67,6 @@ class LAMB(Optimizer):
t.assign(t.detach() - self.lr * r * up)
self.realize([self.t] + self.m + self.v)
from collections import OrderedDict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
# TODO: remove this
from tinygrad.state import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401