mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
enet weight loading
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import fetch
|
||||
|
||||
def swish(x):
|
||||
return x.mul(x.sigmoid())
|
||||
@@ -13,6 +14,9 @@ class BatchNorm2D:
|
||||
self.weight = Tensor.zeros(sz)
|
||||
self.bias = Tensor.zeros(sz)
|
||||
# TODO: need running_mean and running_var
|
||||
self.running_mean = Tensor.zeros(sz)
|
||||
self.running_var = Tensor.zeros(sz)
|
||||
self.num_batches_tracked = Tensor.zeros(0)
|
||||
|
||||
def __call__(self, x):
|
||||
# this work at inference?
|
||||
@@ -84,19 +88,43 @@ class EfficientNet:
|
||||
self._conv_head = Tensor.zeros(1280, 320, 1, 1)
|
||||
self._bn1 = BatchNorm2D(1280)
|
||||
self._fc = Tensor.zeros(1280, 1000)
|
||||
self._fc_bias = Tensor.zeros(1000)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.pad2d(padding=(0,1,0,1))
|
||||
x = self._bn0(x.conv2d(self._conv_stem, stride=2))
|
||||
for b in self._blocks:
|
||||
print(x.shape)
|
||||
x = b(x)
|
||||
x = self._bn1(x.conv2d(self._conv_head))
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
|
||||
#x = x.dropout(0.2)
|
||||
return swish(x.dot(self._fc))
|
||||
return swish(x.dot(self._fc).add(self._fc_bias))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet()
|
||||
|
||||
# load b0
|
||||
import io, torch
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
|
||||
for k,v in b0.items():
|
||||
if '_blocks.' in k:
|
||||
k = "%s[%s].%s" % tuple(k.split(".", 2))
|
||||
mk = "model."+k
|
||||
print(k, v.shape)
|
||||
try:
|
||||
mv = eval(mk)
|
||||
except AttributeError:
|
||||
try:
|
||||
mv = eval(mk.replace(".weight", ""))
|
||||
except AttributeError:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
|
||||
|
||||
#b0 = pickle.loads(b0)
|
||||
out = model.forward(Tensor.zeros(1, 3, 224, 224))
|
||||
print(out)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user