Fix llama 13B weights loading (#700)

* Fix llama 13B weights loading

* refactor more

* add test

* test storage offset

* fix spacing

* fix strides

* llama 13B working?

* yolo?

* better test for seeks
This commit is contained in:
Kirill
2023-03-15 18:59:52 +03:00
committed by GitHub
parent df48753692
commit 0532025b04
3 changed files with 87 additions and 59 deletions

View File

@@ -223,16 +223,16 @@ if __name__ == "__main__":
from extra.utils import fake_torch_load_zipped, get_child
if args.large:
raise RuntimeError("large model is broken")
model = Transformer(**args_13B)
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
weights0 = fake_torch_load_zipped(open(WEIGHTS0_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated.00")
weights1 = fake_torch_load_zipped(open(WEIGHTS1_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated.01")
weights0 = fake_torch_load_zipped(open(WEIGHTS0_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1))
weights1 = fake_torch_load_zipped(open(WEIGHTS1_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1))
# eww, this makes a copy
print("concatenating weights")
from tqdm import tqdm
assert set(weights0.keys()) == set(weights1.keys())
for k,v in (t := tqdm(weights0.items())):
assert GlobalCounters.mem_used/1e9 < 28, "used over 28 GB"
# assert GlobalCounters.mem_used/1e9 < 28, "used over 28 GB"
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
if 'rope.freqs' in k: continue # no rope today
mv = get_child(model, k)
@@ -241,21 +241,12 @@ if __name__ == "__main__":
# if the weight is copied across models, it's simple
# TODO: assert they are the same
if w0.shape == mv.shape:
mv.lazydata.realized = w0
w0._buf = None
mv.assign(w0)
mv.realize()
continue
# we have to concatenate them, create tensors
w0t = Tensor.empty(*w0.shape)
w1t = Tensor.empty(*w1.shape)
w0t.lazydata.realized = w0
w1t.lazydata.realized = w1
# terrible hacks. force create the output buffer as float16
mv.lazydata.realized = Device._buffers[Device.DEFAULT].empty(mv.shape, dtype=w0.dtype)
if w0.shape[0] != mv.shape[0]: mv.assign(w0t.cat(w1t, dim=0))
elif w0.shape[1] != mv.shape[1]: mv.assign(w0t.cat(w1t, dim=1))
if w0.shape[0] != mv.shape[0]: mv.assign(w0.cat(w1, dim=0))
elif w0.shape[1] != mv.shape[1]: mv.assign(w0.cat(w1, dim=1))
else: raise RuntimeError("what axis mismatch?")
mv.realize()
@@ -268,7 +259,7 @@ if __name__ == "__main__":
else:
model = Transformer(**args_7B)
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated")
weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1))
#from tinygrad.nn.optim import get_state_dict
#state_dict = get_state_dict(model)