From 2db670ef26394e778cabda55df8e3dc36897ba4a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 29 Oct 2020 08:08:21 -0700 Subject: [PATCH] enet refactor + no sigmoid warning --- examples/efficientnet.py | 50 ++++++++++++++++++++++------------------ tinygrad/ops.py | 5 +++- 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 93921ac6a9..9591532108 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -3,8 +3,14 @@ # https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth # a rough copy of # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py +import io +import numpy as np +np.set_printoptions(suppress=True) + from tinygrad.tensor import Tensor from tinygrad.utils import fetch + +# BatchNorm2D and swish from tinygrad.nn import * class MBConvBlock: @@ -84,30 +90,30 @@ class EfficientNet: #x = x.dropout(0.2) return swish(x.dot(self._fc).add(self._fc_bias)) + def load_weights_from_torch(self): + # load b0 + import torch + b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth") + b0 = torch.load(io.BytesIO(b0)) + + for k,v in b0.items(): + if '_blocks.' in k: + k = "%s[%s].%s" % tuple(k.split(".", 2)) + mk = "self."+k + #print(k, v.shape) + try: + mv = eval(mk) + except AttributeError: + try: + mv = eval(mk.replace(".weight", "")) + except AttributeError: + mv = eval(mk.replace(".bias", "_bias")) + mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T + if __name__ == "__main__": - import numpy as np - np.set_printoptions(suppress=True) # instantiate my net model = EfficientNet() - - # load b0 - import io, torch - b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth") - b0 = torch.load(io.BytesIO(b0)) - - for k,v in b0.items(): - if '_blocks.' in k: - k = "%s[%s].%s" % tuple(k.split(".", 2)) - mk = "model."+k - #print(k, v.shape) - try: - mv = eval(mk) - except AttributeError: - try: - mv = eval(mk.replace(".weight", "")) - except AttributeError: - mv = eval(mk.replace(".bias", "_bias")) - mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T + model.load_weights_from_torch() # load cat image from PIL import Image @@ -117,7 +123,7 @@ if __name__ == "__main__": img = img.astype(np.float32).reshape(1,3,224,224) print(img.shape) - #b0 = pickle.loads(b0) + # run the net out = model.forward(Tensor(img)) print(np.argmax(out.data), np.max(out.data)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fc1da63d48..8201f635e7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -104,7 +104,10 @@ register('relu', ReLU) class Sigmoid(Function): @staticmethod def forward(ctx, input): - ret = 1/(1 + np.exp(-input)) + # TODO: stable sigmoid? does the overflow matter? + with np.warnings.catch_warnings(): + np.warnings.filterwarnings('ignore') + ret = 1/(1 + np.exp(-input)) ctx.save_for_backward(ret) return ret