mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
perf: constant in while in for in busy func (#1688)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
@@ -109,6 +109,8 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty
|
||||
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype)
|
||||
return ret
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
|
||||
class LazyBuffer:
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None):
|
||||
@@ -310,13 +312,12 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
||||
mops: List[Tuple[MovementOps, Any]] = []
|
||||
bx = x
|
||||
# backwalk all the movement ops. don't push PAD or EXPAND
|
||||
while not bx.realized and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op != MovementOps.PAD) and len(bx.children) <= 1:
|
||||
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
|
||||
assert isinstance(bx.op.op, MovementOps)
|
||||
mops.append((bx.op.op, bx.op.arg))
|
||||
bx = cast(LazyBuffer, bx.op.src[0])
|
||||
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
|
||||
unsafe_pad_ops = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
if mops and not bx.realized and bx.optype == BinaryOps and len(bx.children) <= 1 and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op not in unsafe_pad_ops for x in bx.op.get_lazyops())):
|
||||
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(x[0] is not MovementOps.PAD for x in mops) or all(x.op not in UNSAFE_PAD_OPS for x in bx.op.get_lazyops())):
|
||||
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
|
||||
else:
|
||||
new_srcs.append(x)
|
||||
|
||||
Reference in New Issue
Block a user