fix enet padding

This commit is contained in:
George Hotz
2020-11-09 17:56:57 -08:00
parent 866b759d3b
commit 9db95ab942

View File

@@ -25,8 +25,11 @@ class MBConvBlock:
else:
self._expand_conv = None
self.pad = (kernel_size-1)//2
self.strides = strides
if strides == (2,2):
self.pad = [(kernel_size-1)//2-1, (kernel_size-1)//2]*2
else:
self.pad = [(kernel_size-1)//2]*4
self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size)
self._bn1 = BatchNorm2D(oup)
@@ -44,7 +47,7 @@ class MBConvBlock:
x = inputs
if self._expand_conv:
x = swish(self._bn0(x.conv2d(self._expand_conv)))
x = x.pad2d(padding=(self.pad, self.pad, self.pad, self.pad))
x = x.pad2d(padding=self.pad)
x = x.conv2d(self._depthwise_conv, stride=self.strides, groups=self._depthwise_conv.shape[0])
x = swish(self._bn1(x))