reorder and comments to reshape [pr] (#8223)

something feels wrong... contructing a counter example next
This commit is contained in:
chenyu
2024-12-13 17:02:27 -05:00
committed by GitHub
parent c1b79c118f
commit eb0e5a14fd
2 changed files with 7 additions and 3 deletions

View File

@@ -57,6 +57,8 @@ class TestMergeDims(unittest.TestCase):
self.assertEqual(merge_dims((2, 3), (0, 1), ((1, 2), (0, 2))), ((6, 1, 3),))
# shift mask on stride 0
self.assertEqual(merge_dims((2, 3), (0, 1), ((0, 1), (0, 2))), ((6, 1, 3),))
# permute 0 / 1
self.assertEqual(merge_dims((3, 2), (1, 0), ((0, 2), (1, 2))), ((3, 1, 3), (2, 0, 0)))
if __name__ == '__main__':
unittest.main()

View File

@@ -18,16 +18,16 @@ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
@functools.lru_cache(maxsize=None)
def merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o expand (zero stride)), ...]
if not shape: return ()
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
# merge this dim to next dim if size is 1
merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
last_s, last_st, last_pre_expand_s = ret[-1]
# always merge 1
if s == 1: continue
last_s, last_st, last_pre_expand_s = ret[-1]
# merge last dim with this dim if merging or strides matched
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0)
else: ret.append((s, st, s if st != 0 else 0))
@@ -325,7 +325,9 @@ class View:
# TODO: third resolve shouldn't be needed
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
strides.append(new_stride)
if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
acc = acc * new_dim
# TODO: likely a bug, what if expand happened before acc < real_dim happens?
if resolve(new_dim != 1): new_stride *= (new_dim if resolve(acc < real_dim) else 0)
if resolve(acc != merged_dim): return None
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None: