mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
factor out a couple nn ops
This commit is contained in:
@@ -5,28 +5,7 @@
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import fetch
|
||||
|
||||
def swish(x):
|
||||
return x.mul(x.sigmoid())
|
||||
|
||||
class BatchNorm2D:
|
||||
def __init__(self, sz, eps=0.001):
|
||||
self.eps = eps
|
||||
self.weight = Tensor.zeros(sz)
|
||||
self.bias = Tensor.zeros(sz)
|
||||
|
||||
# TODO: need running_mean and running_var
|
||||
self.running_mean = Tensor.zeros(sz)
|
||||
self.running_var = Tensor.zeros(sz)
|
||||
self.num_batches_tracked = Tensor.zeros(0)
|
||||
|
||||
def __call__(self, x):
|
||||
# this work at inference?
|
||||
x = x.sub(self.running_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1]))
|
||||
x = x.div(self.running_var.add(Tensor([self.eps])).reshape(shape=[1, -1, 1, 1]).sqrt())
|
||||
x = x.add(self.bias.reshape(shape=[1, -1, 1, 1]))
|
||||
return x
|
||||
from tinygrad.nn import *
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
|
||||
|
||||
Reference in New Issue
Block a user