diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index a48f24fc8c..fef10c155a 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -133,9 +133,7 @@ def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBu def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:None for x in get_buffers(self.op)} op_type : OpType = BinaryOps - # not for ProcessingOps psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1] - # TODO: this is broken, the reshape just shouldn't be pushed, not hacked out later if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4): if psrcs[0][1].optype == ProcessingOps: real_srcs[psrcs[0][0]] = psrcs[0][1].op @@ -149,7 +147,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer real_srcs[x] = None real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg) op_type = ReduceOps - # reinject the reshape + # a reshape is allowed after the ReduceOp, it's a nop in the backend if psrcs[0][0].shape != psrcs[0][1].shape: real_srcs[psrcs[0][0]] = LazyOp(MovementOps.RESHAPE, (real_srcs[psrcs[0][0]],), psrcs[0][0].shape) for x in real_srcs.keys(): @@ -247,20 +245,22 @@ class LazyBuffer: if local_st.contiguous and self.shape == local_st.shape: return self - # two ops in a row is one op - if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK] and self.realized is None and self.op.op == op: - return self.op.src[0].movement_op(op, arg) - if op == MovementOps.PERMUTE and self.realized is None and self.op.op == op: - return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg)) - if op == MovementOps.PAD and self.realized is None and self.op.op == op: - return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg))) + # two ops in a row is one op. merge them if unresolved + if self.realized is None and self.op.op == op: + if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK]: + return self.op.src[0].movement_op(op, arg) + if op == MovementOps.PERMUTE: + return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg)) + if op == MovementOps.PAD: + return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg))) + # TODO: MovementOps.FLIP / MovementOps.STRIDED? # some permutes are actually just reshapes if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg)) + # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]: - # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer: if isinstance(y, LazyBuffer): return y.movement_op(op, arg) @@ -271,6 +271,7 @@ class LazyBuffer: # create the buffer ret = LazyBuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg)) + # if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match) # NOTE: if ret is in the cache, it can already be realized if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous: # MovementOps aren't stacked any more, they each have one parent, find the root @@ -295,7 +296,6 @@ class LazyBuffer: (C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx))) w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \ .movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W)) - #print(x.st.views, w.st.views) return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \ .movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox)) elif x.device == "OPENCL":