mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
get_state_dict
This commit is contained in:
@@ -22,6 +22,7 @@ from extra.helpers import Timing
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Linear
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.nn.optim import get_state_dict
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||
@@ -197,7 +198,7 @@ if __name__ == "__main__":
|
||||
chatbot = args.prompt == None
|
||||
|
||||
# load model (you have to find the weights yourself)
|
||||
from extra.utils import fake_torch_load_zipped, get_child
|
||||
from extra.utils import fake_torch_load_zipped
|
||||
|
||||
if args.large:
|
||||
raise RuntimeError("large model is broken")
|
||||
@@ -247,10 +248,12 @@ if __name__ == "__main__":
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated")
|
||||
|
||||
state_dict = get_state_dict(model)
|
||||
|
||||
# assign weights (should be free)
|
||||
for k,v in weights.items():
|
||||
if '.inner_attention.rope.freqs' in k: continue # no rope today
|
||||
mv = get_child(model, k)
|
||||
mv = state_dict[k]
|
||||
assert mv.shape == v.shape, f"shape mismatch in {k}, {mv.shape} != {v.shape}"
|
||||
mv.assign(v).realize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user