reset ending ranges

This commit is contained in:
George Hotz
2025-10-17 22:52:17 +08:00
parent c5617ed8cf
commit dad778564c

View File

@@ -206,15 +206,15 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
# we have to (partially) realize here if there's new ranges
if len(_realize_axis): rctx.realize_map[x] = _realize_axis
# if this element has weight and there's ended ranges, we might have to end some other ranges
# if this element is a reduce and there's ended ranges, we might have to end some other ranges
if len(ending_ranges[x]) and x.op in GroupOp.Elementwise.union({Ops.REDUCE_AXIS}):
_realize_axis = rctx.realize_map.get(x, [])
assert _realize_axis is not None
_realize_axis = rctx.realize_map.get(x, []) or []
local_ending_ranges = ending_ranges[x]
for i,r in enumerate(out_rngs):
if i in _realize_axis: continue
if any(any(rr.arg > e.arg for e in local_ending_ranges) for rr in r.ranges) or not (PCONTIG > 1):
if not (PCONTIG > 1) or any(any(rr.arg > e.arg for e in local_ending_ranges) for rr in r.ranges):
_realize_axis.append(i)
ending_ranges[x] = []
if len(_realize_axis):
rctx.realize_map[x] = _realize_axis
out_rngs = tuple([(rctx.new_range(x.shape[i]) if i in _realize_axis else r) for i,r in enumerate(out_rngs)])