mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
that
This commit is contained in:
@@ -308,11 +308,11 @@ def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
|
||||
def might_end_axis(idx:UOp):
|
||||
if idx.arg is None: return None
|
||||
# TODO: write a proper cost function here
|
||||
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE, Ops.REDUCE_AXIS} for x in idx.sparents): return None
|
||||
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None
|
||||
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
|
||||
to_end_axis = []
|
||||
for i,a in enumerate(idx.src[1:]):
|
||||
# in RANGEIFY=1, always realize
|
||||
if not (RANGEIFY > 1) or any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
|
||||
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
|
||||
to_end_axis.append(i)
|
||||
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
|
||||
return idx.replace(arg=None)
|
||||
|
||||
Reference in New Issue
Block a user