diff --git a/tinygrad/schedule/indexing.py b/tinygrad/schedule/indexing.py index 4147267c36..9788d745b5 100644 --- a/tinygrad/schedule/indexing.py +++ b/tinygrad/schedule/indexing.py @@ -206,15 +206,15 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]: # we have to (partially) realize here if there's new ranges if len(_realize_axis): rctx.realize_map[x] = _realize_axis - # if this element has weight and there's ended ranges, we might have to end some other ranges + # if this element is a reduce and there's ended ranges, we might have to end some other ranges if len(ending_ranges[x]) and x.op in GroupOp.Elementwise.union({Ops.REDUCE_AXIS}): - _realize_axis = rctx.realize_map.get(x, []) - assert _realize_axis is not None + _realize_axis = rctx.realize_map.get(x, []) or [] local_ending_ranges = ending_ranges[x] for i,r in enumerate(out_rngs): if i in _realize_axis: continue - if any(any(rr.arg > e.arg for e in local_ending_ranges) for rr in r.ranges) or not (PCONTIG > 1): + if not (PCONTIG > 1) or any(any(rr.arg > e.arg for e in local_ending_ranges) for rr in r.ranges): _realize_axis.append(i) + ending_ranges[x] = [] if len(_realize_axis): rctx.realize_map[x] = _realize_axis out_rngs = tuple([(rctx.new_range(x.shape[i]) if i in _realize_axis else r) for i,r in enumerate(out_rngs)])