diff --git a/CLAUDE.md b/CLAUDE.md index 858e5e9204..eb262afc0f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -95,6 +95,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()" ## Workflow Rules - **NEVER commit without explicit user approval** - always show the diff and wait for approval +- **NEVER amend commits** - always create a new commit instead +- Run `pre-commit run --all-files` before committing to catch linting/type errors - Run tests before proposing commits - Test with `SPEC=2` when modifying UOp-related code diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index cbffc2dbb5..643b1c5dec 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -95,6 +95,37 @@ class TestTensorVariable(unittest.TestCase): assert t.uop.base.buffer.size == 30 assert t.uop.shape == (3, vb) + def test_symbolic_chunk(self): + # chunk should work when split dimension is concrete, even if other dims are symbolic + vv = Variable("a", 1, 10).bind(4) + t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8) + chunks = t.chunk(2, dim=-1) # split along concrete dim 8 + assert len(chunks) == 2 + assert chunks[0].shape[1] == 4 + assert chunks[1].shape[1] == 4 + # verify the values by shrinking to concrete shape first + np.testing.assert_equal(chunks[0].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4))) + np.testing.assert_equal(chunks[1].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4))) + + def test_symbolic_split(self): + # split should work when split dimension is concrete, even if other dims are symbolic + vv = Variable("a", 1, 10).bind(3) + t = Tensor.arange(30).reshape(10, 3).contiguous()[:, :vv] # shape (10, vv) + splits = t.split(5, dim=0) # split along concrete dim 10 + assert len(splits) == 2 + assert splits[0].shape[0] == 5 + assert splits[1].shape[0] == 5 + # verify the values by shrinking to concrete shape first + np.testing.assert_equal(splits[0].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[:5, :3]) + np.testing.assert_equal(splits[1].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[5:, :3]) + + def test_symbolic_chunk_error_on_symbolic_dim(self): + # chunk should fail when trying to split along a symbolic dimension + vv = Variable("a", 1, 10).bind(4) + t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8) + with self.assertRaises(AssertionError): + t.chunk(2, dim=0) # can't split along symbolic dim + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 291f068655..1a5c25959a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1334,10 +1334,11 @@ class Tensor(OpMixin): print("\\n".join([repr(x.numpy()) for x in split])) ``` """ - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" dim = self._resolve_dim(dim) - if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))] - assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}" + dim_sz = self.shape[dim] + assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}" + if isinstance(sizes, int): sizes = [min(sizes, dim_sz-i) for i in range(0, max(1, dim_sz), max(1, sizes))] + assert sum(sizes) == dim_sz, f"expect sizes to sum exactly to {dim_sz}, but got {sum(sizes)}" return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))]) def chunk(self, chunks:int, dim:int=0) -> list[Tensor]: @@ -1359,10 +1360,11 @@ class Tensor(OpMixin): print("\\n".join([repr(x.numpy()) for x in chunked])) ``` """ - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}" dim = self._resolve_dim(dim) - return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim)) + dim_sz = self.shape[dim] + assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}" + assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}" + return list(self.split(ceildiv(dim_sz, chunks) if dim_sz else [0]*chunks, dim=dim)) def unfold(self, dim:int, size:sint, step:int) -> Tensor: """