match torch api for pad2d

This commit is contained in:
George Hotz
2020-11-09 17:48:56 -08:00
parent daf073535f
commit 866b759d3b
3 changed files with 7 additions and 7 deletions

View File

@@ -63,7 +63,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,1,1,1)), lambda x: x.pad2d(padding=(1,1,1,1)), gpu=self.gpu)
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
def test_conv2d(self):
for bs in [1,8]: