mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Add replicate mode to Tensor.pad (#7608)
* base implementation * add tests * actually remove the assertionerror test * actually only have reflect for this pr * change the 4 if-else one liner * maybe use a lambda * fix * maybe a lil cleaner * fix tests * complete * small change --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -360,7 +360,7 @@ class TestAssign(unittest.TestCase):
|
||||
|
||||
def test_permuted_assignment_masked_view_possible(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), 2)
|
||||
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2)
|
||||
a.assign(a + b)
|
||||
kc = GlobalCounters.kernel_count
|
||||
a.realize()
|
||||
@@ -370,7 +370,7 @@ class TestAssign(unittest.TestCase):
|
||||
def test_permuted_assignment_masked_view_not_contiguous(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
||||
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), 2).permute(1, 0)
|
||||
b = a.shrink((None, (0, 2))).pad((None, (0, 2)), value=2).permute(1, 0)
|
||||
a.assign(a + b)
|
||||
a.realize()
|
||||
|
||||
|
||||
@@ -861,7 +861,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_two_nested_range_alt_indexing(self):
|
||||
a = Tensor([2, 2]).realize()
|
||||
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum()
|
||||
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()
|
||||
lin = helper_linearizer_opt(out, wanna_output=[24])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE]
|
||||
# RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN
|
||||
|
||||
@@ -1390,10 +1390,30 @@ class TestOps(unittest.TestCase):
|
||||
# raise error for too many or too little pads
|
||||
self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0,1,0,3,0)), lambda x: x.pad((0,0,0,0,1,0,3,0)),
|
||||
expected=(RuntimeError, ValueError))
|
||||
# raise error for mode string typo
|
||||
self.helper_test_exception([(3,3,3)], lambda x: torch.nn.functional.pad(x, (3,0), mode="typo"), lambda x: x.pad((3,0), mode="typo"),
|
||||
expected=NotImplementedError)
|
||||
x = Tensor.ones(3,3)
|
||||
with self.assertRaises(ValueError): x.pad((None,(0,1),(3,0)))
|
||||
with self.assertRaises(ValueError): x.pad(((0,1),))
|
||||
|
||||
def test_pad_reflect_mode(self):
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,2,3,2), mode="reflect"), lambda x: x.pad((0,2,3,2), mode="reflect"))
|
||||
helper_test_op([(5,5,5)], lambda x: torch.nn.functional.pad(x, (0,2), mode="reflect"), lambda x: x.pad((0,2), mode="reflect"))
|
||||
helper_test_op([(1,1,5,5,5)], lambda x: torch.nn.functional.pad(x, (1,2,3,4,1,2), mode="reflect"),
|
||||
lambda x: x.pad((1,2,3,4,1,2), mode="reflect"))
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,2,-1), mode="reflect"), lambda x: x.pad((-1,2,2,-1), mode="reflect"))
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-3,0,-3), mode="reflect"), lambda x: x.pad((3,-3,0,-3), mode="reflect"))
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (3,-5,1,-5), mode="reflect"), lambda x: x.pad((3,-5,1,-5), mode="reflect"))
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (0,0,0,-5), mode="reflect"), lambda x: x.pad((0,0,0,-5), mode="reflect"))
|
||||
|
||||
# max pad size for reflect is exactly once: pad < input size
|
||||
helper_test_op([(1,1,5,5)], lambda x: torch.nn.functional.pad(x, (4,4,0,4), mode="reflect"), lambda x:x.pad((4,4,0,4),mode="reflect"))
|
||||
# raise error for relfection padding when: pad >= input size
|
||||
self.helper_test_exception([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.pad(x, (3,5,0,0),mode="reflect"), lambda x: x.pad((3,5,0,0),mode="reflect"),
|
||||
expected=(RuntimeError, ValueError))
|
||||
|
||||
def test_pad_reshape(self):
|
||||
helper_test_op([(1, 2)],
|
||||
lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)),
|
||||
|
||||
@@ -1186,7 +1186,7 @@ class TestSchedule(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.rand(3, 4, 5).realize()
|
||||
b = Tensor.rand(3, 4, 5).realize()
|
||||
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
@@ -1195,7 +1195,7 @@ class TestSchedule(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(3, 4, 5).realize()
|
||||
b = Tensor.randn(3, 4, 5).realize()
|
||||
out = (a.pad(((0, 1), (0, 1), (0, 1)), 1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), 1.0).sum()).contiguous()
|
||||
out = (a.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum(keepdim=True)+b.pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()).contiguous()
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(keepdims=True) + \
|
||||
@@ -1204,7 +1204,7 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_pad_reduce_unsafe(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.rand(3, 4, 5).realize()
|
||||
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
@@ -1213,7 +1213,7 @@ class TestSchedule(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(3, 4, 5).abs().realize()
|
||||
b = Tensor.randn(3, 4, 5).abs().realize()
|
||||
out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
out = (a.log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum()+b).abs().log2().pad(((0, 1), (0, 1), (0, 1)), value=1.0).sum().contiguous()
|
||||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 4))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
|
||||
|
||||
@@ -1000,7 +1000,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
||||
return F.Shrink.apply(self, arg=tuple(shrink_arg))
|
||||
|
||||
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], value:float=0.0) -> Tensor:
|
||||
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with padding applied based on the input `padding`.
|
||||
`padding` supports two padding structures:
|
||||
@@ -1015,6 +1015,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
- `padding` must have the same length as `self.ndim`.
|
||||
|
||||
Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
|
||||
Padding modes is selected with `mode` which supports `constant` and `reflect`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
||||
@@ -1030,6 +1031,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy())
|
||||
```
|
||||
"""
|
||||
if mode not in {"constant", "reflect"}: raise NotImplementedError(f"{mode=} is not supported")
|
||||
if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads")
|
||||
# turn flat padding into group padding
|
||||
pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding
|
||||
@@ -1037,9 +1039,16 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX))
|
||||
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v)
|
||||
# early return for symbolic with positive pads (no need to max)
|
||||
if all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
|
||||
pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape))
|
||||
return _constant(X.shrink(shrinks), pads, value)
|
||||
if mode == "constant" and all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value)
|
||||
pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), lambda shape: tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, shape))
|
||||
if mode == "constant": return _constant(X.shrink(shrinks(X.shape)), pads, value)
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
for d,(pB,pA) in enumerate(pads):
|
||||
if pB >= (s:=X.shape[d]) or pA>=s: raise ValueError(f"Padding ({pB}, {pA}) should be less than the input size={s} for dim={d}.")
|
||||
slcB, slcA, = slice(pB,0,-1), slice(s-2 if s-2>=0 else None, s-2-pA if s-2-pA>=0 else None, -1)
|
||||
xB, xA = (X[[slc if i == d else slice(None) for i in range(X.ndim)]] if p > 0 else None for slc, p in ((slcB, pB), (slcA, pA)))
|
||||
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
||||
return X.shrink(shrinks(X.shape))
|
||||
|
||||
# ***** movement high level ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user