perf: constant in while in for in busy func (#1688)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk
2023-08-27 17:13:16 +02:00
committed by GitHub
parent b89d81330f
commit 6ca509a485

View File

@@ -109,6 +109,8 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
return ret
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None):
@@ -310,13 +312,12 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
mops: List[Tuple[MovementOps, Any]] = []
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while not bx.realized and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op != MovementOps.PAD) and len(bx.children) <= 1:
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
bx = cast(LazyBuffer, bx.op.src[0])
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
if mops and not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op not in unsafe_pad_ops for x in bx.op.get_lazyops())):
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(x[0] is not MovementOps.PAD for x in mops) or all(x.op not in UNSAFE_PAD_OPS for x in bx.op.get_lazyops())):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
else:
new_srcs.append(x)