no repeating work with globals

This commit is contained in:
George Hotz
2025-09-30 17:35:25 +08:00
parent 348188a0b5
commit e73f0bbf98
2 changed files with 10 additions and 9 deletions

View File

@@ -72,8 +72,11 @@ class Scheduler:
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
# filter any not in reduces
# TODO: reenable this
"""
reduce_rngs = [x.ranges for x in self.ast.toposort() if x.op is Ops.REDUCE]
for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
"""
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else []

View File

@@ -310,15 +310,13 @@ def might_end_axis(idx:UOp):
# TODO: write a proper cost function here
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
if RANGEIFY > 1:
to_end_axis = []
for i,a in enumerate(idx.src[1:]):
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
to_end_axis.append(i)
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
return idx.replace(arg=None)
# in RANGEIFY=1, always realize
return idx.replace(src=(idx.src[0].realize(),)+idx.src[1:], arg=None)
to_end_axis = []
for i,a in enumerate(idx.src[1:]):
# in RANGEIFY=1, always realize
if not (RANGEIFY > 1) or any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
to_end_axis.append(i)
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
return idx.replace(arg=None)
def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}")