diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 101bb2e669..ab60a52332 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -76,13 +76,13 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple return tuple(reversed(new_mask)) -def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: - result = [] +def un1d(shape:Tuple[sint, ...], offset:sint) -> List[sint]: + # find the position of offset on each dimension based on shape + ret = [] for stride in strides_for_shape(shape): - here = offs // stride if stride != 0 else 0 - result.append(here) - offs -= here * stride - return result + ret.append(offset // stride if stride != 0 else 0) + offset -= ret[-1] * stride + return ret @dataclass(frozen=True) class View: @@ -123,7 +123,7 @@ class View: strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape) # canonicalize 0 in shape if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True) - # canonicalize empty mask + # canonicalize no-op mask if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None # if any dimension has size >1, but is masked such that only one index in the dimension is unmasked # then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset @@ -167,12 +167,12 @@ class View: if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret if vm1.mask: for b,e in vm1.mask: - if resolve(b >= e, False): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape)) + if not resolve(b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape)) return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape))) # Project vm1's offset and strides on to vm2. origin = un1d(vm2.shape, vm1.offset) - terms: List[List[Tuple[int, sint]]] = [[] for _ in origin] + terms: List[List[Tuple[int, sint]]] = [[] for _ in vm2.shape] strides: List[sint] = [0] * len(vm1.shape) for d1, st in enumerate(vm1.strides): if st == 0: continue @@ -195,8 +195,7 @@ class View: merged_size, merged_term = 1, UOp.const(dtypes.int, 0) if resolve(merged_term != 0): return None if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape: - reshaped_vm2 = vm2.reshape(vm2_shape) - if reshaped_vm2 is None: return None + if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1 if vm2.mask: @@ -212,7 +211,7 @@ class View: else: bad = True continue d1, s1 = term[0] - if not isinstance(s1, int) or not isinstance(newe[d1], int): + if not all_int([s1, newe[d1]]): bad = True continue newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))