mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
lazy cleanup
This commit is contained in:
@@ -2,11 +2,10 @@ from __future__ import annotations
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, Type
|
||||
import os, sys, weakref, importlib, inspect
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
@@ -46,17 +45,10 @@ def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
# this supports late merging an upstream Reduce op and even an Elementwise op above that
|
||||
def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
real_srcs : Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)}
|
||||
if DEBUG >= 4:
|
||||
for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())):
|
||||
if x.optype == ReduceOps and x.realized is None:
|
||||
print("\nHIT", k,x, "UNFOLDABLE" if len(k.children) > 1 or len(x.children) > 1 else str())
|
||||
for tk in k.children: print("k", tk)
|
||||
for tx in x.children: print("x", tx)
|
||||
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
intermediate_shape : Tuple[int, ...] = self.shape
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE:
|
||||
if DEBUG >= 4: print("FOLDING", psrcs[0])
|
||||
if psrcs[0][1].optype == ReduceOps:
|
||||
top = _ast_reduceops(psrcs[0][1])
|
||||
real_srcs[psrcs[0][0]] = top
|
||||
@@ -193,14 +185,10 @@ class LazyBuffer:
|
||||
if self.realized is None and self.op.op == op:
|
||||
# TODO: why is deleting self from children needed? shouldn't GC do it?
|
||||
self.op.src[0].children.discard(self)
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK]:
|
||||
return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE:
|
||||
return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD:
|
||||
return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
if op == MovementOps.FLIP:
|
||||
return self.op.src[0].movement_op(op, tuple(i for i in arg+self.op.arg if not (i in arg and i in self.op.arg)))
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK]: return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE: return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD: return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
if op == MovementOps.FLIP: return self.op.src[0].movement_op(op, tuple(i for i in arg+self.op.arg if not (i in arg and i in self.op.arg)))
|
||||
|
||||
# push permutes before reduce ops
|
||||
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.optype == ReduceOps:
|
||||
@@ -212,8 +200,7 @@ class LazyBuffer:
|
||||
return src.movement_op(op, arg).reduce_op(rop, narg)
|
||||
|
||||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and local_st.contiguous:
|
||||
return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
|
||||
# move permutes before expands
|
||||
if op == MovementOps.PERMUTE and PUSH_PERMUTES and self.realized is None and self.op.op == MovementOps.EXPAND:
|
||||
|
||||
Reference in New Issue
Block a user