diff --git a/test/test_ops.py b/test/test_ops.py index e59eb4574c..4a30ca42e5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -612,6 +612,10 @@ class TestOps(unittest.TestCase): def test_pad2d(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) + def test_pad(self): + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2)))) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5)) + def test_transpose(self): helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2)) helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7148a73c06..87ef53f6ff 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -243,8 +243,10 @@ class Tensor: def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) - def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self + def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor: + ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self + return ret if 0 == value else ret + (value - mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg)) # ***** movement hlops *****