mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
lazy cleanups
This commit is contained in:
@@ -133,9 +133,7 @@ def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBu
|
||||
def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]:
|
||||
real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:None for x in get_buffers(self.op)}
|
||||
op_type : OpType = BinaryOps
|
||||
# not for ProcessingOps
|
||||
psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and len(x.children) <= 1 and len(k.children) <= 1]
|
||||
# TODO: this is broken, the reshape just shouldn't be pushed, not hacked out later
|
||||
if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE and (self.device != "OPENCL" or self.shape[-1] == 4):
|
||||
if psrcs[0][1].optype == ProcessingOps:
|
||||
real_srcs[psrcs[0][0]] = psrcs[0][1].op
|
||||
@@ -149,7 +147,7 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
real_srcs[x] = None
|
||||
real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg)
|
||||
op_type = ReduceOps
|
||||
# reinject the reshape
|
||||
# a reshape is allowed after the ReduceOp, it's a nop in the backend
|
||||
if psrcs[0][0].shape != psrcs[0][1].shape:
|
||||
real_srcs[psrcs[0][0]] = LazyOp(MovementOps.RESHAPE, (real_srcs[psrcs[0][0]],), psrcs[0][0].shape)
|
||||
for x in real_srcs.keys():
|
||||
@@ -247,20 +245,22 @@ class LazyBuffer:
|
||||
if local_st.contiguous and self.shape == local_st.shape:
|
||||
return self
|
||||
|
||||
# two ops in a row is one op
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK] and self.realized is None and self.op.op == op:
|
||||
return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE and self.realized is None and self.op.op == op:
|
||||
return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD and self.realized is None and self.op.op == op:
|
||||
return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
# two ops in a row is one op. merge them if unresolved
|
||||
if self.realized is None and self.op.op == op:
|
||||
if op in [MovementOps.RESHAPE, MovementOps.EXPAND, MovementOps.SHRINK]:
|
||||
return self.op.src[0].movement_op(op, arg)
|
||||
if op == MovementOps.PERMUTE:
|
||||
return self.op.src[0].movement_op(op, tuple(self.op.arg[i] for i in arg))
|
||||
if op == MovementOps.PAD:
|
||||
return self.op.src[0].movement_op(op, tuple((b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)))
|
||||
# TODO: MovementOps.FLIP / MovementOps.STRIDED?
|
||||
|
||||
# some permutes are actually just reshapes
|
||||
if op == MovementOps.PERMUTE and local_st.contiguous:
|
||||
return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
|
||||
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer:
|
||||
if isinstance(y, LazyBuffer):
|
||||
return y.movement_op(op, arg)
|
||||
@@ -271,6 +271,7 @@ class LazyBuffer:
|
||||
# create the buffer
|
||||
ret = LazyBuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg))
|
||||
|
||||
# if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match)
|
||||
# NOTE: if ret is in the cache, it can already be realized
|
||||
if REMOVE_MOVEMENT_NOPS and ret.realized is None and self.realized is None and ret.st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
@@ -295,7 +296,6 @@ class LazyBuffer:
|
||||
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \
|
||||
.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
|
||||
#print(x.st.views, w.st.views)
|
||||
return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \
|
||||
.movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox))
|
||||
elif x.device == "OPENCL":
|
||||
|
||||
Reference in New Issue
Block a user