mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
support None in pad_to and shrink_to (#12700)
This commit is contained in:
@@ -527,6 +527,7 @@ class TestTinygrad(unittest.TestCase):
|
|||||||
self.assertListEqual(t.shrink_to(16).tolist(), list(range(16)))
|
self.assertListEqual(t.shrink_to(16).tolist(), list(range(16)))
|
||||||
t = t.reshape(4, 8).contiguous().realize()
|
t = t.reshape(4, 8).contiguous().realize()
|
||||||
self.assertListEqual(t.shrink_to(2, 2).tolist(), [[0, 1], [8, 9]])
|
self.assertListEqual(t.shrink_to(2, 2).tolist(), [[0, 1], [8, 9]])
|
||||||
|
self.assertListEqual(t.shrink_to(None, 2).tolist(), t.shrink_to(4, 2).tolist())
|
||||||
with self.assertRaises(ValueError): t.shrink_to(2)
|
with self.assertRaises(ValueError): t.shrink_to(2)
|
||||||
with self.assertRaises(ValueError): t.shrink_to(2, 2, 2)
|
with self.assertRaises(ValueError): t.shrink_to(2, 2, 2)
|
||||||
|
|
||||||
@@ -636,8 +637,10 @@ class TestZeroShapeTensor(unittest.TestCase):
|
|||||||
|
|
||||||
np.testing.assert_equal(Tensor([1, 2]).pad_to(4).numpy(), [1, 2, 0, 0])
|
np.testing.assert_equal(Tensor([1, 2]).pad_to(4).numpy(), [1, 2, 0, 0])
|
||||||
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(2, 3).numpy(), [[1, 2, 0], [0, 0, 0]])
|
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(2, 3).numpy(), [[1, 2, 0], [0, 0, 0]])
|
||||||
with self.assertRaises(TypeError): Tensor([1, 2]).pad_to(2, 3)
|
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(1, 3).numpy(), [[1, 2, 0]])
|
||||||
with self.assertRaises(TypeError): Tensor([[1, 2]]).pad_to(3)
|
np.testing.assert_equal(Tensor([[1, 2]]).pad_to(None, 3).numpy(), [[1, 2, 0]])
|
||||||
|
with self.assertRaises(ValueError): Tensor([1, 2]).pad_to(2, 3)
|
||||||
|
with self.assertRaises(ValueError): Tensor([[1, 2]]).pad_to(3)
|
||||||
|
|
||||||
def test_shrink_into_zero(self):
|
def test_shrink_into_zero(self):
|
||||||
t = Tensor.rand(3, 4).realize()
|
t = Tensor.rand(3, 4).realize()
|
||||||
|
|||||||
@@ -1191,8 +1191,11 @@ class Tensor(MathTrait):
|
|||||||
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
||||||
|
|
||||||
# convenience
|
# convenience
|
||||||
def pad_to(self, shape, *args): return self.pad(tuple([(0, ns-s) for s,ns in itertools.zip_longest(self.shape, argfix(shape, *args))]))
|
def pad_to(self, shape, *args):
|
||||||
def shrink_to(self, shape, *args): return self.shrink(tuple([(0, ns) for ns in argfix(shape, *args)]))
|
if len(new_shape := argfix(shape, *args)) != self.ndim: raise ValueError(f"dim mismatch, cannot pad {self.shape} to {new_shape}")
|
||||||
|
return self.pad(tuple([None if ns is None else (0, ns-s) for s,ns in zip(self.shape, new_shape)]))
|
||||||
|
def shrink_to(self, shape, *args):
|
||||||
|
return self.shrink(tuple([None if ns is None else (0, ns) for ns in argfix(shape, *args)]))
|
||||||
|
|
||||||
# ***** movement high level ops *****
|
# ***** movement high level ops *****
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user