se optional, track time better

This commit is contained in:
George Hotz
2020-12-06 12:29:42 -08:00
parent 609d11e699
commit 521098cc2f
2 changed files with 26 additions and 15 deletions

View File

@@ -70,7 +70,7 @@ def fake_torch_load(b0):
return ret
class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
oup = expand_ratio * input_filters
if expand_ratio != 1:
self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1)
@@ -87,11 +87,13 @@ class MBConvBlock:
self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size)
self._bn1 = BatchNorm2D(oup)
num_squeezed_channels = max(1, int(input_filters * se_ratio))
self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1)
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1)
self._se_expand_bias = Tensor.zeros(oup)
self.has_se = has_se
if self.has_se:
num_squeezed_channels = max(1, int(input_filters * se_ratio))
self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1)
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1)
self._se_expand_bias = Tensor.zeros(oup)
self._project_conv = Tensor.zeros(output_filters, oup, 1, 1)
self._bn2 = BatchNorm2D(output_filters)
@@ -105,10 +107,11 @@ class MBConvBlock:
x = self._bn1(x).swish()
# has_se
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
x = x.mul(x_squeezed.sigmoid())
if self.has_se:
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
x = x.mul(x_squeezed.sigmoid())
x = self._bn2(x.conv2d(self._project_conv))
if x.shape == inputs.shape:
@@ -116,7 +119,7 @@ class MBConvBlock:
return x
class EfficientNet:
def __init__(self, number=0, classes=1000):
def __init__(self, number=0, classes=1000, has_se=True):
self.number = number
global_params = [
# width, depth
@@ -163,7 +166,7 @@ class EfficientNet:
args[3] = round_filters(args[3])
args[4] = round_filters(args[4])
for n in range(round_repeats(b[0])):
self._blocks.append(MBConvBlock(*args))
self._blocks.append(MBConvBlock(*args, has_se=has_se))
args[3] = args[4]
args[1] = (1,1)