diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index bcbbcbb355..50fdaae355 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -72,8 +72,11 @@ class Scheduler: for ls in local_store_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) # filter any not in reduces + # TODO: reenable this + """ reduce_rngs = [x.ranges for x in self.ast.toposort() if x.op is Ops.REDUCE] for ls in reduce_rngs: store_rngs = tuple([x for x in store_rngs if x in ls]) + """ return [x for x in UOp.sink(*store_rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.LOOP] if store_rngs else [] diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index a6afd0c6e8..b0405f96de 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -310,15 +310,13 @@ def might_end_axis(idx:UOp): # TODO: write a proper cost function here if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None - if RANGEIFY > 1: - to_end_axis = [] - for i,a in enumerate(idx.src[1:]): - if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE): - to_end_axis.append(i) - if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None) - return idx.replace(arg=None) - # in RANGEIFY=1, always realize - return idx.replace(src=(idx.src[0].realize(),)+idx.src[1:], arg=None) + to_end_axis = [] + for i,a in enumerate(idx.src[1:]): + # in RANGEIFY=1, always realize + if not (RANGEIFY > 1) or any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE): + to_end_axis.append(i) + if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None) + return idx.replace(arg=None) def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}")