From 046b3952c361f5152e3b2fd8ad95dfb113b95f37 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 11 Mar 2023 23:46:53 -0800 Subject: [PATCH] get_state_dict --- examples/llama.py | 7 +++++-- tinygrad/nn/optim.py | 19 ++++++++++--------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/llama.py b/examples/llama.py index 0dbe8f254a..f02e19f23a 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -22,6 +22,7 @@ from extra.helpers import Timing from tinygrad.tensor import Tensor from tinygrad.nn import Linear from tinygrad.ops import GlobalCounters +from tinygrad.nn.optim import get_state_dict # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): @@ -197,7 +198,7 @@ if __name__ == "__main__": chatbot = args.prompt == None # load model (you have to find the weights yourself) - from extra.utils import fake_torch_load_zipped, get_child + from extra.utils import fake_torch_load_zipped if args.large: raise RuntimeError("large model is broken") @@ -247,10 +248,12 @@ if __name__ == "__main__": with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"): weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated") + state_dict = get_state_dict(model) + # assign weights (should be free) for k,v in weights.items(): if '.inner_attention.rope.freqs' in k: continue # no rope today - mv = get_child(model, k) + mv = state_dict[k] assert mv.shape == v.shape, f"shape mismatch in {k}, {mv.shape} != {v.shape}" mv.assign(v).realize() diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index b92389e815..cd37f9fab7 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -1,5 +1,5 @@ # sorted in order of increasing complexity -from typing import List +from typing import List, Dict, Optional from tinygrad.tensor import Tensor class Optimizer: @@ -79,12 +79,13 @@ class Adam(Optimizer): t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) self.realize([self.t] + self.m + self.v) -def get_parameters(obj) -> List[Tensor]: - parameters: List[Tensor] = [] - if isinstance(obj, Tensor): - parameters.append(obj) +def get_state_dict(obj, arg:Optional[List[str]]=None, _params:Optional[Dict[str, Tensor]]=None) -> Dict[str, Tensor]: + if arg is None or _params is None: arg, _params = [], {} + if isinstance(obj, Tensor): _params['.'.join(arg)] = obj + elif hasattr(obj, '__dict__'): get_state_dict(obj.__dict__, arg, _params) elif isinstance(obj, (list, tuple)): - for x in obj: parameters.extend(get_parameters(x)) - elif hasattr(obj, '__dict__'): - for v in obj.__dict__.values(): parameters.extend(get_parameters(v)) - return parameters + for i,x in enumerate(obj): get_state_dict(x, arg+[str(i)], _params) + elif isinstance(obj, dict): + for k,v in obj.items(): get_state_dict(v, arg+[k], _params) + return _params +def get_parameters(obj) -> List[Tensor]: return list(get_state_dict(obj).values())