mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.pad_to and Tensor.shrink_to (#12210)
most of the time i want this instead of spelling out the args also add more input validation to shrink
This commit is contained in:
@@ -109,7 +109,7 @@ class TextDecoder:
|
|||||||
|
|
||||||
def forward(self, x:Tensor, pos:Union[Variable, Literal[0]], encoded_audio:Tensor):
|
def forward(self, x:Tensor, pos:Union[Variable, Literal[0]], encoded_audio:Tensor):
|
||||||
seqlen = x.shape[-1]
|
seqlen = x.shape[-1]
|
||||||
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None, None))
|
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None))
|
||||||
for block in self.blocks: x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
|
for block in self.blocks: x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
|
||||||
return self.output_tok(x)
|
return self.output_tok(x)
|
||||||
|
|
||||||
|
|||||||
@@ -550,6 +550,11 @@ class TestTinygrad(unittest.TestCase):
|
|||||||
def test_shrink(self):
|
def test_shrink(self):
|
||||||
t = Tensor.arange(32).contiguous().realize()
|
t = Tensor.arange(32).contiguous().realize()
|
||||||
self.assertListEqual(t[16:20].tolist(), [16,17,18,19])
|
self.assertListEqual(t[16:20].tolist(), [16,17,18,19])
|
||||||
|
self.assertListEqual(t.shrink_to(16).tolist(), list(range(16)))
|
||||||
|
t = t.reshape(4, 8).contiguous().realize()
|
||||||
|
self.assertListEqual(t.shrink_to(2, 2).tolist(), [[0, 1], [8, 9]])
|
||||||
|
with self.assertRaises(ValueError): t.shrink_to(2)
|
||||||
|
with self.assertRaises(ValueError): t.shrink_to(2, 2, 2)
|
||||||
|
|
||||||
@unittest.skip("this test is just flaky, sync issue")
|
@unittest.skip("this test is just flaky, sync issue")
|
||||||
class TestMoveTensor(unittest.TestCase):
|
class TestMoveTensor(unittest.TestCase):
|
||||||
@@ -644,17 +649,22 @@ class TestZeroShapeTensor(unittest.TestCase):
|
|||||||
|
|
||||||
def test_pad(self):
|
def test_pad(self):
|
||||||
t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), value=1)
|
t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), value=1)
|
||||||
assert t.shape == (3, 2, 2)
|
self.assertEqual(t.shape, (3, 2, 2))
|
||||||
np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2)))
|
np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2)))
|
||||||
|
|
||||||
t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), value=1)
|
t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), value=1)
|
||||||
assert t.shape == (3, 4, 0)
|
self.assertEqual(t.shape, (3, 4, 0))
|
||||||
np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0)))
|
np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0)))
|
||||||
|
|
||||||
t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), value=1)
|
t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), value=1)
|
||||||
assert t.shape == (5, 2, 0)
|
self.assertEqual(t.shape, (5, 2, 0))
|
||||||
np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0)))
|
np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0)))
|
||||||
|
|
||||||
|
np.testing.assert_equal(Tensor([1, 2]).pad_to(4).numpy(), [1, 2, 0, 0])
|
||||||
|
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(2, 3).numpy(), [[1, 2, 0], [0, 0, 0]])
|
||||||
|
with self.assertRaises(TypeError): Tensor([1, 2]).pad_to(2, 3)
|
||||||
|
with self.assertRaises(TypeError): Tensor([[1, 2]]).pad_to(3)
|
||||||
|
|
||||||
def test_shrink_into_zero(self):
|
def test_shrink_into_zero(self):
|
||||||
t = Tensor.rand(3, 4).realize()
|
t = Tensor.rand(3, 4).realize()
|
||||||
assert t.shrink((None, (2, 2))).realize().shape == (3, 0)
|
assert t.shrink((None, (2, 2))).realize().shape == (3, 0)
|
||||||
|
|||||||
@@ -1065,6 +1065,7 @@ class Tensor(MathTrait):
|
|||||||
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||||
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
||||||
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
||||||
|
|
||||||
@@ -1131,6 +1132,10 @@ class Tensor(MathTrait):
|
|||||||
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
X = Tensor.cat(*(X_ for X_ in (xB, X, xA) if X_ is not None), dim=d)
|
||||||
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
||||||
|
|
||||||
|
# convenience
|
||||||
|
def pad_to(self, shape, *args): return self.pad(tuple([(0, ns-s) for s,ns in itertools.zip_longest(self.shape, argfix(shape, *args))]))
|
||||||
|
def shrink_to(self, shape, *args): return self.shrink(tuple([(0, ns) for ns in argfix(shape, *args)]))
|
||||||
|
|
||||||
# ***** movement high level ops *****
|
# ***** movement high level ops *****
|
||||||
|
|
||||||
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
|
def _getitem(self, indices, v: Tensor|None = None) -> Tensor:
|
||||||
|
|||||||
Reference in New Issue
Block a user