Add replicate mode to Tensor.pad (#7608)

* base implementation

* add tests

* actually remove the assertionerror test

* actually only have reflect for this pr

* change the 4 if-else one liner

* maybe use a lambda

* fix

* maybe a lil cleaner

* fix tests

* complete

* small change

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2024-11-18 23:55:38 +08:00
committed by GitHub
parent 62db6398a5
commit 8100109c9d
5 changed files with 40 additions and 11 deletions

View File

@@ -360,7 +360,7 @@ class TestAssign(unittest.TestCase):
def test_permuted_assignment_masked_view_possible(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), 2)
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2)
a.assign(a + b)
kc = GlobalCounters.kernel_count
a.realize()
@@ -370,7 +370,7 @@ class TestAssign(unittest.TestCase):
def test_permuted_assignment_masked_view_not_contiguous(self):
a = Tensor.ones(4, 4).contiguous().realize()
with self.assertRaisesRegex(RuntimeError, "contiguous"):
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), 2).permute(1, 0)
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0)
a.assign(a + b)
a.realize()

View File

@@ -861,7 +861,7 @@ class TestLinearizer(unittest.TestCase):
def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE]
# RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN

View File

@@ -1390,10 +1390,30 @@ class TestOps(unittest.TestCase):
# raise error for too many or too little pads
self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0,1,0,3,0)), lambda x: x.pad((0,0,0,0,1,0,3,0)),
expected=(RuntimeError, ValueError))
# raise error for mode string typo
self.helper_test_exception([(3,3,3)], lambda x: torch.nn.functional.pad(x, (3,0), mode="typo"), lambda x: x.pad((3,0), mode="typo"),
expected=NotImplementedError)
x = Tensor.ones(3,3)
with self.assertRaises(ValueError): x.pad((None,(0,1),(3,0)))
with self.assertRaises(ValueError): x.pad(((0,1),))
def test_pad_reflect_mode(self):
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="reflect"), lambda x: x.pad((0,2,3,2), mode="reflect"))
helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="reflect"), lambda x: x.pad((0,2), mode="reflect"))
helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,4,1,2), mode="reflect"),
lambda x: x.pad((1,2,3,4,1,2), mode="reflect"))
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,2,-1), mode="reflect"), lambda x: x.pad((-1,2,2,-1), mode="reflect"))
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-3,0,-3), mode="reflect"), lambda x: x.pad((3,-3,0,-3), mode="reflect"))
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-5,1,-5), mode="reflect"), lambda x: x.pad((3,-5,1,-5), mode="reflect"))
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,0,0,-5), mode="reflect"), lambda x: x.pad((0,0,0,-5), mode="reflect"))
# max pad size for reflect is exactly once: pad < input size
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (4,4,0,4), mode="reflect"), lambda x:x.pad((4,4,0,4),mode="reflect"))
# raise error for relfection padding when: pad >= input size
self.helper_test_exception([(1,1,5,5)],
lambda x: torch.nn.functional.pad(x, (3,5,0,0),mode="reflect"), lambda x: x.pad((3,5,0,0),mode="reflect"),
expected=(RuntimeError, ValueError))
def test_pad_reshape(self):
helper_test_op([(1, 2)],
lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)),

View File

@@ -1186,7 +1186,7 @@ class TestSchedule(unittest.TestCase):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
b = Tensor.rand(3, 4, 5).realize()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
@@ -1195,7 +1195,7 @@ class TestSchedule(unittest.TestCase):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).realize()
b = Tensor.randn(3, 4, 5).realize()
out = (a.pad(((0, 1), (0, 1), (0, 1)), 1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), 1.0).sum()).contiguous()
out = (a.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()).contiguous()
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(keepdims=True) + \
@@ -1204,7 +1204,7 @@ class TestSchedule(unittest.TestCase):
def test_pad_reduce_unsafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
@@ -1213,7 +1213,7 @@ class TestSchedule(unittest.TestCase):
Tensor.manual_seed(0)
a = Tensor.randn(3, 4, 5).abs().realize()
b = Tensor.randn(3, 4, 5).abs().realize()
out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
# run_schedule(check_schedule(out, 1))
run_schedule(check_schedule(out, 4))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \

View File

@@ -1000,7 +1000,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
return F.Shrink.apply(self, arg=tuple(shrink_arg))
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], value:float=0.0) -> Tensor:
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
"""
Returns a tensor with padding applied based on the input `padding`.
`padding` supports two padding structures:
@@ -1015,6 +1015,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
- `padding` must have the same length as `self.ndim`.
Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
Padding modes is selected with `mode` which supports `constant` and `reflect`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
@@ -1030,6 +1031,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
```
"""
if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
# turn flat padding into group padding
pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
@@ -1037,9 +1039,16 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v)
# early return for symbolic with positive pads (no need to max)
if all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))
return _constant(X.shrink(shrinks), pads, value)
if mode == "constant" and all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), lambda shape: tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, shape))
if mode == "constant": return _constant(X.shrink(shrinks(X.shape)), pads, value)
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
for d,(pB,pA) in enumerate(pads):
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
return X.shrink(shrinks(X.shape))
# ***** movement high level ops *****