diff --git a/tinygrad/nn.py b/tinygrad/nn.py index ea9508dc96..3052b541c0 100644 --- a/tinygrad/nn.py +++ b/tinygrad/nn.py @@ -33,20 +33,15 @@ class BatchNorm2D: class Conv2d: def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): - self.out_channels = out_channels self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else (kernel_size[0], kernel_size[1]) self.stride = (stride, stride) if isinstance(stride, int) else (stride[0], stride[1]) self.padding = (padding, ) * 4 if isinstance(padding, int) else (padding[0], padding[0], padding[1], padding[1]) - self.use_bias = bias self.weight = Tensor.uniform(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1]) - if self.use_bias: - self.bias = Tensor.uniform(out_channels) + self.bias = Tensor.uniform(out_channels) if bias else None def __call__(self, x): - if self.padding[0] > 0: + if self.padding[0] > 0 or self.padding[2] > 0: x = x.pad2d(padding=self.padding) - x = x.conv2d(self.weight, stride=self.stride) - if self.use_bias: - x = x.add(self.bias.reshape(shape=(1, -1, 1, 1))) + x = x.conv2d(self.weight, self.bias, stride=self.stride) return x