enet refactor + no sigmoid warning

This commit is contained in:
George Hotz
2020-10-29 08:08:21 -07:00
parent 17fa74c15b
commit 2db670ef26
2 changed files with 32 additions and 23 deletions

View File

@@ -3,8 +3,14 @@
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
# a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
import io
import numpy as np
np.set_printoptions(suppress=True)
from tinygrad.tensor import Tensor
from tinygrad.utils import fetch
# BatchNorm2D and swish
from tinygrad.nn import *
class MBConvBlock:
@@ -84,30 +90,30 @@ class EfficientNet:
#x = x.dropout(0.2)
return swish(x.dot(self._fc).add(self._fc_bias))
def load_weights_from_torch(self):
# load b0
import torch
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
b0 = torch.load(io.BytesIO(b0))
for k,v in b0.items():
if '_blocks.' in k:
k = "%s[%s].%s" % tuple(k.split(".", 2))
mk = "self."+k
#print(k, v.shape)
try:
mv = eval(mk)
except AttributeError:
try:
mv = eval(mk.replace(".weight", ""))
except AttributeError:
mv = eval(mk.replace(".bias", "_bias"))
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
if __name__ == "__main__":
import numpy as np
np.set_printoptions(suppress=True)
# instantiate my net
model = EfficientNet()
# load b0
import io, torch
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
b0 = torch.load(io.BytesIO(b0))
for k,v in b0.items():
if '_blocks.' in k:
k = "%s[%s].%s" % tuple(k.split(".", 2))
mk = "model."+k
#print(k, v.shape)
try:
mv = eval(mk)
except AttributeError:
try:
mv = eval(mk.replace(".weight", ""))
except AttributeError:
mv = eval(mk.replace(".bias", "_bias"))
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
model.load_weights_from_torch()
# load cat image
from PIL import Image
@@ -117,7 +123,7 @@ if __name__ == "__main__":
img = img.astype(np.float32).reshape(1,3,224,224)
print(img.shape)
#b0 = pickle.loads(b0)
# run the net
out = model.forward(Tensor(img))
print(np.argmax(out.data), np.max(out.data))

View File

@@ -104,7 +104,10 @@ register('relu', ReLU)
class Sigmoid(Function):
@staticmethod
def forward(ctx, input):
ret = 1/(1 + np.exp(-input))
# TODO: stable sigmoid? does the overflow matter?
with np.warnings.catch_warnings():
np.warnings.filterwarnings('ignore')
ret = 1/(1 + np.exp(-input))
ctx.save_for_backward(ret)
return ret