From 49f999d91978a3cbb65bae4d9003a7dda7938949 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 9 Jun 2025 13:35:02 -0700 Subject: [PATCH] update _reshape_mask for symbolic shape expand (#10726) * don't merge shape symbolic reshape symbolic * proper fix --- test/test_symbolic_shapetracker.py | 7 +++++++ tinygrad/shape/view.py | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 4cd542feb4..ed065d0576 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -226,6 +226,13 @@ class TestSymbolicExpand(unittest.TestCase): a = a + 1 self.assertTupleEqual(a.shape, (3, vi)) + def test_pad_then_expand_into_symbols(self): + vi = Variable("i", 1, 10).bind(3) + a = Tensor(1).unsqueeze(0).pad((0, 24)).unsqueeze(0).expand((vi, 25)) + self.assertEqual(a.shape, (vi, 25)) + self.assertEqual(a.reshape(25*vi).shape, (vi*25,)) + self.assertEqual(a.reshape(vi*25).shape, (vi*25,)) + class TestSymbolicShrink(unittest.TestCase): def test_shrink_symbols(self): vi = Variable("i", 1, 5) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 2fb971b4cc..e140238aee 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -3,7 +3,7 @@ import functools, operator, itertools from dataclasses import dataclass from typing import Optional, cast, Sequence from tinygrad.dtype import dtypes -from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops +from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv @functools.cache @@ -51,7 +51,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1)) while len(new_mask) < len(new_shape): - (l, r), next_stride = mask, new_dim * curr_stride + (l, r), next_stride = mask, ssimplify(new_dim * curr_stride) # need to split mask if old_dim == next_stride: # simply copy the mask and get next batch for merging @@ -66,7 +66,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple next_mask = next(r_masks, (0, 1)) # combine if the mask can unfold continuously if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None - mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1) + mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), ssimplify(old_dim * next(r_shape, 1)) return tuple(reversed(new_mask))