diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 5cb24bffa0..de76c3c8f1 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -10,9 +10,11 @@ def swish(x): return x.mul(x.sigmoid()) class BatchNorm2D: - def __init__(self, sz): + def __init__(self, sz, eps=0.001): + self.eps = eps self.weight = Tensor.zeros(sz) self.bias = Tensor.zeros(sz) + # TODO: need running_mean and running_var self.running_mean = Tensor.zeros(sz) self.running_var = Tensor.zeros(sz) @@ -20,7 +22,9 @@ class BatchNorm2D: def __call__(self, x): # this work at inference? + x = x.sub(self.running_mean.reshape(shape=[1, -1, 1, 1])) x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1])) + x = x.div(self.running_var.add(Tensor([self.eps])).reshape(shape=[1, -1, 1, 1]).sqrt()) x = x.add(self.bias.reshape(shape=[1, -1, 1, 1])) return x @@ -102,6 +106,8 @@ class EfficientNet: return swish(x.dot(self._fc).add(self._fc_bias)) if __name__ == "__main__": + import numpy as np + np.set_printoptions(suppress=True) # instantiate my net model = EfficientNet() @@ -114,7 +120,7 @@ if __name__ == "__main__": if '_blocks.' in k: k = "%s[%s].%s" % tuple(k.split(".", 2)) mk = "model."+k - print(k, v.shape) + #print(k, v.shape) try: mv = eval(mk) except AttributeError: @@ -125,6 +131,7 @@ if __name__ == "__main__": mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T #b0 = pickle.loads(b0) - out = model.forward(Tensor.zeros(1, 3, 224, 224)) - print(out) + img = np.zeros((1, 3, 224, 224), np.float32) + 0.5 + out = model.forward(Tensor(img)) + print(out.data[:, 0:10]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 57fa1b532b..4a8cf4682b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -24,6 +24,41 @@ class Add(Function): def backward(ctx, grad_output): return grad_output, grad_output register('add', Add) + +class Sub(Function): + @staticmethod + def forward(ctx, x, y): + return x-y + + @staticmethod + def backward(ctx, grad_output): + # this right? + return grad_output, -grad_output +register('sub', Sub) + +class Div(Function): + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return x/y + + @staticmethod + def backward(ctx, grad_output): + # this right? + x,y = ctx.saved_tensors + return y/grad_output, x/grad_output +register('div', Div) + +class Sqrt(Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return np.sqrt(x) + + @staticmethod + def backward(ctx, grad_output): + raise Exception("write this") +register('sqrt', Sqrt) class Dot(Function): @staticmethod