This commit is contained in:
George Hotz
2025-09-30 18:50:04 +08:00
parent 036803f8e7
commit c963e44ea0

View File

@@ -308,11 +308,11 @@ def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
def might_end_axis(idx:UOp):
if idx.arg is None: return None
# TODO: write a proper cost function here
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE, Ops.REDUCE_AXIS} for x in idx.sparents): return None
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
to_end_axis = []
for i,a in enumerate(idx.src[1:]):
# in RANGEIFY=1, always realize
if not (RANGEIFY > 1) or any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
to_end_axis.append(i)
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
return idx.replace(arg=None)