diff --git a/examples/efficientnet.py b/examples/efficientnet.py index a68089a488..8f0815285e 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -15,23 +15,33 @@ class BatchNorm2D: return x * self.weight + self.bias class MBConvBlock: - def __init__(self, d0, d1, d2, d3): - self._expand_conv = Tensor.zeros(d1, d0, 1, 1) - self._bn0 = BatchNorm2D(d1) - self._depthwise_conv = Tensor.zeros(d1, 1, 3, 3) - self._bn1 = BatchNorm2D(d1) - self._se_reduce = Tensor.zeros(d2, d1, 1, 1) - self._se_reduce_bias = Tensor.zeros(d2) - self._se_expand = Tensor.zeros(d1, d2, 1, 1) - self._se_expand_bias = Tensor.zeros(d1) - self._project_conv = Tensor.zeros(d3, d2, 1, 1) - self._bn2 = BatchNorm2D(d3) + def __init__(self, input_filters, expand_ratio, se_ratio, output_filters): + oup = expand_ratio * input_filters + if expand_ratio != 1: + self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1) + self._bn0 = BatchNorm2D(oup) + self._depthwise_conv = Tensor.zeros(oup, 1, 3, 3) + self._bn1 = BatchNorm2D(oup) + + num_squeezed_channels = max(1, int(input_filters * se_ratio)) + self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1) + self._se_reduce_bias = Tensor.zeros(num_squeezed_channels) + self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1) + self._se_expand_bias = Tensor.zeros(oup) + + self._project_conv = Tensor.zeros(output_filters, oup, 1, 1) + self._bn2 = BatchNorm2D(output_filters) def __call__(self, x): - x = self._bn0(x.conv2d(self._expand_conv)) - x = self._bn1(x.conv2d(self._depthwise_conv)) # TODO: repeat on axis 1 - x = x.conv2d(self._se_reduce) + self._se_reduce_bias - x = x.conv2d(self._se_expand) + self._se_expand_bias + x = self._bn0(x.conv2d(self._expand_conv)).swish() + x = self._bn1(x.conv2d(self._depthwise_conv)).swish() # TODO: repeat on axis 1 + + # has_se + x_squeezed = x.avg_pool2d() + x_squeezed = (x_squeezed.conv2d(self._se_reduce) + self._se_reduce_bias).swish() + x_squeezed = x_squeezed.conv2d(self._se_expand) + self._se_expand_bias + x = x * x_squeezed.sigmoid() + x = self._bn2(x.conv2d(self._project_conv)) return x.swish() @@ -51,7 +61,7 @@ class EfficientNet: for b in self._blocks: x = b(x) x = self._bn1(x.conv2d(self._conv_head)) - x = x.avg_pool2d() # wrong + x = x.avg_pool2d() # wrong? x = x.dropout(0.2) return x.dot(self_fc).swish() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1ef53358e3..2b7c0c57f0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -68,6 +68,20 @@ class ReLU(Function): return grad_input register('relu', ReLU) +class Sigmoid(Function): + @staticmethod + def forward(ctx, input): + ret = 1/(1 + np.exp(-input)) + ctx.save_for_backward(ret) + return ret + + @staticmethod + def backward(ctx, grad_output): + ret, = ctx.saved_tensors + grad_input = grad_output * (ret * (1 - ret)) + return grad_input +register('sigmoid', Sigmoid) + class Reshape(Function): @staticmethod def forward(ctx, x, shape):