mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
add output_padding to transposed conv (#875)
This commit is contained in:
@@ -45,14 +45,14 @@ class Conv2d:
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
class ConvTranspose2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
|
||||
self.weight = Tensor.glorot_uniform(in_channels, out_channels//groups, *self.kernel_size)
|
||||
self.bias = Tensor.zeros(out_channels) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
|
||||
|
||||
Reference in New Issue
Block a user