From 03994e0011f22dc0a4497e3e171160c2f8b5f600 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 21 Nov 2020 13:43:53 -0800 Subject: [PATCH] load torch files without torch --- extra/efficientnet.py | 77 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/extra/efficientnet.py b/extra/efficientnet.py index ef3372317d..5dafa19678 100644 --- a/extra/efficientnet.py +++ b/extra/efficientnet.py @@ -4,6 +4,71 @@ from tinygrad.tensor import Tensor from tinygrad.utils import fetch from tinygrad.nn import BatchNorm2D +USE_TORCH = False + +def fake_torch_load(b0): + import io + import pickle + import struct + + # convert it to a file + fb0 = io.BytesIO(b0) + + # skip three junk pickles + pickle.load(fb0) + pickle.load(fb0) + pickle.load(fb0) + + key_prelookup = {} + + class HackTensor: + def __new__(cls, *args): + #print(args) + ident, storage_type, obj_key, location, obj_size, view_metadata = args[0] + assert ident == 'storage' + + ret = np.zeros(obj_size, dtype=storage_type) + key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3]) + return ret + + class MyPickle(pickle.Unpickler): + def find_class(self, module, name): + #print(module, name) + if name == 'FloatStorage': + return np.float32 + if name == 'LongStorage': + return np.int64 + if module == "torch._utils" or module == "torch": + return HackTensor + else: + return pickle.Unpickler.find_class(self, module, name) + + def persistent_load(self, pid): + return pid + + ret = MyPickle(fb0).load() + + # create key_lookup + key_lookup = pickle.load(fb0) + key_real = [None] * len(key_lookup) + for k,v in key_prelookup.items(): + key_real[key_lookup.index(k)] = v + + # read in the actual data + for storage_type, obj_size, np_array, np_shape, np_strides in key_real: + ll = struct.unpack("Q", fb0.read(8))[0] + assert ll == obj_size + bytes_size = {np.float32: 4, np.int64: 8}[storage_type] + mydat = fb0.read(ll * bytes_size) + np_array[:] = np.frombuffer(mydat, storage_type) + np_array.shape = np_shape + + # numpy stores its strides in bytes + real_strides = tuple([x*bytes_size for x in np_strides]) + np_array.strides = real_strides + + return ret + class MBConvBlock: def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio): oup = expand_ratio * input_filters @@ -123,8 +188,6 @@ class EfficientNet: def load_weights_from_torch(self, gpu): # load b0 - import io - import torch # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551 if self.number == 0: b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth") @@ -136,7 +199,13 @@ class EfficientNet: b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth") else: raise Exception("no pretrained weights") - b0 = torch.load(io.BytesIO(b0)) + + if USE_TORCH: + import io + import torch + b0 = torch.load(io.BytesIO(b0)) + else: + b0 = fake_torch_load(b0) for k,v in b0.items(): if '_blocks.' in k: @@ -150,7 +219,7 @@ class EfficientNet: mv = eval(mk.replace(".weight", "")) except AttributeError: mv = eval(mk.replace(".bias", "_bias")) - vnp = v.numpy().astype(np.float32) + vnp = v.numpy().astype(np.float32) if USE_TORCH else v mv.data[:] = vnp if k != '_fc.weight' else vnp.T if gpu: mv.cuda_()