mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
reorder and comments to reshape [pr] (#8223)
something feels wrong... contructing a counter example next
This commit is contained in:
@@ -57,6 +57,8 @@ class TestMergeDims(unittest.TestCase):
|
||||
self.assertEqual(merge_dims((2, 3), (0, 1), ((1, 2), (0, 2))), ((6, 1, 3),))
|
||||
# shift mask on stride 0
|
||||
self.assertEqual(merge_dims((2, 3), (0, 1), ((0, 1), (0, 2))), ((6, 1, 3),))
|
||||
# permute 0 / 1
|
||||
self.assertEqual(merge_dims((3, 2), (1, 0), ((0, 2), (1, 2))), ((3, 1, 3), (2, 0, 0)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -18,16 +18,16 @@ def strides_for_shape(shape:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tuple[Tuple[int, int], ...]]=None) -> Tuple[Tuple[int, int, int], ...]:
|
||||
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o zero stride), ...]
|
||||
# merge contiguous sub-parts or zero strided dims. ret = Tuple[(merged_size, stride, merged size w/o expand (zero stride)), ...]
|
||||
if not shape: return ()
|
||||
assert len(shape) == len(strides) and (mask is None or len(shape) == len(mask))
|
||||
ret = [(shape[0], strides[0], shape[0] if strides[0] != 0 else 0)]
|
||||
# merge this dim to next dim if size is 1
|
||||
merging = (mask[0][1] - mask[0][0] == 1) if mask is not None else shape[0] == 1
|
||||
for i, (s, st) in enumerate(zip(shape[1:], strides[1:]), start=1):
|
||||
last_s, last_st, last_pre_expand_s = ret[-1]
|
||||
# always merge 1
|
||||
if s == 1: continue
|
||||
last_s, last_st, last_pre_expand_s = ret[-1]
|
||||
# merge last dim with this dim if merging or strides matched
|
||||
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0)
|
||||
else: ret.append((s, st, s if st != 0 else 0))
|
||||
@@ -325,7 +325,9 @@ class View:
|
||||
# TODO: third resolve shouldn't be needed
|
||||
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
|
||||
strides.append(new_stride)
|
||||
if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
|
||||
acc = acc * new_dim
|
||||
# TODO: likely a bug, what if expand happened before acc < real_dim happens?
|
||||
if resolve(new_dim != 1): new_stride *= (new_dim if resolve(acc < real_dim) else 0)
|
||||
if resolve(acc != merged_dim): return None
|
||||
|
||||
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
|
||||
|
||||
Reference in New Issue
Block a user