support None in pad_to and shrink_to (#12700)

This commit is contained in:
chenyu
2025-10-15 09:25:31 -04:00
committed by GitHub
parent 612e3d6143
commit 312c622d35
2 changed files with 10 additions and 4 deletions

View File

@@ -527,6 +527,7 @@ class TestTinygrad(unittest.TestCase):
self.assertListEqual(t.shrink_to(16).tolist(), list(range(16)))
t = t.reshape(4, 8).contiguous().realize()
self.assertListEqual(t.shrink_to(2, 2).tolist(), [[0, 1], [8, 9]])
self.assertListEqual(t.shrink_to(None, 2).tolist(), t.shrink_to(4, 2).tolist())
with self.assertRaises(ValueError): t.shrink_to(2)
with self.assertRaises(ValueError): t.shrink_to(2, 2, 2)
@@ -636,8 +637,10 @@ class TestZeroShapeTensor(unittest.TestCase):
np.testing.assert_equal(Tensor([1, 2]).pad_to(4).numpy(), [1, 2, 0, 0])
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(2, 3).numpy(), [[1, 2, 0], [0, 0, 0]])
with self.assertRaises(TypeError): Tensor([1, 2]).pad_to(2, 3)
with self.assertRaises(TypeError): Tensor([[1, 2]]).pad_to(3)
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(1, 3).numpy(), [[1, 2, 0]])
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(None, 3).numpy(), [[1, 2, 0]])
with self.assertRaises(ValueError): Tensor([1, 2]).pad_to(2, 3)
with self.assertRaises(ValueError): Tensor([[1, 2]]).pad_to(3)
def test_shrink_into_zero(self):
t = Tensor.rand(3, 4).realize()