mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
support symbolic shapes in split/chunk when split dim is concrete (#13718)
* support symbolic shapes in split/chunk when split dim is concrete Previously split() and chunk() required all dimensions to be concrete. Now they only require the dimension being split to be concrete, allowing them to work with tensors that have symbolic shapes in other dimensions. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * update CLAUDE.md: add pre-commit and no-amend rules 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix dim resolution order in split/chunk Ensure dim_sz is retrieved after dim is resolved, not before. The previous one-liner evaluated self.shape[dim] with the original unresolved dim value. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -95,6 +95,8 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
## Workflow Rules
|
||||
|
||||
- **NEVER commit without explicit user approval** - always show the diff and wait for approval
|
||||
- **NEVER amend commits** - always create a new commit instead
|
||||
- Run `pre-commit run --all-files` before committing to catch linting/type errors
|
||||
- Run tests before proposing commits
|
||||
- Test with `SPEC=2` when modifying UOp-related code
|
||||
|
||||
|
||||
@@ -95,6 +95,37 @@ class TestTensorVariable(unittest.TestCase):
|
||||
assert t.uop.base.buffer.size == 30
|
||||
assert t.uop.shape == (3, vb)
|
||||
|
||||
def test_symbolic_chunk(self):
|
||||
# chunk should work when split dimension is concrete, even if other dims are symbolic
|
||||
vv = Variable("a", 1, 10).bind(4)
|
||||
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
|
||||
chunks = t.chunk(2, dim=-1) # split along concrete dim 8
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0].shape[1] == 4
|
||||
assert chunks[1].shape[1] == 4
|
||||
# verify the values by shrinking to concrete shape first
|
||||
np.testing.assert_equal(chunks[0].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
|
||||
np.testing.assert_equal(chunks[1].shrink(((0, 4), (0, 4))).numpy(), np.ones((4, 4)))
|
||||
|
||||
def test_symbolic_split(self):
|
||||
# split should work when split dimension is concrete, even if other dims are symbolic
|
||||
vv = Variable("a", 1, 10).bind(3)
|
||||
t = Tensor.arange(30).reshape(10, 3).contiguous()[:, :vv] # shape (10, vv)
|
||||
splits = t.split(5, dim=0) # split along concrete dim 10
|
||||
assert len(splits) == 2
|
||||
assert splits[0].shape[0] == 5
|
||||
assert splits[1].shape[0] == 5
|
||||
# verify the values by shrinking to concrete shape first
|
||||
np.testing.assert_equal(splits[0].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[:5, :3])
|
||||
np.testing.assert_equal(splits[1].shrink(((0, 5), (0, 3))).numpy(), np.arange(30).reshape(10, 3)[5:, :3])
|
||||
|
||||
def test_symbolic_chunk_error_on_symbolic_dim(self):
|
||||
# chunk should fail when trying to split along a symbolic dimension
|
||||
vv = Variable("a", 1, 10).bind(4)
|
||||
t = Tensor.ones(10, 8).contiguous()[:vv, :] # shape (vv, 8)
|
||||
with self.assertRaises(AssertionError):
|
||||
t.chunk(2, dim=0) # can't split along symbolic dim
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1334,10 +1334,11 @@ class Tensor(OpMixin):
|
||||
print("\\n".join([repr(x.numpy()) for x in split]))
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim = self._resolve_dim(dim)
|
||||
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
|
||||
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
||||
dim_sz = self.shape[dim]
|
||||
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
|
||||
if isinstance(sizes, int): sizes = [min(sizes, dim_sz-i) for i in range(0, max(1, dim_sz), max(1, sizes))]
|
||||
assert sum(sizes) == dim_sz, f"expect sizes to sum exactly to {dim_sz}, but got {sum(sizes)}"
|
||||
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
||||
|
||||
def chunk(self, chunks:int, dim:int=0) -> list[Tensor]:
|
||||
@@ -1359,10 +1360,11 @@ class Tensor(OpMixin):
|
||||
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
||||
dim = self._resolve_dim(dim)
|
||||
return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
||||
dim_sz = self.shape[dim]
|
||||
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
|
||||
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
||||
return list(self.split(ceildiv(dim_sz, chunks) if dim_sz else [0]*chunks, dim=dim))
|
||||
|
||||
def unfold(self, dim:int, size:sint, step:int) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user