mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add output_padding to transposed conv (#875)
This commit is contained in:
@@ -99,8 +99,8 @@ def Conv(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=Non
|
||||
padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 # reorder padding
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding)
|
||||
|
||||
def ConvTranspose(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0)
|
||||
def ConvTranspose(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1):
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0, output_padding=output_padding)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout(data, ratio=0.5, training_mode=False, seed=None):
|
||||
|
||||
@@ -456,6 +456,11 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_bias_conv_transpose2d(self):
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3), (4,)],
|
||||
lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b).relu(),
|
||||
lambda x,w,b: Tensor.conv_transpose2d(x,w,b).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_grouped_conv_transpose2d(self):
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,groups=2).relu(),
|
||||
@@ -468,14 +473,22 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_dilated_conv_transpose2d(self):
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=2).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,dilation=2).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
for dilation in [(1,2), (2,1), 2, 1]:
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=dilation).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,dilation=dilation).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_strided_conv_transpose2d(self):
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,stride=2).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
for stride in [(2,1), (1,2), 1]:
|
||||
helper_test_op([(2,4,4,5), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_output_padded_conv_transpose2d(self):
|
||||
for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]:
|
||||
helper_test_op([(2,4,6,5), (4,4,3,3),(4,)],
|
||||
lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(),
|
||||
lambda x,w,b: Tensor.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(IMAGE>0, "no conv3d on images")
|
||||
def test_simple_conv_transpose3d(self):
|
||||
|
||||
@@ -45,14 +45,14 @@ class Conv2d:
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
class ConvTranspose2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups
|
||||
self.weight = Tensor.glorot_uniform(in_channels, out_channels//groups, *self.kernel_size)
|
||||
self.bias = Tensor.zeros(out_channels) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
|
||||
|
||||
@@ -391,7 +391,7 @@ class Tensor:
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
|
||||
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
|
||||
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
|
||||
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
|
||||
stride = make_pair(stride, len(HW))
|
||||
@@ -400,7 +400,7 @@ class Tensor:
|
||||
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
|
||||
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
|
||||
x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
|
||||
padding = flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in reversed(list(zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW))))))
|
||||
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
|
||||
return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding)
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user