diff --git a/test/test_ops.py b/test/test_ops.py index ee5adf86f9..5543b861e1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -614,6 +614,8 @@ class TestOps(unittest.TestCase): def test_pad(self): helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2)))) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5)) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("inf")), lambda x: x.pad(((3,4), (1,2)), value=float("inf"))) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("-inf")), lambda x: x.pad(((3,4), (1,2)), value=float("-inf"))) def test_transpose(self): helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1be8b7eaea..675d5138e7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ import operator import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes -from math import ceil, pi, prod, sqrt, log, cos, copysign +from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf from tinygrad.lazy import Device, LazyBuffer from tinygrad.ops import LoadOps @@ -246,6 +246,7 @@ class Tensor: def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor: ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self + if isinf(value): return ret + copysign(1,value)/mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg) return ret if 0 == value else ret + (value - mlops.Pad.apply(Tensor.full(self.shape, value), arg=arg)) # ***** movement hlops *****