mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
clean up Conv2d
This commit is contained in:
@@ -33,20 +33,15 @@ class BatchNorm2D:
|
||||
|
||||
class Conv2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else (kernel_size[0], kernel_size[1])
|
||||
self.stride = (stride, stride) if isinstance(stride, int) else (stride[0], stride[1])
|
||||
self.padding = (padding, ) * 4 if isinstance(padding, int) else (padding[0], padding[0], padding[1], padding[1])
|
||||
self.use_bias = bias
|
||||
self.weight = Tensor.uniform(out_channels, in_channels, self.kernel_size[0], self.kernel_size[1])
|
||||
if self.use_bias:
|
||||
self.bias = Tensor.uniform(out_channels)
|
||||
self.bias = Tensor.uniform(out_channels) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
if self.padding[0] > 0:
|
||||
if self.padding[0] > 0 or self.padding[2] > 0:
|
||||
x = x.pad2d(padding=self.padding)
|
||||
x = x.conv2d(self.weight, stride=self.stride)
|
||||
if self.use_bias:
|
||||
x = x.add(self.bias.reshape(shape=(1, -1, 1, 1)))
|
||||
x = x.conv2d(self.weight, self.bias, stride=self.stride)
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user