support symbolic shapes in split/chunk when split dim is concrete (#13718)

* support symbolic shapes in split/chunk when split dim is concrete

Previously split() and chunk() required all dimensions to be concrete.
Now they only require the dimension being split to be concrete, allowing
them to work with tensors that have symbolic shapes in other dimensions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* update CLAUDE.md: add pre-commit and no-amend rules

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* fix dim resolution order in split/chunk

Ensure dim_sz is retrieved after dim is resolved, not before.
The previous one-liner evaluated self.shape[dim] with the original
unresolved dim value.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
George Hotz
2025-12-16 13:55:06 -04:00
committed by GitHub
parent e428fbfab6
commit bfe374c7f5
3 changed files with 41 additions and 6 deletions

View File

@@ -95,6 +95,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
## Workflow Rules
- **NEVER commit without explicit user approval** - always show the diff and wait for approval
- **NEVER amend commits** - always create a new commit instead
- Run `pre-commit run --all-files` before committing to catch linting/type errors
- Run tests before proposing commits
- Test with `SPEC=2` when modifying UOp-related code

View File

@@ -95,6 +95,37 @@ class TestTensorVariable(unittest.TestCase):
assert t.uop.base.buffer.size == 30
assert t.uop.shape == (3, vb)
def test_symbolic_chunk(self):
# chunk should work when split dimension is concrete, even if other dims are symbolic
vv = Variable("a", 1, 10).bind(4)
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
chunks = t.chunk(2, dim=-1) # split along concrete dim 8
assert len(chunks) == 2
assert chunks[0].shape[1] == 4
assert chunks[1].shape[1] == 4
# verify the values by shrinking to concrete shape first
np.testing.assert_equal(chunks[0].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
np.testing.assert_equal(chunks[1].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
def test_symbolic_split(self):
# split should work when split dimension is concrete, even if other dims are symbolic
vv = Variable("a", 1, 10).bind(3)
t = Tensor.arange(30).reshape(10, 3).contiguous()[:, :vv] # shape (10, vv)
splits = t.split(5, dim=0) # split along concrete dim 10
assert len(splits) == 2
assert splits[0].shape[0] == 5
assert splits[1].shape[0] == 5
# verify the values by shrinking to concrete shape first
np.testing.assert_equal(splits[0].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[:5, :3])
np.testing.assert_equal(splits[1].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[5:, :3])
def test_symbolic_chunk_error_on_symbolic_dim(self):
# chunk should fail when trying to split along a symbolic dimension
vv = Variable("a", 1, 10).bind(4)
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
with self.assertRaises(AssertionError):
t.chunk(2, dim=0) # can't split along symbolic dim
if __name__ == '__main__':
unittest.main()

View File

@@ -1334,10 +1334,11 @@ class Tensor(OpMixin):
print("\\n".join([repr(x.numpy()) for x in split]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim = self._resolve_dim(dim)
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
dim_sz = self.shape[dim]
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
if isinstance(sizes, int): sizes = [min(sizes, dim_sz-i) for i in range(0, max(1, dim_sz), max(1, sizes))]
assert sum(sizes) == dim_sz, f"expect sizes to sum exactly to {dim_sz}, but got {sum(sizes)}"
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
@@ -1359,10 +1360,11 @@ class Tensor(OpMixin):
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
dim = self._resolve_dim(dim)
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
dim_sz = self.shape[dim]
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
return list(self.split(ceildiv(dim_sz, chunks) if dim_sz else [0]*chunks, dim=dim))
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
"""