third try at torch loading (#677)

* third try at torch loading

* numpy fixed

* fix enet compile

* load_single_weight supports empty weights

* oops, CPU wasn't the default

* so many bugs
This commit is contained in:
George Hotz
2023-03-10 19:11:29 -08:00
committed by GitHub
parent 8b7a16cf85
commit b1206bcb18
7 changed files with 74 additions and 38 deletions

View File

@@ -243,14 +243,12 @@ if __name__ == "__main__":
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated")
# assign weights
for k,v in (t := tqdm(weights.items())):
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB assigning {k}")
if '.inner_attention.rope.freqs' in k: continue # no rope today
mv = get_child(model, k)
assert mv.shape == v.shape, f"shape mismatch in {k}, {mv.shape} != {v.shape}"
mv.assign(v).realize()
#mv.lazydata.realized = v
# assign weights (should be free)
for k,v in weights.items():
if '.inner_attention.rope.freqs' in k: continue # no rope today
mv = get_child(model, k)
assert mv.shape == v.shape, f"shape mismatch in {k}, {mv.shape} != {v.shape}"
mv.assign(v).realize()
del weights

View File

@@ -619,7 +619,7 @@ if __name__ == "__main__":
skip_if_exists=True
)
dat = fake_torch_load_zipped(open(FILENAME, "rb"))
for k,v in tqdm(dat['state_dict'].items()):
for k,v in dat['state_dict'].items():
try:
w = get_child(model, k)
except (AttributeError, KeyError, IndexError):
@@ -627,7 +627,7 @@ if __name__ == "__main__":
w = None
#print(f"{str(v.shape):30s}" if v is not None else v, w.shape if w is not None else w, k)
if w is not None:
assert w.shape == v.shape and w.dtype.np == v.dtype, f"shape or dtype mismatch. {w.shape} != {v.shape} or {w.dtype.np} != {v.dtype}"
assert w.shape == v.shape and w.dtype == v.dtype, f"shape or dtype mismatch. {w.shape} != {v.shape} or {w.dtype} != {v.dtype}"
w.assign(v)
# run through CLIP to get context

View File

@@ -2,7 +2,12 @@ import pickle
import numpy as np
from tqdm import tqdm
import tempfile
from tinygrad.helpers import prod, getenv
from collections import defaultdict
from tinygrad.helpers import prod, getenv, DEBUG
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyNumpyArray
from tinygrad.shape import strides_for_shape
def fetch(url):
if url.startswith("/"):
@@ -28,16 +33,20 @@ def download_file(url, fp, skip_if_exists=False):
os.rename(f.name, fp)
def my_unpickle(fb0):
key_prelookup = {}
key_prelookup = defaultdict(list)
class HackTensor:
def __new__(cls, *args):
#print(args)
ident, storage_type, obj_key, location, obj_size = args[0][0:5]
assert ident == 'storage'
assert prod(args[2]) == obj_size
ret = np.zeros(args[2], dtype=storage_type)
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
if storage_type not in [np.float16, np.float32]:
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {args[2]}")
ret = None
else:
ret = Tensor(LazyNumpyArray(None, tuple(args[2]), storage_type))
key_prelookup[obj_key].append((storage_type, obj_size, ret, args[2], args[3]))
return ret
class HackParameter:
@@ -55,6 +64,8 @@ def my_unpickle(fb0):
return np.float32
if name == 'LongStorage':
return np.int64
if name == 'IntStorage':
return np.int32
if name == 'HalfStorage':
return np.float16
if module == "torch._utils":
@@ -73,18 +84,49 @@ def my_unpickle(fb0):
return MyPickle(fb0).load(), key_prelookup
def load_single_weight(t:Tensor, myfile, shape, strides, dtype):
bytes_size = np.dtype(dtype).itemsize
if t is None:
myfile.seek(prod(shape) * bytes_size, 1)
return
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
# slow path
np_array = np.frombuffer(myfile.read(prod(t.shape) * t.dtype.itemsize), t.dtype.np).reshape(t.shape)
real_strides = tuple([x*t.dtype.itemsize for x in strides]) # numpy stores its strides in bytes
np_array.strides = real_strides
lna = t.lazydata.op.arg
lna.fxn = lambda _: np_array
t.realize()
return
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
# this needs real APIs
if t.device in ["METAL", "CLANG", "LLVM"]:
del t.lazydata.op
t.lazydata.realized = t.lazydata.dbuffer(t.shape, dtype=t.dtype)
myfile.readinto(t.lazydata.realized.raw()._buffer())
else:
lna = t.lazydata.op.arg
lna.fxn = lambda lna: np.frombuffer(myfile.read(prod(t.shape) * t.dtype.itemsize), lna.dtype).reshape(lna.shape)
t.realize()
def fake_torch_load_zipped(fb0, load_weights=True, base_name="archive"):
import zipfile
with zipfile.ZipFile(fb0, 'r') as myzip:
with myzip.open(f'{base_name}/data.pkl') as myfile:
ret = my_unpickle(myfile)
if load_weights:
for k,v in tqdm(ret[1].items()):
def load_weight(k, vv):
with myzip.open(f'{base_name}/data/{k}') as myfile:
if v[2].dtype == "object":
print(f"issue assigning object on {k}")
continue
np.copyto(v[2], np.frombuffer(myfile.read(), v[2].dtype).reshape(v[3]))
for v in vv:
load_single_weight(v[2], myfile, v[3], v[4], v[0])
for k,v in (t := tqdm(ret[1].items())):
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
load_weight(k,v)
return ret[0]
def fake_torch_load(b0):
@@ -108,19 +150,14 @@ def fake_torch_load(b0):
key_lookup = pickle.load(fb0)
key_real = [None] * len(key_lookup)
for k,v in key_prelookup.items():
key_real[key_lookup.index(k)] = v
assert len(v) == 1
key_real[key_lookup.index(k)] = v[0]
# read in the actual data
for storage_type, obj_size, np_array, np_shape, np_strides in key_real:
for storage_type, obj_size, tensor, np_shape, np_strides in key_real:
ll = struct.unpack("Q", fb0.read(8))[0]
assert ll == obj_size
bytes_size = {np.float32: 4, np.int64: 8}[storage_type]
mydat = fb0.read(ll * bytes_size)
np.copyto(np_array, np.frombuffer(mydat, storage_type).reshape(np_shape))
# numpy stores its strides in bytes
real_strides = tuple([x*bytes_size for x in np_strides])
np_array.strides = real_strides
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type)
return ret

