mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
load torch files without torch
This commit is contained in:
@@ -4,6 +4,71 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import fetch
|
||||
from tinygrad.nn import BatchNorm2D
|
||||
|
||||
USE_TORCH = False
|
||||
|
||||
def fake_torch_load(b0):
|
||||
import io
|
||||
import pickle
|
||||
import struct
|
||||
|
||||
# convert it to a file
|
||||
fb0 = io.BytesIO(b0)
|
||||
|
||||
# skip three junk pickles
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
|
||||
key_prelookup = {}
|
||||
|
||||
class HackTensor:
|
||||
def __new__(cls, *args):
|
||||
#print(args)
|
||||
ident, storage_type, obj_key, location, obj_size, view_metadata = args[0]
|
||||
assert ident == 'storage'
|
||||
|
||||
ret = np.zeros(obj_size, dtype=storage_type)
|
||||
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
|
||||
return ret
|
||||
|
||||
class MyPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#print(module, name)
|
||||
if name == 'FloatStorage':
|
||||
return np.float32
|
||||
if name == 'LongStorage':
|
||||
return np.int64
|
||||
if module == "torch._utils" or module == "torch":
|
||||
return HackTensor
|
||||
else:
|
||||
return pickle.Unpickler.find_class(self, module, name)
|
||||
|
||||
def persistent_load(self, pid):
|
||||
return pid
|
||||
|
||||
ret = MyPickle(fb0).load()
|
||||
|
||||
# create key_lookup
|
||||
key_lookup = pickle.load(fb0)
|
||||
key_real = [None] * len(key_lookup)
|
||||
for k,v in key_prelookup.items():
|
||||
key_real[key_lookup.index(k)] = v
|
||||
|
||||
# read in the actual data
|
||||
for storage_type, obj_size, np_array, np_shape, np_strides in key_real:
|
||||
ll = struct.unpack("Q", fb0.read(8))[0]
|
||||
assert ll == obj_size
|
||||
bytes_size = {np.float32: 4, np.int64: 8}[storage_type]
|
||||
mydat = fb0.read(ll * bytes_size)
|
||||
np_array[:] = np.frombuffer(mydat, storage_type)
|
||||
np_array.shape = np_shape
|
||||
|
||||
# numpy stores its strides in bytes
|
||||
real_strides = tuple([x*bytes_size for x in np_strides])
|
||||
np_array.strides = real_strides
|
||||
|
||||
return ret
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
|
||||
oup = expand_ratio * input_filters
|
||||
@@ -123,8 +188,6 @@ class EfficientNet:
|
||||
|
||||
def load_weights_from_torch(self, gpu):
|
||||
# load b0
|
||||
import io
|
||||
import torch
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551
|
||||
if self.number == 0:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
@@ -136,7 +199,13 @@ class EfficientNet:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth")
|
||||
else:
|
||||
raise Exception("no pretrained weights")
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
|
||||
if USE_TORCH:
|
||||
import io
|
||||
import torch
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
else:
|
||||
b0 = fake_torch_load(b0)
|
||||
|
||||
for k,v in b0.items():
|
||||
if '_blocks.' in k:
|
||||
@@ -150,7 +219,7 @@ class EfficientNet:
|
||||
mv = eval(mk.replace(".weight", ""))
|
||||
except AttributeError:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
vnp = v.numpy().astype(np.float32)
|
||||
vnp = v.numpy().astype(np.float32) if USE_TORCH else v
|
||||
mv.data[:] = vnp if k != '_fc.weight' else vnp.T
|
||||
if gpu:
|
||||
mv.cuda_()
|
||||
|
||||
Reference in New Issue
Block a user