mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
movement ops aren't really ops (#1056)
This commit is contained in:
@@ -53,7 +53,7 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
||||
# reshape all the late ops into the output shape
|
||||
# NOTE: these RESHAPEs will return self if they don't change the shape
|
||||
for x in real_srcs.keys():
|
||||
if not real_srcs[x]: real_srcs[x] = x.reshape_op(intermediate_shape)
|
||||
if not real_srcs[x]: real_srcs[x] = x.reshape(intermediate_shape)
|
||||
ast = self.op.map_buffers(real_srcs)
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
||||
|
||||
@@ -154,8 +154,7 @@ class LazyBuffer:
|
||||
# create a constant with the shape and dtype of self
|
||||
def const_like(self, val) -> LazyBuffer:
|
||||
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
|
||||
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \
|
||||
.reshape_op((1,)*len(self.shape)).expand_op(self.shape)
|
||||
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
|
||||
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
|
||||
def toCPU(self):
|
||||
@@ -178,7 +177,7 @@ class LazyBuffer:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
|
||||
return root.reshape_op(ret.st.shape)
|
||||
return root.reshape(ret.st.shape)
|
||||
return ret
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
@@ -186,27 +185,27 @@ class LazyBuffer:
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
|
||||
|
||||
def reshape_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].reshape_op(arg)
|
||||
return self.op.src[0].reshape(arg)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
|
||||
|
||||
def pad_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all([b == 0 and e == 0 for b,e in arg]): return self
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad_op(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
def expand_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
def expand(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].expand_op(arg)
|
||||
return self.op.src[0].expand(arg)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg)
|
||||
|
||||
def permute_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if arg == tuple(range(len(self.shape))): return self
|
||||
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute_op(tuple([self.op.arg[i] for i in arg]))
|
||||
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg]))
|
||||
if not self.realized:
|
||||
if PUSH_PERMUTES and self.optype == ReduceOps:
|
||||
# reduceops have one buffer input, permute it
|
||||
@@ -214,29 +213,28 @@ class LazyBuffer:
|
||||
src, rop = self.op.src[0], self.op.op
|
||||
src.children.discard(self)
|
||||
del self # TODO: why doesn't this delete remove it from the children
|
||||
return src.permute_op(arg).reduce_op(cast(ReduceOps, rop), narg)
|
||||
return src.permute(arg).reduce_op(cast(ReduceOps, rop), narg)
|
||||
|
||||
# move permutes before expands (always, this is safe)
|
||||
if self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].permute_op(arg).expand_op(tuple([self.op.arg[a] for a in arg]))
|
||||
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
|
||||
|
||||
# move permutes before reshapes if we can
|
||||
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer:
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].permute_op(tuple(flatten(shape_idx_groups[i] for i in arg))) \
|
||||
.reshape_op(ShapeTracker(self.st).permute(arg).shape)
|
||||
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
def shrink_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink_op(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
def stride_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
local_st = ShapeTracker(self.shape).stride(arg)
|
||||
if self.shape == local_st.shape and local_st.contiguous: return self
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride_op(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
@property
|
||||
@@ -353,10 +351,10 @@ REALIZE_DISPATCHER: Dict[Any, Callable] = {
|
||||
}
|
||||
|
||||
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
|
||||
MovementOps.RESHAPE: LazyBuffer.reshape_op,
|
||||
MovementOps.EXPAND: LazyBuffer.expand_op,
|
||||
MovementOps.SHRINK: LazyBuffer.shrink_op,
|
||||
MovementOps.PERMUTE: LazyBuffer.permute_op,
|
||||
MovementOps.PAD: LazyBuffer.pad_op,
|
||||
MovementOps.STRIDE: LazyBuffer.stride_op,
|
||||
MovementOps.RESHAPE: LazyBuffer.reshape,
|
||||
MovementOps.EXPAND: LazyBuffer.expand,
|
||||
MovementOps.SHRINK: LazyBuffer.shrink,
|
||||
MovementOps.PERMUTE: LazyBuffer.permute,
|
||||
MovementOps.PAD: LazyBuffer.pad,
|
||||
MovementOps.STRIDE: LazyBuffer.stride,
|
||||
}
|
||||
@@ -63,8 +63,8 @@ class Sum(Function):
|
||||
self.input_shape = x.shape
|
||||
return x.reduce_op(ReduceOps.SUM, new_shape)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.expand_op(self.input_shape)
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
__slots__ = "x", "ret"
|
||||
@@ -74,13 +74,13 @@ class Max(Function):
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand_op(self.x.shape))
|
||||
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand(self.x.shape))
|
||||
|
||||
# sum of locations, averaged
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand_op(self.x.shape)
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
|
||||
|
||||
grad_output_expanded = grad_output.expand_op(self.x.shape)
|
||||
grad_output_expanded = grad_output.expand(self.x.shape)
|
||||
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
||||
|
||||
# ************* binary ops *************
|
||||
@@ -96,7 +96,7 @@ class Maximum(Function):
|
||||
self.ret = x.binary_op(BinaryOps.MAX, y)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output):
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret)
|
||||
eq = self.x.binary_op(BinaryOps.CMPEQ, self.y)
|
||||
splitter = eq.const_like(2).binary_op(BinaryOps.SUB, eq).binary_op(BinaryOps.DIV, eq.const_like(2))
|
||||
@@ -113,7 +113,7 @@ class Add(Function):
|
||||
grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
return x.binary_op(BinaryOps.SUB, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
@@ -122,7 +122,7 @@ class Sub(Function):
|
||||
|
||||
class Mul(Function):
|
||||
__slots__ = 'x', 'y'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y = x, y
|
||||
return x.binary_op(BinaryOps.MUL, y)
|
||||
|
||||
@@ -132,7 +132,7 @@ class Mul(Function):
|
||||
|
||||
class Pow(Function):
|
||||
__slots__ = 'x', 'y', 'ret'
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
|
||||
return self.ret
|
||||
|
||||
@@ -157,7 +157,7 @@ class Expand(Function):
|
||||
__slots__ = 'input_shape'
|
||||
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.expand_op(shape)
|
||||
return x.expand(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
|
||||
@@ -166,43 +166,43 @@ class Reshape(Function):
|
||||
__slots__ = 'input_shape'
|
||||
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.reshape_op(shape)
|
||||
return x.reshape(shape)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.reshape_op(self.input_shape)
|
||||
def backward(self, grad_output:LazyBuffer):
|
||||
return grad_output.reshape(self.input_shape)
|
||||
|
||||
class Permute(Function):
|
||||
__slots__ = 'input_order'
|
||||
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_order = order
|
||||
return x.permute_op(order)
|
||||
return x.permute(order)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.permute_op(argsort(self.input_order))
|
||||
return grad_output.permute(argsort(self.input_order))
|
||||
|
||||
class Pad(Function):
|
||||
__slots__ = 'narg'
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
||||
return x.pad_op(arg)
|
||||
return x.pad(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.shrink_op(self.narg)
|
||||
return grad_output.shrink(self.narg)
|
||||
|
||||
class Shrink(Function):
|
||||
__slots__ = 'narg'
|
||||
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
||||
return x.shrink_op(arg)
|
||||
return x.shrink(arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.pad_op(self.narg)
|
||||
return grad_output.pad(self.narg)
|
||||
|
||||
class Flip(Function):
|
||||
__slots__ = 'arg'
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]):
|
||||
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
|
||||
return x.stride_op(self.arg)
|
||||
return x.stride(self.arg)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.stride_op(self.arg)
|
||||
return grad_output.stride(self.arg)
|
||||
@@ -65,12 +65,14 @@ class LazyOp:
|
||||
@property
|
||||
def optype(self): raise NotImplementedError
|
||||
def realize(self): raise NotImplementedError
|
||||
def reshape_op(self, _): raise NotImplementedError
|
||||
def pad_op(self, _): raise NotImplementedError
|
||||
def expand_op(self, _): raise NotImplementedError
|
||||
def permute_op(self, _): raise NotImplementedError
|
||||
def shrink_op(self, _): raise NotImplementedError
|
||||
def stride_op(self, _): raise NotImplementedError
|
||||
|
||||
# movement ops
|
||||
def reshape(self, _): raise NotImplementedError
|
||||
def pad(self, _): raise NotImplementedError
|
||||
def expand(self, _): raise NotImplementedError
|
||||
def permute(self, _): raise NotImplementedError
|
||||
def shrink(self, _): raise NotImplementedError
|
||||
def stride(self, _): raise NotImplementedError
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
|
||||
Reference in New Issue
Block a user