fix dim resolution order in split/chunk

Ensure dim_sz is retrieved after dim is resolved, not before.
The previous one-liner evaluated self.shape[dim] with the original
unresolved dim value.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
George Hotz
2025-12-16 12:37:41 -04:00
parent 41e8a3f2b2
commit d30e3b3cad

View File

@@ -1334,7 +1334,8 @@ class Tensor(OpMixin):
print("\\n".join([repr(x.numpy()) for x in split]))
```
"""
dim, dim_sz = self._resolve_dim(dim), self.shape[dim]
dim = self._resolve_dim(dim)
dim_sz = self.shape[dim]
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
if isinstance(sizes, int): sizes = [min(sizes, dim_sz-i) for i in range(0, max(1, dim_sz), max(1, sizes))]
assert sum(sizes) == dim_sz, f"expect sizes to sum exactly to {dim_sz}, but got {sum(sizes)}"
@@ -1359,7 +1360,8 @@ class Tensor(OpMixin):
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
"""
dim, dim_sz = self._resolve_dim(dim), self.shape[dim]
dim = self._resolve_dim(dim)
dim_sz = self.shape[dim]
assert isinstance(dim_sz, int), f"does not support symbolic shape in split dimension {dim}: {self.shape}"
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
return list(self.split(ceildiv(dim_sz, chunks) if dim_sz else [0]*chunks, dim=dim))