mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
no repeating work with globals
This commit is contained in:
@@ -72,8 +72,11 @@ class Scheduler:
|
||||
for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
|
||||
|
||||
# filter any not in reduces
|
||||
# TODO: reenable this
|
||||
"""
|
||||
reduce_rngs = [x.ranges for x in self.ast.toposort() if x.op is Ops.REDUCE]
|
||||
for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls])
|
||||
"""
|
||||
|
||||
return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else []
|
||||
|
||||
|
||||
@@ -310,15 +310,13 @@ def might_end_axis(idx:UOp):
|
||||
# TODO: write a proper cost function here
|
||||
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None
|
||||
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
|
||||
if RANGEIFY > 1:
|
||||
to_end_axis = []
|
||||
for i,a in enumerate(idx.src[1:]):
|
||||
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
|
||||
to_end_axis.append(i)
|
||||
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
|
||||
return idx.replace(arg=None)
|
||||
# in RANGEIFY=1, always realize
|
||||
return idx.replace(src=(idx.src[0].realize(),)+idx.src[1:], arg=None)
|
||||
to_end_axis = []
|
||||
for i,a in enumerate(idx.src[1:]):
|
||||
# in RANGEIFY=1, always realize
|
||||
if not (RANGEIFY > 1) or any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
|
||||
to_end_axis.append(i)
|
||||
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
|
||||
return idx.replace(arg=None)
|
||||
|
||||
def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user