mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
enet work
This commit is contained in:
@@ -10,9 +10,11 @@ def swish(x):
|
||||
return x.mul(x.sigmoid())
|
||||
|
||||
class BatchNorm2D:
|
||||
def __init__(self, sz):
|
||||
def __init__(self, sz, eps=0.001):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.zeros(sz)
|
||||
self.bias = Tensor.zeros(sz)
|
||||
|
||||
# TODO: need running_mean and running_var
|
||||
self.running_mean = Tensor.zeros(sz)
|
||||
self.running_var = Tensor.zeros(sz)
|
||||
@@ -20,7 +22,9 @@ class BatchNorm2D:
|
||||
|
||||
def __call__(self, x):
|
||||
# this work at inference?
|
||||
x = x.sub(self.running_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1]))
|
||||
x = x.div(self.running_var.add(Tensor([self.eps])).reshape(shape=[1, -1, 1, 1]).sqrt())
|
||||
x = x.add(self.bias.reshape(shape=[1, -1, 1, 1]))
|
||||
return x
|
||||
|
||||
@@ -102,6 +106,8 @@ class EfficientNet:
|
||||
return swish(x.dot(self._fc).add(self._fc_bias))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
# instantiate my net
|
||||
model = EfficientNet()
|
||||
|
||||
@@ -114,7 +120,7 @@ if __name__ == "__main__":
|
||||
if '_blocks.' in k:
|
||||
k = "%s[%s].%s" % tuple(k.split(".", 2))
|
||||
mk = "model."+k
|
||||
print(k, v.shape)
|
||||
#print(k, v.shape)
|
||||
try:
|
||||
mv = eval(mk)
|
||||
except AttributeError:
|
||||
@@ -125,6 +131,7 @@ if __name__ == "__main__":
|
||||
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
|
||||
|
||||
#b0 = pickle.loads(b0)
|
||||
out = model.forward(Tensor.zeros(1, 3, 224, 224))
|
||||
print(out)
|
||||
img = np.zeros((1, 3, 224, 224), np.float32) + 0.5
|
||||
out = model.forward(Tensor(img))
|
||||
print(out.data[:, 0:10])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user