mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
third try at torch loading (#677)
* third try at torch loading * numpy fixed * fix enet compile * load_single_weight supports empty weights * oops, CPU wasn't the default * so many bugs
This commit is contained in:
@@ -243,14 +243,12 @@ if __name__ == "__main__":
|
||||
with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"):
|
||||
weights = fake_torch_load_zipped(open(WEIGHTS_FILENAME, "rb"), load_weights=getenv("WEIGHTS", 1), base_name="consolidated")
|
||||
|
||||
# assign weights
|
||||
for k,v in (t := tqdm(weights.items())):
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB assigning {k}")
|
||||
if '.inner_attention.rope.freqs' in k: continue # no rope today
|
||||
mv = get_child(model, k)
|
||||
assert mv.shape == v.shape, f"shape mismatch in {k}, {mv.shape} != {v.shape}"
|
||||
mv.assign(v).realize()
|
||||
#mv.lazydata.realized = v
|
||||
# assign weights (should be free)
|
||||
for k,v in weights.items():
|
||||
if '.inner_attention.rope.freqs' in k: continue # no rope today
|
||||
mv = get_child(model, k)
|
||||
assert mv.shape == v.shape, f"shape mismatch in {k}, {mv.shape} != {v.shape}"
|
||||
mv.assign(v).realize()
|
||||
|
||||
del weights
|
||||
|
||||
|
||||
@@ -619,7 +619,7 @@ if __name__ == "__main__":
|
||||
skip_if_exists=True
|
||||
)
|
||||
dat = fake_torch_load_zipped(open(FILENAME, "rb"))
|
||||
for k,v in tqdm(dat['state_dict'].items()):
|
||||
for k,v in dat['state_dict'].items():
|
||||
try:
|
||||
w = get_child(model, k)
|
||||
except (AttributeError, KeyError, IndexError):
|
||||
@@ -627,7 +627,7 @@ if __name__ == "__main__":
|
||||
w = None
|
||||
#print(f"{str(v.shape):30s}" if v is not None else v, w.shape if w is not None else w, k)
|
||||
if w is not None:
|
||||
assert w.shape == v.shape and w.dtype.np == v.dtype, f"shape or dtype mismatch. {w.shape} != {v.shape} or {w.dtype.np} != {v.dtype}"
|
||||
assert w.shape == v.shape and w.dtype == v.dtype, f"shape or dtype mismatch. {w.shape} != {v.shape} or {w.dtype} != {v.dtype}"
|
||||
w.assign(v)
|
||||
|
||||
# run through CLIP to get context
|
||||
|
||||
@@ -2,7 +2,12 @@ import pickle
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import tempfile
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import prod, getenv, DEBUG
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyNumpyArray
|
||||
from tinygrad.shape import strides_for_shape
|
||||
|
||||
def fetch(url):
|
||||
if url.startswith("/"):
|
||||
@@ -28,16 +33,20 @@ def download_file(url, fp, skip_if_exists=False):
|
||||
os.rename(f.name, fp)
|
||||
|
||||
def my_unpickle(fb0):
|
||||
key_prelookup = {}
|
||||
key_prelookup = defaultdict(list)
|
||||
class HackTensor:
|
||||
def __new__(cls, *args):
|
||||
#print(args)
|
||||
ident, storage_type, obj_key, location, obj_size = args[0][0:5]
|
||||
assert ident == 'storage'
|
||||
|
||||
assert prod(args[2]) == obj_size
|
||||
ret = np.zeros(args[2], dtype=storage_type)
|
||||
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
|
||||
|
||||
if storage_type not in [np.float16, np.float32]:
|
||||
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {args[2]}")
|
||||
ret = None
|
||||
else:
|
||||
ret = Tensor(LazyNumpyArray(None, tuple(args[2]), storage_type))
|
||||
key_prelookup[obj_key].append((storage_type, obj_size, ret, args[2], args[3]))
|
||||
return ret
|
||||
|
||||
class HackParameter:
|
||||
@@ -55,6 +64,8 @@ def my_unpickle(fb0):
|
||||
return np.float32
|
||||
if name == 'LongStorage':
|
||||
return np.int64
|
||||
if name == 'IntStorage':
|
||||
return np.int32
|
||||
if name == 'HalfStorage':
|
||||
return np.float16
|
||||
if module == "torch._utils":
|
||||
@@ -73,18 +84,49 @@ def my_unpickle(fb0):
|
||||
|
||||
return MyPickle(fb0).load(), key_prelookup
|
||||
|
||||
def load_single_weight(t:Tensor, myfile, shape, strides, dtype):
|
||||
bytes_size = np.dtype(dtype).itemsize
|
||||
if t is None:
|
||||
myfile.seek(prod(shape) * bytes_size, 1)
|
||||
return
|
||||
|
||||
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
|
||||
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
|
||||
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
|
||||
# slow path
|
||||
np_array = np.frombuffer(myfile.read(prod(t.shape) * t.dtype.itemsize), t.dtype.np).reshape(t.shape)
|
||||
real_strides = tuple([x*t.dtype.itemsize for x in strides]) # numpy stores its strides in bytes
|
||||
np_array.strides = real_strides
|
||||
|
||||
lna = t.lazydata.op.arg
|
||||
lna.fxn = lambda _: np_array
|
||||
t.realize()
|
||||
return
|
||||
|
||||
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
|
||||
# this needs real APIs
|
||||
if t.device in ["METAL", "CLANG", "LLVM"]:
|
||||
del t.lazydata.op
|
||||
t.lazydata.realized = t.lazydata.dbuffer(t.shape, dtype=t.dtype)
|
||||
myfile.readinto(t.lazydata.realized.raw()._buffer())
|
||||
else:
|
||||
lna = t.lazydata.op.arg
|
||||
lna.fxn = lambda lna: np.frombuffer(myfile.read(prod(t.shape) * t.dtype.itemsize), lna.dtype).reshape(lna.shape)
|
||||
t.realize()
|
||||
|
||||
def fake_torch_load_zipped(fb0, load_weights=True, base_name="archive"):
|
||||
import zipfile
|
||||
with zipfile.ZipFile(fb0, 'r') as myzip:
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
ret = my_unpickle(myfile)
|
||||
if load_weights:
|
||||
for k,v in tqdm(ret[1].items()):
|
||||
def load_weight(k, vv):
|
||||
with myzip.open(f'{base_name}/data/{k}') as myfile:
|
||||
if v[2].dtype == "object":
|
||||
print(f"issue assigning object on {k}")
|
||||
continue
|
||||
np.copyto(v[2], np.frombuffer(myfile.read(), v[2].dtype).reshape(v[3]))
|
||||
for v in vv:
|
||||
load_single_weight(v[2], myfile, v[3], v[4], v[0])
|
||||
for k,v in (t := tqdm(ret[1].items())):
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
||||
load_weight(k,v)
|
||||
return ret[0]
|
||||
|
||||
def fake_torch_load(b0):
|
||||
@@ -108,19 +150,14 @@ def fake_torch_load(b0):
|
||||
key_lookup = pickle.load(fb0)
|
||||
key_real = [None] * len(key_lookup)
|
||||
for k,v in key_prelookup.items():
|
||||
key_real[key_lookup.index(k)] = v
|
||||
assert len(v) == 1
|
||||
key_real[key_lookup.index(k)] = v[0]
|
||||
|
||||
# read in the actual data
|
||||
for storage_type, obj_size, np_array, np_shape, np_strides in key_real:
|
||||
for storage_type, obj_size, tensor, np_shape, np_strides in key_real:
|
||||
ll = struct.unpack("Q", fb0.read(8))[0]
|
||||
assert ll == obj_size
|
||||
bytes_size = {np.float32: 4, np.int64: 8}[storage_type]
|
||||
mydat = fb0.read(ll * bytes_size)
|
||||
np.copyto(np_array, np.frombuffer(mydat, storage_type).reshape(np_shape))
|
||||
|
||||
# numpy stores its strides in bytes
|
||||
real_strides = tuple([x*bytes_size for x in np_strides])
|
||||
np_array.strides = real_strides
|
||||
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
|
||||
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -146,6 +146,7 @@ class EfficientNet:
|
||||
|
||||
b0 = fake_torch_load(fetch(model_urls[self.number]))
|
||||
for k,v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||
if cat in k:
|
||||
k = k.replace('.bias', '_bias')
|
||||
@@ -153,9 +154,9 @@ class EfficientNet:
|
||||
|
||||
#print(k, v.shape)
|
||||
mv = get_child(self, k)
|
||||
vnp = v.astype(np.float32)
|
||||
vnp = vnp if k != '_fc' else vnp.T
|
||||
vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
vnp = v #.astype(np.float32)
|
||||
vnp = vnp if k != '_fc' else vnp.transpose()
|
||||
#vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
|
||||
if mv.shape == vnp.shape:
|
||||
mv.assign(vnp)
|
||||
|
||||
@@ -16,7 +16,7 @@ LAZY = getenv("LAZY", 1)
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._buffers = {y.upper():y for y in [os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")]}
|
||||
self.DEFAULT : str = functools.reduce(lambda val, ele: val if getenv(val) == 1 else ele, self._buffers, "CPU")
|
||||
self.DEFAULT : str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, "CPU")
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __getitem__(self, x:str) -> Type[DeviceBuffer]: return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{self._buffers[x]}'), inspect.isclass) if (cname.lower() == self._buffers[x] + "buffer")][0]
|
||||
Device = _Device()
|
||||
@@ -73,7 +73,7 @@ def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tu
|
||||
|
||||
class LazyNumpyArray:
|
||||
def __init__(self, fxn, shape, dtype): self.fxn, self.shape, self.dtype = fxn, shape, dtype
|
||||
def __call__(self): return self.fxn(self.shape, self.dtype)
|
||||
def __call__(self): return self.fxn(self)
|
||||
def reshape(self, new_shape): return LazyNumpyArray(self.fxn, new_shape, self.dtype)
|
||||
def copy(self): return self
|
||||
def astype(self, typ): return self
|
||||
|
||||
@@ -44,7 +44,7 @@ class RawBuffer(Copyable): # pylint: disable=abstract-method
|
||||
self.dtype : DType = dtype
|
||||
self._memsz : int = size*dtype.itemsize
|
||||
GlobalCounters.mem_used += self._memsz
|
||||
def __del__(self): GlobalCounters.mem_used -= self.size*self._memsz
|
||||
def __del__(self): GlobalCounters.mem_used -= self._memsz
|
||||
|
||||
class RawBufferCopyIn(RawBuffer):
|
||||
def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
@@ -133,11 +133,11 @@ class Tensor:
|
||||
def manual_seed(seed=None): Tensor._rng = np.random.default_rng(seed=seed)
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda shape, dtype: Tensor._rng.random(size=shape, dtype=dtype), shape, np.float32), **kwargs)
|
||||
def rand(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.random(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
|
||||
|
||||
# TODO: replace with a transformation from uniform -> gaussian
|
||||
@staticmethod
|
||||
def randn(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda shape, dtype: Tensor._rng.standard_normal(size=shape, dtype=dtype), shape, np.float32), **kwargs)
|
||||
def randn(*shape, **kwargs) -> Tensor: return Tensor(LazyNumpyArray(lambda lna: Tensor._rng.standard_normal(size=lna.shape, dtype=lna.dtype), shape, np.float32), **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user