mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
se optional, track time better
This commit is contained in:
@@ -70,7 +70,7 @@ def fake_torch_load(b0):
|
||||
return ret
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
|
||||
oup = expand_ratio * input_filters
|
||||
if expand_ratio != 1:
|
||||
self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1)
|
||||
@@ -87,11 +87,13 @@ class MBConvBlock:
|
||||
self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size)
|
||||
self._bn1 = BatchNorm2D(oup)
|
||||
|
||||
num_squeezed_channels = max(1, int(input_filters * se_ratio))
|
||||
self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1)
|
||||
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
|
||||
self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1)
|
||||
self._se_expand_bias = Tensor.zeros(oup)
|
||||
self.has_se = has_se
|
||||
if self.has_se:
|
||||
num_squeezed_channels = max(1, int(input_filters * se_ratio))
|
||||
self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1)
|
||||
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
|
||||
self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1)
|
||||
self._se_expand_bias = Tensor.zeros(oup)
|
||||
|
||||
self._project_conv = Tensor.zeros(output_filters, oup, 1, 1)
|
||||
self._bn2 = BatchNorm2D(output_filters)
|
||||
@@ -105,10 +107,11 @@ class MBConvBlock:
|
||||
x = self._bn1(x).swish()
|
||||
|
||||
# has_se
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
|
||||
x = x.mul(x_squeezed.sigmoid())
|
||||
if self.has_se:
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
|
||||
x = x.mul(x_squeezed.sigmoid())
|
||||
|
||||
x = self._bn2(x.conv2d(self._project_conv))
|
||||
if x.shape == inputs.shape:
|
||||
@@ -116,7 +119,7 @@ class MBConvBlock:
|
||||
return x
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self, number=0, classes=1000):
|
||||
def __init__(self, number=0, classes=1000, has_se=True):
|
||||
self.number = number
|
||||
global_params = [
|
||||
# width, depth
|
||||
@@ -163,7 +166,7 @@ class EfficientNet:
|
||||
args[3] = round_filters(args[3])
|
||||
args[4] = round_filters(args[4])
|
||||
for n in range(round_repeats(b[0])):
|
||||
self._blocks.append(MBConvBlock(*args))
|
||||
self._blocks.append(MBConvBlock(*args, has_se=has_se))
|
||||
args[3] = args[4]
|
||||
args[1] = (1,1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user