mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
support bias in conv like linear
This commit is contained in:
@@ -44,8 +44,8 @@ class MBConvBlock:
|
||||
# has_se
|
||||
if self.has_se:
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
|
||||
x_squeezed = x_squeezed.conv2d(self._se_expand, self._se_expand_bias)
|
||||
x = x.mul(x_squeezed.sigmoid())
|
||||
|
||||
x = self._bn2(x.conv2d(self._project_conv))
|
||||
|
||||
@@ -308,6 +308,10 @@ class Tensor:
|
||||
def max_pool2d(self, kernel_size=(2,2)):
|
||||
return self._pool2d(*kernel_size).max(axis=(3,5))
|
||||
|
||||
def conv2d(self, weight, bias=None, stride=1, groups=1):
|
||||
ret = self._conv2d(weight, stride=stride, groups=groups)
|
||||
return ret if bias is None else ret.add(bias.reshape(shape=[1, -1, 1, 1]))
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, weight, bias):
|
||||
|
||||
Reference in New Issue
Block a user