From 6ca509a485686459bc9e42dbad5dbe68d1801edf Mon Sep 17 00:00:00 2001 From: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> Date: Sun, 27 Aug 2023 17:13:16 +0200 Subject: [PATCH] perf: constant in while in for in busy func (#1688) Co-authored-by: Roelof van Dijk --- tinygrad/lazy.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index bd66dc7a91..8d1c0f784f 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -109,6 +109,8 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype) return ret +UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} + class LazyBuffer: __deletable__ = ('op',) def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None): @@ -310,13 +312,12 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: mops: List[Tuple[MovementOps, Any]] = [] bx = x # backwalk all the movement ops. don't push PAD or EXPAND - while not bx.realized and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op != MovementOps.PAD) and len(bx.children) <= 1: + while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1: assert isinstance(bx.op.op, MovementOps) mops.append((bx.op.op, bx.op.arg)) bx = cast(LazyBuffer, bx.op.src[0]) # NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0 - unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} - if mops and not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op not in unsafe_pad_ops for x in bx.op.get_lazyops())): + if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(x[0] is not MovementOps.PAD for x in mops) or all(x.op not in UNSAFE_PAD_OPS for x in bx.op.get_lazyops())): new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1])) else: new_srcs.append(x)