lazy cleanups

This commit is contained in:
George Hotz
2022-10-28 08:43:43 -07:00
parent d02f8f9bc0
commit 8517b69bfb

View File

@@ -133,9 +133,7 @@ def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBu
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:None for x in get_buffers(self.op)}
op_type : OpType = BinaryOps
# not for ProcessingOps
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
# TODO: this is broken, the reshape just shouldn't be pushed, not hacked out later
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
if psrcs[0][1].optype == ProcessingOps:
real_srcs[psrcs[0][0]] = psrcs[0][1].op
@@ -149,7 +147,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
real_srcs[x] = None
real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg)
op_type = ReduceOps
# reinject the reshape
# a reshape is allowed after the ReduceOp, it's a nop in the backend
if psrcs[0][0].shape != psrcs[0][1].shape:
real_srcs[psrcs[0][0]] = LazyOp(MovementOps.RESHAPE, (real_srcs[psrcs[0][0]],), psrcs[0][0].shape)
for x in real_srcs.keys():
@@ -247,20 +245,22 @@ class LazyBuffer:
if local_st.contiguous and self.shape == local_st.shape:
return self
# two ops in a row is one op
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK] and self.realized is None and self.op.op == op:
return self.op.src[0].movement_op(op, arg)
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == op:
return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
if op == MovementOps.PAD and self.realized is None and self.op.op == op:
return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
# two ops in a row is one op. merge them if unresolved
if self.realized is None and self.op.op == op:
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK]:
return self.op.src[0].movement_op(op, arg)
if op == MovementOps.PERMUTE:
return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
if op == MovementOps.PAD:
return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
# TODO: MovementOps.FLIP / MovementOps.STRIDED?
# some permutes are actually just reshapes
if op == MovementOps.PERMUTE and local_st.contiguous:
return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
if isinstance(y, LazyBuffer):
return y.movement_op(op, arg)
@@ -271,6 +271,7 @@ class LazyBuffer:
# create the buffer
ret = LazyBuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg))
# if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match)
# NOTE: if ret is in the cache, it can already be realized
if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
@@ -295,7 +296,6 @@ class LazyBuffer:
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \
.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
#print(x.st.views, w.st.views)
return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \
.movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox))
elif x.device == "OPENCL":