View File

@@ -146,6 +146,7 @@ class EfficientNet:
b0 = fake_torch_load(fetch(model_urls[self.number]))
for k,v in b0.items():
if k.endswith("num_batches_tracked"): continue
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
if cat in k:
k = k.replace('.bias', '_bias')
@@ -153,9 +154,9 @@ class EfficientNet:
#print(k, v.shape)
mv = get_child(self, k)
vnp = v.astype(np.float32)
vnp = vnp if k != '_fc' else vnp.T
vnp = vnp if vnp.shape != () else np.array([vnp])
vnp = v #.astype(np.float32)
vnp = vnp if k != '_fc' else vnp.transpose()
#vnp = vnp if vnp.shape != () else np.array([vnp])
if mv.shape == vnp.shape:
mv.assign(vnp)

View File

@@ -16,7 +16,7 @@ LAZY = getenv("LAZY", 1)
class _Device:
def __init__(self) -> None:
self._buffers = {y.upper():y for y in [os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")]}
self.DEFAULT : str = functools.reduce(lambda val, ele: val if getenv(val) == 1 else ele, self._buffers, "CPU")
self.DEFAULT : str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, "CPU")
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Type[DeviceBuffer]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{self._buffers[x]}'), inspect.isclass) if (cname.lower() == self._buffers[x] + "buffer")][0]
Device = _Device()
@@ -73,7 +73,7 @@ def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tu
class LazyNumpyArray:
def __init__(self, fxn, shape, dtype): self.fxn, self.shape, self.dtype = fxn, shape, dtype
def __call__(self): return self.fxn(self.shape, self.dtype)
def __call__(self): return self.fxn(self)
def reshape(self, new_shape): return LazyNumpyArray(self.fxn, new_shape, self.dtype)
def copy(self): return self
def astype(self, typ): return self

View File

@@ -44,7 +44,7 @@ class RawBuffer(Copyable): # pylint: disable=abstract-method
self.dtype : DType = dtype
self._memsz : int = size*dtype.itemsize
GlobalCounters.mem_used += self._memsz
def __del__(self): GlobalCounters.mem_used -= self.size*self._memsz
def __del__(self): GlobalCounters.mem_used -= self._memsz
class RawBufferCopyIn(RawBuffer):
def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")

View File

@@ -133,11 +133,11 @@ class Tensor:
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
@staticmethod
def rand(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda shape, dtype: Tensor._rng.random(size=shape, dtype=dtype), shape, np.float32), **kwargs)
def rand(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.random(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
# TODO: replace with a transformation from uniform -> gaussian
@staticmethod
def randn(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda shape, dtype: Tensor._rng.standard_normal(size=shape, dtype=dtype), shape, np.float32), **kwargs)
def randn(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.standard_normal(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
# ***** rng hlops *****