mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update _reshape_mask for symbolic shape expand (#10726)
* don't merge shape symbolic reshape symbolic * proper fix
This commit is contained in:
@@ -226,6 +226,13 @@ class TestSymbolicExpand(unittest.TestCase):
|
||||
a = a + 1
|
||||
self.assertTupleEqual(a.shape, (3, vi))
|
||||
|
||||
def test_pad_then_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 10).bind(3)
|
||||
a = Tensor(1).unsqueeze(0).pad((0, 24)).unsqueeze(0).expand((vi, 25))
|
||||
self.assertEqual(a.shape, (vi, 25))
|
||||
self.assertEqual(a.reshape(25*vi).shape, (vi*25,))
|
||||
self.assertEqual(a.reshape(vi*25).shape, (vi*25,))
|
||||
|
||||
class TestSymbolicShrink(unittest.TestCase):
|
||||
def test_shrink_symbols(self):
|
||||
vi = Variable("i", 1, 5)
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools, operator, itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, cast, Sequence
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops
|
||||
from tinygrad.uop.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop, Ops, ssimplify
|
||||
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
||||
|
||||
@functools.cache
|
||||
@@ -51,7 +51,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
|
||||
curr_stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
||||
|
||||
while len(new_mask) < len(new_shape):
|
||||
(l, r), next_stride = mask, new_dim * curr_stride
|
||||
(l, r), next_stride = mask, ssimplify(new_dim * curr_stride)
|
||||
|
||||
# need to split mask
|
||||
if old_dim == next_stride: # simply copy the mask and get next batch for merging
|
||||
@@ -66,7 +66,7 @@ def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple
|
||||
next_mask = next(r_masks, (0, 1))
|
||||
# combine if the mask can unfold continuously
|
||||
if mask != (0, old_dim) and l != r and next_mask[1] - next_mask[0] != 1: return None
|
||||
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), old_dim * next(r_shape, 1)
|
||||
mask, old_dim = (next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r), ssimplify(old_dim * next(r_shape, 1))
|
||||
|
||||
return tuple(reversed(new_mask))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user