mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-17 18:11:49 -05:00
Fast DiskTensor to other Tensor (#916)
* make disktensors fast * loading * loader for sd and llama
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# sorted in order of increasing complexity
|
||||
from typing import List, Dict
|
||||
from typing import List
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class Optimizer:
|
||||
@@ -67,15 +67,6 @@ class LAMB(Optimizer):
|
||||
t.assign(t.detach() - self.lr * r * up)
|
||||
self.realize([self.t] + self.m + self.v)
|
||||
|
||||
from collections import OrderedDict
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> Dict[str, Tensor]:
|
||||
if isinstance(obj, tensor_type): return {prefix.strip('.'):obj}
|
||||
if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type)
|
||||
if hasattr(obj, '__dict__'): return get_state_dict(obj.__dict__, prefix, tensor_type)
|
||||
state_dict = {}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i,x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
|
||||
elif isinstance(obj, dict):
|
||||
for k,v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
|
||||
return state_dict
|
||||
def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())
|
||||
# TODO: remove this
|
||||
from tinygrad.state import get_state_dict, get_parameters # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
|
||||
Reference in New Issue
Block a user