update _reshape_mask for symbolic shape expand (#10726)

* don't merge shape symbolic reshape symbolic

* proper fix
This commit is contained in:
chenyu
2025-06-09 13:35:02 -07:00
committed by GitHub
parent 27dd97f688
commit 49f999d919
2 changed files with 10 additions and 3 deletions

View File

@@ -226,6 +226,13 @@ class TestSymbolicExpand(unittest.TestCase):
a = a + 1
self.assertTupleEqual(a.shape, (3, vi))
def test_pad_then_expand_into_symbols(self):
vi = Variable("i", 1, 10).bind(3)
a = Tensor(1).unsqueeze(0).pad((0, 24)).unsqueeze(0).expand((vi, 25))
self.assertEqual(a.shape, (vi, 25))
self.assertEqual(a.reshape(25*vi).shape, (vi*25,))
self.assertEqual(a.reshape(vi*25).shape, (vi*25,))
class TestSymbolicShrink(unittest.TestCase):
def test_shrink_symbols(self):
vi = Variable("i", 1, 5)

View File

@@ -3,7 +3,7 @@ import functools, operator, itertools
from dataclasses import dataclass
from typing import Optional, cast, Sequence
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
@functools.cache
@@ -51,7 +51,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
while len(new_mask) < len(new_shape):
(l, r), next_stride = mask, new_dim * curr_stride
(l, r), next_stride = mask, ssimplify(new_dim * curr_stride)
# need to split mask
if old_dim == next_stride: # simply copy the mask and get next batch for merging
@@ -66,7 +66,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
next_mask = next(r_masks, (0, 1))
# combine if the mask can unfold continuously
if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), ssimplify(old_dim * next(r_shape, 1))
return tuple(reversed(new_mask))