factor out a couple nn ops

This commit is contained in:
George Hotz
2020-10-29 08:01:12 -07:00
parent f84f6c1edd
commit 17fa74c15b
3 changed files with 26 additions and 22 deletions

View File

@@ -5,28 +5,7 @@
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
from tinygrad.tensor import Tensor
from tinygrad.utils import fetch
def swish(x):
return x.mul(x.sigmoid())
class BatchNorm2D:
def __init__(self, sz, eps=0.001):
self.eps = eps
self.weight = Tensor.zeros(sz)
self.bias = Tensor.zeros(sz)
# TODO: need running_mean and running_var
self.running_mean = Tensor.zeros(sz)
self.running_var = Tensor.zeros(sz)
self.num_batches_tracked = Tensor.zeros(0)
def __call__(self, x):
# this work at inference?
x = x.sub(self.running_mean.reshape(shape=[1, -1, 1, 1]))
x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1]))
x = x.div(self.running_var.add(Tensor([self.eps])).reshape(shape=[1, -1, 1, 1]).sqrt())
x = x.add(self.bias.reshape(shape=[1, -1, 1, 1]))
return x
from tinygrad.nn import *
class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):