movement ops aren't really ops (#1056)

This commit is contained in:
George Hotz
2023-06-26 15:01:28 -07:00
committed by GitHub
parent 65cbaa3429
commit f265e8523a
3 changed files with 53 additions and 53 deletions

View File

@@ -53,7 +53,7 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if not real_srcs[x]: real_srcs[x] = x.reshape_op(intermediate_shape)
if not real_srcs[x]: real_srcs[x] = x.reshape(intermediate_shape)
ast = self.op.map_buffers(real_srcs)
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
@@ -154,8 +154,7 @@ class LazyBuffer:
# create a constant with the shape and dtype of self
def const_like(self, val) -> LazyBuffer:
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \
.reshape_op((1,)*len(self.shape)).expand_op(self.shape)
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
def toCPU(self):
@@ -178,7 +177,7 @@ class LazyBuffer:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape):
return root.reshape_op(ret.st.shape)
return root.reshape(ret.st.shape)
return ret
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
@@ -186,27 +185,27 @@ class LazyBuffer:
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
def reshape_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].reshape_op(arg)
return self.op.src[0].reshape(arg)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg)
def pad_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all([b == 0 and e == 0 for b,e in arg]): return self
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad_op(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg)
def expand_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
def expand(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND:
return self.op.src[0].expand_op(arg)
return self.op.src[0].expand(arg)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg)
def permute_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if arg == tuple(range(len(self.shape))): return self
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute_op(tuple([self.op.arg[i] for i in arg]))
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg]))
if not self.realized:
if PUSH_PERMUTES and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
@@ -214,29 +213,28 @@ class LazyBuffer:
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children
return src.permute_op(arg).reduce_op(cast(ReduceOps, rop), narg)
return src.permute(arg).reduce_op(cast(ReduceOps, rop), narg)
# move permutes before expands (always, this is safe)
if self.op.op == MovementOps.EXPAND:
return self.op.src[0].permute_op(arg).expand_op(tuple([self.op.arg[a] for a in arg]))
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
# move permutes before reshapes if we can
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer:
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].permute_op(tuple(flatten(shape_idx_groups[i] for i in arg))) \
.reshape_op(ShapeTracker(self.st).permute(arg).shape)
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape)
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg)
def shrink_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink_op(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg)
def stride_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
local_st = ShapeTracker(self.shape).stride(arg)
if self.shape == local_st.shape and local_st.contiguous: return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride_op(tuple(map(operator.mul, arg, self.op.arg)))
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg)
@property
@@ -353,10 +351,10 @@ REALIZE_DISPATCHER: Dict[Any, Callable] = {
}
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.RESHAPE: LazyBuffer.reshape_op,
MovementOps.EXPAND: LazyBuffer.expand_op,
MovementOps.SHRINK: LazyBuffer.shrink_op,
MovementOps.PERMUTE: LazyBuffer.permute_op,
MovementOps.PAD: LazyBuffer.pad_op,
MovementOps.STRIDE: LazyBuffer.stride_op,
MovementOps.RESHAPE: LazyBuffer.reshape,
MovementOps.EXPAND: LazyBuffer.expand,
MovementOps.SHRINK: LazyBuffer.shrink,
MovementOps.PERMUTE: LazyBuffer.permute,
MovementOps.PAD: LazyBuffer.pad,
MovementOps.STRIDE: LazyBuffer.stride,
}

View File

@@ -63,8 +63,8 @@ class Sum(Function):
self.input_shape = x.shape
return x.reduce_op(ReduceOps.SUM, new_shape)
def backward(self, grad_output):
return grad_output.expand_op(self.input_shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)
class Max(Function):
__slots__ = "x", "ret"
@@ -74,13 +74,13 @@ class Max(Function):
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand_op(self.x.shape))
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand(self.x.shape))
# sum of locations, averaged
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand_op(self.x.shape)
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
grad_output_expanded = grad_output.expand_op(self.x.shape)
grad_output_expanded = grad_output.expand(self.x.shape)
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
# ************* binary ops *************
@@ -96,7 +96,7 @@ class Maximum(Function):
self.ret = x.binary_op(BinaryOps.MAX, y)
return self.ret
def backward(self, grad_output):
def backward(self, grad_output:LazyBuffer):
mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret)
eq = self.x.binary_op(BinaryOps.CMPEQ, self.y)
splitter = eq.const_like(2).binary_op(BinaryOps.SUB, eq).binary_op(BinaryOps.DIV, eq.const_like(2))
@@ -113,7 +113,7 @@ class Add(Function):
grad_output if self.needs_input_grad[1] else None
class Sub(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.binary_op(BinaryOps.SUB, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
@@ -122,7 +122,7 @@ class Sub(Function):
class Mul(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.binary_op(BinaryOps.MUL, y)
@@ -132,7 +132,7 @@ class Mul(Function):
class Pow(Function):
__slots__ = 'x', 'y', 'ret'
def forward(self, x:LazyBuffer, y:LazyBuffer):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y)
return self.ret
@@ -157,7 +157,7 @@ class Expand(Function):
__slots__ = 'input_shape'
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.expand_op(shape)
return x.expand(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
@@ -166,43 +166,43 @@ class Reshape(Function):
__slots__ = 'input_shape'
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.reshape_op(shape)
return x.reshape(shape)
def backward(self, grad_output):
return grad_output.reshape_op(self.input_shape)
def backward(self, grad_output:LazyBuffer):
return grad_output.reshape(self.input_shape)
class Permute(Function):
__slots__ = 'input_order'
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
self.input_order = order
return x.permute_op(order)
return x.permute(order)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.permute_op(argsort(self.input_order))
return grad_output.permute(argsort(self.input_order))
class Pad(Function):
__slots__ = 'narg'
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
return x.pad_op(arg)
return x.pad(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.shrink_op(self.narg)
return grad_output.shrink(self.narg)
class Shrink(Function):
__slots__ = 'narg'
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink_op(arg)
return x.shrink(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.pad_op(self.narg)
return grad_output.pad(self.narg)
class Flip(Function):
__slots__ = 'arg'
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]):
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
return x.stride_op(self.arg)
return x.stride(self.arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.stride_op(self.arg)
return grad_output.stride(self.arg)

View File

@@ -65,12 +65,14 @@ class LazyOp:
@property
def optype(self): raise NotImplementedError
def realize(self): raise NotImplementedError
def reshape_op(self, _): raise NotImplementedError
def pad_op(self, _): raise NotImplementedError
def expand_op(self, _): raise NotImplementedError
def permute_op(self, _): raise NotImplementedError
def shrink_op(self, _): raise NotImplementedError
def stride_op(self, _): raise NotImplementedError
# movement ops
def reshape(self, _): raise NotImplementedError
def pad(self, _): raise NotImplementedError
def expand(self, _): raise NotImplementedError
def permute(self, _): raise NotImplementedError
def shrink(self, _): raise NotImplementedError
def stride(self, _): raise NotImplementedError
# **************** for Interpreted Buffers ****************