diff --git a/examples/llama.py b/examples/llama.py index 8d1b970bc2..52a8a49bc8 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -223,16 +223,16 @@ if __name__ == "__main__": from extra.utils import fake_torch_load_zipped, get_child if args.large: - raise RuntimeError("large model is broken") model = Transformer(**args_13B) with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"): - weights0 = fake_torch_load_zipped(open(WEIGHTS0_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated.00") - weights1 = fake_torch_load_zipped(open(WEIGHTS1_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated.01") + weights0 = fake_torch_load_zipped(open(WEIGHTS0_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1)) + weights1 = fake_torch_load_zipped(open(WEIGHTS1_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1)) # eww, this makes a copy print("concatenating weights") from tqdm import tqdm + assert set(weights0.keys()) == set(weights1.keys()) for k,v in (t := tqdm(weights0.items())): - assert GlobalCounters.mem_used/1e9 < 28, "used over 28 GB" + # assert GlobalCounters.mem_used/1e9 < 28, "used over 28 GB" t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB") if 'rope.freqs' in k: continue # no rope today mv = get_child(model, k) @@ -241,21 +241,12 @@ if __name__ == "__main__": # if the weight is copied across models, it's simple # TODO: assert they are the same if w0.shape == mv.shape: - mv.lazydata.realized = w0 - w0._buf = None + mv.assign(w0) + mv.realize() continue - # we have to concatenate them, create tensors - w0t = Tensor.empty(*w0.shape) - w1t = Tensor.empty(*w1.shape) - w0t.lazydata.realized = w0 - w1t.lazydata.realized = w1 - - # terrible hacks. force create the output buffer as float16 - mv.lazydata.realized = Device._buffers[Device.DEFAULT].empty(mv.shape, dtype=w0.dtype) - - if w0.shape[0] != mv.shape[0]: mv.assign(w0t.cat(w1t, dim=0)) - elif w0.shape[1] != mv.shape[1]: mv.assign(w0t.cat(w1t, dim=1)) + if w0.shape[0] != mv.shape[0]: mv.assign(w0.cat(w1, dim=0)) + elif w0.shape[1] != mv.shape[1]: mv.assign(w0.cat(w1, dim=1)) else: raise RuntimeError("what axis mismatch?") mv.realize() @@ -268,7 +259,7 @@ if __name__ == "__main__": else: model = Transformer(**args_7B) with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"): - weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated") + weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1)) #from tinygrad.nn.optim import get_state_dict #state_dict = get_state_dict(model) diff --git a/extra/utils.py b/extra/utils.py index 6abeadda6c..a91cea4343 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -35,45 +35,36 @@ def download_file(url, fp, skip_if_exists=False): def my_unpickle(fb0): key_prelookup = defaultdict(list) - class HackTensor: - def __new__(cls, *args): - #print(args) - ident, storage_type, obj_key, location, obj_size = args[0][0:5] - assert ident == 'storage' - assert prod(args[2]) == obj_size + def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): + #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) + ident, storage_type, obj_key, location, obj_size = storage[0:5] + assert ident == 'storage' + assert prod(size) <= (obj_size - storage_offset) - if storage_type not in [np.float16, np.float32]: - if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {args[2]}") - ret = None - else: - ret = Tensor(LazyNumpyArray(lambda lst: np.zeros(lst.shape, dtype=lst.dtype), tuple(args[2]), storage_type)) - key_prelookup[obj_key].append((storage_type, obj_size, ret, args[2], args[3])) - return ret + if storage_type not in [np.float16, np.float32]: + if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}") + ret = None + else: + ret = Tensor(LazyNumpyArray(lambda lst: np.zeros(lst.shape, dtype=lst.dtype), tuple(size), storage_type)) + key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset)) + return ret - class HackParameter: - def __new__(cls, *args): - #print(args) - pass - - class Dummy: + def _rebuild_parameter(*args): + #print(args) pass + class Dummy: pass + class MyPickle(pickle.Unpickler): def find_class(self, module, name): #print(module, name) - if name == 'FloatStorage': - return np.float32 - if name == 'LongStorage': - return np.int64 - if name == 'IntStorage': - return np.int32 - if name == 'HalfStorage': - return np.float16 + if name == 'FloatStorage': return np.float32 + if name == 'LongStorage': return np.int64 + if name == 'IntStorage': return np.int32 + if name == 'HalfStorage': return np.float16 if module == "torch._utils": - if name == "_rebuild_tensor_v2": - return HackTensor - elif name == "_rebuild_parameter": - return HackParameter + if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2 + if name == "_rebuild_parameter": return _rebuild_parameter else: if module.startswith('pytorch_lightning'): return Dummy try: @@ -86,19 +77,27 @@ def my_unpickle(fb0): return MyPickle(fb0).load(), key_prelookup -def load_single_weight(t:Tensor, myfile, shape, strides, dtype, mmap_allowed=False): +def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False): bytes_size = np.dtype(dtype).itemsize if t is None: myfile.seek(prod(shape) * bytes_size, 1) return + bytes_offset = 0 + if storage_offset is not None: + bytes_offset = storage_offset * bytes_size + myfile.seek(bytes_offset) + assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}" assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)): # slow path - np_array = np.frombuffer(myfile.read(prod(t.shape) * t.dtype.itemsize), t.dtype.np).reshape(t.shape) - real_strides = tuple([x*t.dtype.itemsize for x in strides]) # numpy stores its strides in bytes - np_array.strides = real_strides + buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape))) + buffer_size += t.dtype.itemsize + np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np) + + np_array = np.lib.stride_tricks.as_strided( + np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides]) lna = t.lazydata.op.arg lna.fxn = lambda _: np_array @@ -115,7 +114,7 @@ def load_single_weight(t:Tensor, myfile, shape, strides, dtype, mmap_allowed=Fal else: def _mmap(lna): assert myfile._compress_type == 0, "compressed data can't be mmaped" - return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start, shape=lna.shape) + return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape) def _read(lna): ret = np.empty(lna.shape, dtype=lna.dtype) myfile.readinto(ret.data) @@ -124,18 +123,19 @@ def load_single_weight(t:Tensor, myfile, shape, strides, dtype, mmap_allowed=Fal else: t.lazydata.op.arg.fxn = _read t.realize() -def fake_torch_load_zipped(fb0, load_weights=True, base_name="archive", multithreaded=True): +def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True): if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap import zipfile with zipfile.ZipFile(fb0, 'r') as myzip: + base_name = myzip.namelist()[0].split('/', 1)[0] with myzip.open(f'{base_name}/data.pkl') as myfile: ret = my_unpickle(myfile) if load_weights: def load_weight(k, vv): with myzip.open(f'{base_name}/data/{k}') as myfile: for v in vv: - load_single_weight(v[2], myfile, v[3], v[4], v[0], mmap_allowed=True) + load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True) if multithreaded: import concurrent.futures # 2 seems fastest @@ -176,10 +176,11 @@ def fake_torch_load(b0): key_real[key_lookup.index(k)] = v[0] # read in the actual data - for storage_type, obj_size, tensor, np_shape, np_strides in key_real: + for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real: ll = struct.unpack("Q", fb0.read(8))[0] assert ll == obj_size, f"size mismatch {ll} != {obj_size}" - load_single_weight(tensor, fb0, np_shape, np_strides, storage_type) + assert storage_offset == 0, "not implemented" + load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None) return ret diff --git a/test/extra/test_utils.py b/test/extra/test_utils.py index 4464421611..042bcf3c88 100644 --- a/test/extra/test_utils.py +++ b/test/extra/test_utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import io import unittest -from extra.utils import fetch +from extra.utils import fetch, fake_torch_load_zipped from PIL import Image class TestUtils(unittest.TestCase): @@ -18,5 +18,41 @@ class TestUtils(unittest.TestCase): pimg = Image.open(io.BytesIO(img)) assert pimg.size == (705, 1024) + def test_fake_torch_load_zipped(self): + import torch + import numpy as np + import tempfile + class LayerWithOffset(torch.nn.Module): + def __init__(self): + super(LayerWithOffset, self).__init__() + d = torch.randn(16) + self.param1 = torch.nn.Parameter( + d.as_strided([2, 2], [2, 3], storage_offset=5) + ) + self.param2 = torch.nn.Parameter( + d.as_strided([2, 2], [2, 3], storage_offset=4) + ) + + for isfloat16 in [True, False]: + model = torch.nn.Sequential( + torch.nn.Linear(4, 8), + torch.nn.Linear(8, 3), + LayerWithOffset() + ) + if isfloat16: model = model.half() + + with tempfile.TemporaryDirectory() as tmpdirname: + path = tmpdirname + '/testloadmodel.pth' + torch.save(model.state_dict(), path) + model2 = fake_torch_load_zipped(path) + + for name, a in model.state_dict().items(): + b = model2[name] + a, b = a.numpy(), b.numpy() + assert a.shape == b.shape + assert a.dtype == b.dtype + assert np.array_equal(a, b) + + if __name__ == '__main__': unittest.main() \ No newline at end of file