clean up Conv2d

This commit is contained in:
George Hotz
2021-11-30 11:02:55 -05:00
parent 835869974c
commit 7d7e2b690d

View File

@@ -33,20 +33,15 @@ class BatchNorm2D:
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
self.out_channels = out_channels
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else (kernel_size[0], kernel_size[1])
self.stride = (stride, stride) if isinstance(stride, int) else (stride[0], stride[1])
self.padding = (padding, ) * 4 if isinstance(padding, int) else (padding[0], padding[0], padding[1], padding[1])
self.use_bias = bias
self.weight = Tensor.uniform(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1])
if self.use_bias:
self.bias = Tensor.uniform(out_channels)
self.bias = Tensor.uniform(out_channels) if bias else None
def __call__(self, x):
if self.padding[0] > 0:
if self.padding[0] > 0 or self.padding[2] > 0:
x = x.pad2d(padding=self.padding)
x = x.conv2d(self.weight, stride=self.stride)
if self.use_bias:
x = x.add(self.bias.reshape(shape=(1, -1, 1, 1)))
x = x.conv2d(self.weight, self.bias, stride=self.stride)
return x