mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix llama 13B weights loading (#700)
* Fix llama 13B weights loading * refactor more * add test * test storage offset * fix spacing * fix strides * llama 13B working? * yolo? * better test for seeks
This commit is contained in:
@@ -223,16 +223,16 @@ if __name__ == "__main__":
|
||||
from extra.utils import fake_torch_load_zipped, get_child
|
||||
|
||||
if args.large:
|
||||
raise RuntimeError("large model is broken")
|
||||
model = Transformer(**args_13B)
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
weights0 = fake_torch_load_zipped(open(WEIGHTS0_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated.00")
|
||||
weights1 = fake_torch_load_zipped(open(WEIGHTS1_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated.01")
|
||||
weights0 = fake_torch_load_zipped(open(WEIGHTS0_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1))
|
||||
weights1 = fake_torch_load_zipped(open(WEIGHTS1_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1))
|
||||
# eww, this makes a copy
|
||||
print("concatenating weights")
|
||||
from tqdm import tqdm
|
||||
assert set(weights0.keys()) == set(weights1.keys())
|
||||
for k,v in (t := tqdm(weights0.items())):
|
||||
assert GlobalCounters.mem_used/1e9 < 28, "used over 28 GB"
|
||||
# assert GlobalCounters.mem_used/1e9 < 28, "used over 28 GB"
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
if 'rope.freqs' in k: continue # no rope today
|
||||
mv = get_child(model, k)
|
||||
@@ -241,21 +241,12 @@ if __name__ == "__main__":
|
||||
# if the weight is copied across models, it's simple
|
||||
# TODO: assert they are the same
|
||||
if w0.shape == mv.shape:
|
||||
mv.lazydata.realized = w0
|
||||
w0._buf = None
|
||||
mv.assign(w0)
|
||||
mv.realize()
|
||||
continue
|
||||
|
||||
# we have to concatenate them, create tensors
|
||||
w0t = Tensor.empty(*w0.shape)
|
||||
w1t = Tensor.empty(*w1.shape)
|
||||
w0t.lazydata.realized = w0
|
||||
w1t.lazydata.realized = w1
|
||||
|
||||
# terrible hacks. force create the output buffer as float16
|
||||
mv.lazydata.realized = Device._buffers[Device.DEFAULT].empty(mv.shape, dtype=w0.dtype)
|
||||
|
||||
if w0.shape[0] != mv.shape[0]: mv.assign(w0t.cat(w1t, dim=0))
|
||||
elif w0.shape[1] != mv.shape[1]: mv.assign(w0t.cat(w1t, dim=1))
|
||||
if w0.shape[0] != mv.shape[0]: mv.assign(w0.cat(w1, dim=0))
|
||||
elif w0.shape[1] != mv.shape[1]: mv.assign(w0.cat(w1, dim=1))
|
||||
else: raise RuntimeError("what axis mismatch?")
|
||||
mv.realize()
|
||||
|
||||
@@ -268,7 +259,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
model = Transformer(**args_7B)
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated")
|
||||
weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1))
|
||||
|
||||
#from tinygrad.nn.optim import get_state_dict
|
||||
#state_dict = get_state_dict(model)
|
||||
|
||||
Reference in New Issue
Block a user