From 6ea5df19b26dcf82889eb3b37eff4feaa95fbca9 Mon Sep 17 00:00:00 2001 From: Marcello Fuschi Date: Mon, 29 May 2023 16:57:06 +0200 Subject: [PATCH] Fix conv_transpose2d asymmetric padding (#840) --- test/test_ops.py | 7 ++++--- tinygrad/tensor.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 2c0621286a..a7bfaeaa9c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -441,9 +441,10 @@ class TestOps(unittest.TestCase): lambda x,w: Tensor.conv_transpose2d(x,w,groups=2).relu(), atol=1e-4, grad_rtol=1e-5) def test_padded_conv_transpose2d(self): - helper_test_op([(2,4,9,9), (4,4,3,3)], - lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=1).relu(), - lambda x,w: Tensor.conv_transpose2d(x,w,padding=1).relu(), atol=1e-4, grad_rtol=1e-5) + for padding in [(1,2), (2,1), 2, 1, 0]: + helper_test_op([(2,4,9,9), (4,4,3,3)], + lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=padding).relu(), + lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), atol=1e-4, grad_rtol=1e-5) def test_dilated_conv_transpose2d(self): helper_test_op([(2,4,9,9), (4,4,3,3)], diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6f315e654d..214624c4fe 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -385,8 +385,8 @@ class Tensor: x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride))) x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)]) x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) - # TODO: the make_pair on padding is wrong in the asymmetric padding case - return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW))))) + padding = flatten(((k-1)*d-p,(k-1)*d-p) for k,p,d in reversed(list(zip(HW, make_pair(padding, len(HW)), make_pair(dilation, len(HW)))))) + return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding) def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]