Tensor.pad_to and Tensor.shrink_to (#12210)

most of the time i want this instead of spelling out the args

also add more input validation to shrink
This commit is contained in:
chenyu
2025-09-16 12:24:55 -04:00
committed by GitHub
parent 122a50fe8c
commit 84d2d047ea
3 changed files with 19 additions and 4 deletions

View File

@@ -550,6 +550,11 @@ class TestTinygrad(unittest.TestCase):
def test_shrink(self):
t = Tensor.arange(32).contiguous().realize()
self.assertListEqual(t[16:20].tolist(), [16,17,18,19])
self.assertListEqual(t.shrink_to(16).tolist(), list(range(16)))
t = t.reshape(4, 8).contiguous().realize()
self.assertListEqual(t.shrink_to(2, 2).tolist(), [[0, 1], [8, 9]])
with self.assertRaises(ValueError): t.shrink_to(2)
with self.assertRaises(ValueError): t.shrink_to(2, 2, 2)
@unittest.skip("this test is just flaky, sync issue")
class TestMoveTensor(unittest.TestCase):
@@ -644,17 +649,22 @@ class TestZeroShapeTensor(unittest.TestCase):
def test_pad(self):
t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), value=1)
assert t.shape == (3, 2, 2)
self.assertEqual(t.shape, (3, 2, 2))
np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2)))
t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), value=1)
assert t.shape == (3, 4, 0)
self.assertEqual(t.shape, (3, 4, 0))
np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0)))
t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), value=1)
assert t.shape == (5, 2, 0)
self.assertEqual(t.shape, (5, 2, 0))
np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0)))
np.testing.assert_equal(Tensor([1, 2]).pad_to(4).numpy(), [1, 2, 0, 0])
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(2, 3).numpy(), [[1, 2, 0], [0, 0, 0]])
with self.assertRaises(TypeError): Tensor([1, 2]).pad_to(2, 3)
with self.assertRaises(TypeError): Tensor([[1, 2]]).pad_to(3)
def test_shrink_into_zero(self):
t = Tensor.rand(3, 4).realize()
assert t.shrink((None, (2, 2))).realize().shape == (3, 0)