mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
enet refactor + no sigmoid warning
This commit is contained in:
@@ -3,8 +3,14 @@
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
||||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
import io
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import fetch
|
||||
|
||||
# BatchNorm2D and swish
|
||||
from tinygrad.nn import *
|
||||
|
||||
class MBConvBlock:
|
||||
@@ -84,30 +90,30 @@ class EfficientNet:
|
||||
#x = x.dropout(0.2)
|
||||
return swish(x.dot(self._fc).add(self._fc_bias))
|
||||
|
||||
def load_weights_from_torch(self):
|
||||
# load b0
|
||||
import torch
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
|
||||
for k,v in b0.items():
|
||||
if '_blocks.' in k:
|
||||
k = "%s[%s].%s" % tuple(k.split(".", 2))
|
||||
mk = "self."+k
|
||||
#print(k, v.shape)
|
||||
try:
|
||||
mv = eval(mk)
|
||||
except AttributeError:
|
||||
try:
|
||||
mv = eval(mk.replace(".weight", ""))
|
||||
except AttributeError:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
|
||||
|
||||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
# instantiate my net
|
||||
model = EfficientNet()
|
||||
|
||||
# load b0
|
||||
import io, torch
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
|
||||
for k,v in b0.items():
|
||||
if '_blocks.' in k:
|
||||
k = "%s[%s].%s" % tuple(k.split(".", 2))
|
||||
mk = "model."+k
|
||||
#print(k, v.shape)
|
||||
try:
|
||||
mv = eval(mk)
|
||||
except AttributeError:
|
||||
try:
|
||||
mv = eval(mk.replace(".weight", ""))
|
||||
except AttributeError:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
|
||||
model.load_weights_from_torch()
|
||||
|
||||
# load cat image
|
||||
from PIL import Image
|
||||
@@ -117,7 +123,7 @@ if __name__ == "__main__":
|
||||
img = img.astype(np.float32).reshape(1,3,224,224)
|
||||
print(img.shape)
|
||||
|
||||
#b0 = pickle.loads(b0)
|
||||
# run the net
|
||||
out = model.forward(Tensor(img))
|
||||
print(np.argmax(out.data), np.max(out.data))
|
||||
|
||||
|
||||
@@ -104,7 +104,10 @@ register('relu', ReLU)
|
||||
class Sigmoid(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = 1/(1 + np.exp(-input))
|
||||
# TODO: stable sigmoid? does the overflow matter?
|
||||
with np.warnings.catch_warnings():
|
||||
np.warnings.filterwarnings('ignore')
|
||||
ret = 1/(1 + np.exp(-input))
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user