* handle reshape of contiguous subparts with explicit mask
* remove the add/remove ones logic in reshape
* accomodate ones in accumulate logic
* make multiply commutative
* fix linting
* make mypy happy
* add test for commutative mul
* merge dimensions in shape_strides for 1 range masks
* add offsets for merging
* fix linting
* add back explicit 1 reshapes
* fix mypy errors
* fix accumulate by includng state
* include non-zero stride dimension in acc
* small cleanup
* more compact to_shape_strides
* more logical cleanup
* compress more
* compress reshape mask
* adding some comments
* small bug fix
* improve test coverage
* remove explicit add remove ones
* small bug in test
* enable test_reshape_splitting_combining
* small fix
* 10 lines less to_shape_strides
* shorten reshape mask
* some more cleanup
* more cleanup
* introduce some symbols for compactness
* more symbols
* more cleaner
* lessen symbols, it became less readable
* remove merge_views from view.reshape
* change to_shape_strides to _merge_dims
* improve readability
* fix corner case
* cleanup
* better handling of 1 <= Variable('i',1,10) & new_dim = Variable('i',1,10)
* rewrite _reshape_mask for readability
* fix white space
* add comment
* nice shorthands for readability
* add proof in docs
* small nit
---------
Co-authored-by: chenyu <chenyu@fastmail.com>
3.3 KiB
"View.reshape without symbolic"
This section contains the sketch proof of "Complete, Fast and Correct View.reshapes without using Symbolic". The goal is to reduce multi-views which cost runtime.
- old_shape = (s1,s2,...,si,s(i+1),...,sn)
- old_stride = (st1, st2, ... ,sti, st(i+1), ..., stn)
- merge_old_shape = (p1, p2), where p1 = s1 * ... * si & p2 = s(i+1) * ... * sn,
- new_shape = (k1, ..., kp, k(p+1), ..., kl)
- prod(new_shape) = p1 * p2 (trivial)
- mask and new_mask represent valid indexes before & after reshape respectively.
Assumption
p1 & p2 individually are mergeable (we will discuss later on this) & we cannot merge p1 & p2.
Claim
If prod([k1 ... kp]) < p1 and prod([k1 ... k(p+1)]) > p1, reshape is not possible.
Proof
k(p+1) will require some dimensions from p1 & some from p2, which means p1 & p2 should be mergeable, but they are not.
Conclusion
Hence, reshape is only possible if ∃ a p, where prod([k1 .. kp]) = p1.
Conditions for mergeability
Case 1 - All non-zero strides
They will merge if stx = st(x+1) * s(x+1), where x ∈ [1, ..., i-1, i+1, ..., n-1].
Proof
Lets consider merging of (s1 ... si) -> p1, here we have to get a single new stride corresponding to p1. For which it has to be contiguous.
Case 2 - Some stride is zero
Let stj = 0 & st(j+1) != 0 & s(j+1) > 1, where 1 < j < i.
If sj = 1 , reshape is trivial.
If sj > 1,
- If maskj has range > 1, reshape is not possible, because s(j+1) will need to be repeated at-least once and a single stride can't capture repetition.
- If maskj has range = 1, reshape is possible, since it is virtually shape = 1, with some offset.
Conditions for reshaping mask
Case 1 - Splitting Dimension - Mask shouldn't be cut for successful reshape.
-
Example - [1,2,3,4,5,6,7,8] -> 1,2,3,4], [5,6,7,8 ; mask = ((2,6)) ; new_mask[0] = (0,2) (trivial split).
-
new_mask[1] = not possible. It is only possible if mask spans [1-8] or lies within a single dimension [1-4] or [5-8].
Case 2 - Combining Dimension - Mask should unfold continuously.
-
Example - 1,2],[3,4],[5,6 -> [1,2,3,4,5,6]; mask = ((0,2),(0,2)).
-
new_mask = (0,4); only possible because mask1 span the whole dimension.
-
If mask1 did not span the whole dimension, the only way combining would be possible is if mask0 had range 1 as shown below.
- 1,2,3],[4,5,6 -> [1,2,3,4,5,6]; mask = ((1,2),(0,2)); new_mask = ((3,5))