From ae83e9844c7ed07f04087af7c865df97540be745 Mon Sep 17 00:00:00 2001 From: kposborne2 <53231580+kposborne2@users.noreply.github.com> Date: Thu, 1 Jun 2023 00:03:22 -0700 Subject: [PATCH] add output_padding to transposed conv (#875) --- extra/onnx_ops.py | 4 ++-- test/test_ops.py | 25 +++++++++++++++++++------ tinygrad/nn/__init__.py | 6 +++--- tinygrad/tensor.py | 4 ++-- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 20b239ad77..87465dfade 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -99,8 +99,8 @@ def Conv(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=Non padding = [p for ps in zip(pads[:len(pads)//2][::-1], pads[len(pads)//2:][::-1]) for p in ps] if pads is not None else 0 # reorder padding return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding) -def ConvTranspose(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1): - return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0) +def ConvTranspose(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1): + return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0, output_padding=output_padding) # Reimplemented here because you need legacy RNG for passing ONNX tests. def Dropout(data, ratio=0.5, training_mode=False, seed=None): diff --git a/test/test_ops.py b/test/test_ops.py index e1272a4efe..38abd02a87 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -456,6 +456,11 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv_transpose2d(x,w).relu(), lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + def test_bias_conv_transpose2d(self): + helper_test_op([(2,4,9,9), (4,4,3,3), (4,)], + lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b).relu(), + lambda x,w,b: Tensor.conv_transpose2d(x,w,b).relu(), atol=1e-4, grad_rtol=1e-5) + def test_grouped_conv_transpose2d(self): helper_test_op([(2,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv_transpose2d(x,w,groups=2).relu(), @@ -468,14 +473,22 @@ class TestOps(unittest.TestCase): lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), atol=1e-4, grad_rtol=1e-5) def test_dilated_conv_transpose2d(self): - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=2).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w,dilation=2).relu(), atol=1e-4, grad_rtol=1e-5) + for dilation in [(1,2), (2,1), 2, 1]: + helper_test_op([(2,4,9,9), (4,4,3,3)], + lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=dilation).relu(), + lambda x,w: Tensor.conv_transpose2d(x,w,dilation=dilation).relu(), atol=1e-4, grad_rtol=1e-5) def test_strided_conv_transpose2d(self): - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w,stride=2).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w,stride=2).relu(), atol=1e-4, grad_rtol=1e-5) + for stride in [(2,1), (1,2), 1]: + helper_test_op([(2,4,4,5), (4,4,3,3)], + lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(), + lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5) + + def test_output_padded_conv_transpose2d(self): + for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]: + helper_test_op([(2,4,6,5), (4,4,3,3),(4,)], + lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(), + lambda x,w,b: Tensor.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5) @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_simple_conv_transpose3d(self): diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 75eaffb4ff..b54d7b1fe0 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -45,14 +45,14 @@ class Conv2d: return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups) class ConvTranspose2d: - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) - self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups + self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups self.weight = Tensor.glorot_uniform(in_channels, out_channels//groups, *self.kernel_size) self.bias = Tensor.zeros(out_channels) if bias else None def __call__(self, x): - return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups) + return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups) class Linear: def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c4d8e2a9a7..63b994e9d8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -391,7 +391,7 @@ class Tensor: def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) - def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: + def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing) stride = make_pair(stride, len(HW)) @@ -400,7 +400,7 @@ class Tensor: x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride))) x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)]) x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) - padding = flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in reversed(list(zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW)))))) + padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding) def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: