mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
match torch api for pad2d
This commit is contained in:
@@ -63,7 +63,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,1,1,1)), lambda x: x.pad2d(padding=(1,1,1,1)), gpu=self.gpu)
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu)
|
||||
|
||||
def test_conv2d(self):
|
||||
for bs in [1,8]:
|
||||
|
||||
Reference in New Issue
Block a user