lazy cleanup (#2567)

This commit is contained in:
chenyu
2023-12-02 13:21:43 -05:00
committed by GitHub
parent 5068e99d18
commit c8774713c5
2 changed files with 23 additions and 28 deletions

View File

@@ -365,6 +365,7 @@ class TestOpt(unittest.TestCase):
a.realize()
cache_len = len(CacheCollector.cache)
np.testing.assert_allclose(a.numpy(), b.numpy(), rtol=1e-3, atol=1e-5)
# TODO: what does these `return cache_len`` do?
return cache_len
def test_expand_reduce_is_folded_on_same_axis(self):

View File

@@ -37,24 +37,24 @@ def _ast_binaryops(op:LazyOp, shape:Tuple[sint, ...]) -> LazyOp:
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in op.buffers}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape: Tuple[sint, ...] = shape
psrcs = [(buf,root) for buf in op.buffers if len(buf.children) <= 1 and (root:=get_movementroot_contiguous(buf)).optype == ReduceOps and not root.realized and prod(root.shape) == prod(buf.shape) and len(root.children) <= 1]
intermediate_shape = shape
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
if psrc[1].optype == ReduceOps:
top = _ast_reduceops(psrc[1].op)
real_srcs[psrc[0]] = top
# NOTE: right now we can't handle multiple, as we'd have to check for loop
buf,root = psrcs[0]
top = _ast_reduceops(root.op)
real_srcs[buf] = top
real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrc[0].shape != psrc[1].shape:
intermediate_shape = psrc[1].shape
assert psrc[0].shape == shape, f"shape mismatch {psrc[0].shape} != {shape}"
if buf.shape != root.shape:
intermediate_shape = root.shape
assert buf.shape == shape, f"shape mismatch {buf.shape} != {shape}"
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None: real_srcs[x] = x.reshape(intermediate_shape)
for buf,src in real_srcs.items():
if src is None: real_srcs[buf] = buf.reshape(intermediate_shape)
# NOTE: cast the type to remove the Optional
ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
return LazyOp(MovementOps.RESHAPE, (ast, ), shape) if intermediate_shape != shape else ast
@@ -102,16 +102,14 @@ UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2,
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[Buffer]=None, base:Optional[LazyBuffer]=None):
self.st: ShapeTracker = st
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
self._realized: Optional[Buffer] = src
self.device, self.st, self.shape, self.optype, self._dtype, self._realized = device, st, st.shape, optype, dtype, src
self.output_buffer: Optional[Buffer] = None # TODO: do we really need this? or can we just use realized
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: WeakSet[LazyBuffer] = WeakSet()
self.views: WeakSet[LazyBuffer] = WeakSet()
# NOTE: op should be read only after construction of LazyBuffer. it is now with schedule
if op is not None:
self.op: LazyOp = op
self.op = op
for x in op.buffers: x.children.add(self)
assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based"
self._base = base
@@ -127,13 +125,13 @@ class LazyBuffer:
@property
def realized(self): return self.base._realized
@realized.setter
def realized(self, val):
def realized(self, val:Buffer):
assert self._base is None, "no setting realized of based LazyBuffers"
self._realized = val
@property
def dtype(self): return self.base._dtype
@dtype.setter
def dtype(self, val):
def dtype(self, val:DType):
assert self._base is None, "no setting dtype of based LazyBuffers"
self._dtype = val
@@ -241,15 +239,14 @@ class LazyBuffer:
# get outputs now
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0]
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
# push all contiguous to the end of BinaryOps
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
new_srcs: List[LazyBuffer] = []
for x in srcs:
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
x.op.src[0].children.discard(x)
new_srcs.append(cast(LazyBuffer, x.op.src[0]))
else:
new_srcs.append(x)
x = cast(LazyBuffer, x.op.src[0])
new_srcs.append(x)
return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous()
if MERGE_ELEMENTWISE_OPS:
@@ -283,8 +280,7 @@ class LazyBuffer:
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
if (root:=get_movementroot(self)) != self and root.st.contiguous and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
@@ -351,15 +347,13 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
assert isinstance(bx.op.op, MovementOps) and isinstance(bx.op.src[0], LazyBuffer)
mops.append((bx.op.op, bx.op.arg))
assert isinstance(bx.op.src[0], LazyBuffer)
bx = bx.op.src[0]
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(y[0] is not MovementOps.PAD for y in mops) or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops())):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
else:
new_srcs.append(x)
x = bx.op.replace_with_movement_ops(mops[::-1])
new_srcs.append(x)
return tuple(new_srcs)
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {