mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix enet padding
This commit is contained in:
@@ -25,8 +25,11 @@ class MBConvBlock:
|
||||
else:
|
||||
self._expand_conv = None
|
||||
|
||||
self.pad = (kernel_size-1)//2
|
||||
self.strides = strides
|
||||
if strides == (2,2):
|
||||
self.pad = [(kernel_size-1)//2-1, (kernel_size-1)//2]*2
|
||||
else:
|
||||
self.pad = [(kernel_size-1)//2]*4
|
||||
|
||||
self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size)
|
||||
self._bn1 = BatchNorm2D(oup)
|
||||
@@ -44,7 +47,7 @@ class MBConvBlock:
|
||||
x = inputs
|
||||
if self._expand_conv:
|
||||
x = swish(self._bn0(x.conv2d(self._expand_conv)))
|
||||
x = x.pad2d(padding=(self.pad, self.pad, self.pad, self.pad))
|
||||
x = x.pad2d(padding=self.pad)
|
||||
x = x.conv2d(self._depthwise_conv, stride=self.strides, groups=self._depthwise_conv.shape[0])
|
||||
x = swish(self._bn1(x))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user