From f265e8523af5ffa26647e98dc6fa90a173833d79 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 26 Jun 2023 15:01:28 -0700 Subject: [PATCH] movement ops aren't really ops (#1056) --- tinygrad/lazy.py | 50 +++++++++++++++++++++++------------------------ tinygrad/mlops.py | 42 +++++++++++++++++++-------------------- tinygrad/ops.py | 14 +++++++------ 3 files changed, 53 insertions(+), 53 deletions(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index b27fc2d4cc..d191927a37 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -53,7 +53,7 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp: # reshape all the late ops into the output shape # NOTE: these RESHAPEs will return self if they don't change the shape for x in real_srcs.keys(): - if not real_srcs[x]: real_srcs[x] = x.reshape_op(intermediate_shape) + if not real_srcs[x]: real_srcs[x] = x.reshape(intermediate_shape) ast = self.op.map_buffers(real_srcs) return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast @@ -154,8 +154,7 @@ class LazyBuffer: # create a constant with the shape and dtype of self def const_like(self, val) -> LazyBuffer: # NOTE: dtypes.from_np(self.dtype.np) to deal with image types - return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val) \ - .reshape_op((1,)*len(self.shape)).expand_op(self.shape) + return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape) # NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this? def toCPU(self): @@ -178,7 +177,7 @@ class LazyBuffer: # MovementOps aren't stacked any more, they each have one parent, find the root root = get_movementroot(self) if root.st.contiguous and root != self and prod(ret.st.shape) == prod(root.shape): - return root.reshape_op(ret.st.shape) + return root.reshape(ret.st.shape) return ret def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer: @@ -186,27 +185,27 @@ class LazyBuffer: srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype) - def reshape_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if self.shape == arg: return self if not self.realized and self.op.op == MovementOps.RESHAPE: self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? - return self.op.src[0].reshape_op(arg) + return self.op.src[0].reshape(arg) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg) - def pad_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: + def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: if all([b == 0 and e == 0 for b,e in arg]): return self - if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad_op(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)])) + if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)])) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).pad(arg), MovementOps.PAD, arg) - def expand_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + def expand(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if self.shape == arg: return self if not self.realized and self.op.op == MovementOps.EXPAND: - return self.op.src[0].expand_op(arg) + return self.op.src[0].expand(arg) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).expand(arg), MovementOps.EXPAND, arg) - def permute_op(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if arg == tuple(range(len(self.shape))): return self - if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute_op(tuple([self.op.arg[i] for i in arg])) + if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg])) if not self.realized: if PUSH_PERMUTES and self.optype == ReduceOps: # reduceops have one buffer input, permute it @@ -214,29 +213,28 @@ class LazyBuffer: src, rop = self.op.src[0], self.op.op src.children.discard(self) del self # TODO: why doesn't this delete remove it from the children - return src.permute_op(arg).reduce_op(cast(ReduceOps, rop), narg) + return src.permute(arg).reduce_op(cast(ReduceOps, rop), narg) # move permutes before expands (always, this is safe) if self.op.op == MovementOps.EXPAND: - return self.op.src[0].permute_op(arg).expand_op(tuple([self.op.arg[a] for a in arg])) + return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg])) # move permutes before reshapes if we can if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and self.op.src[0].__class__ is LazyBuffer: if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape): self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? - return self.op.src[0].permute_op(tuple(flatten(shape_idx_groups[i] for i in arg))) \ - .reshape_op(ShapeTracker(self.st).permute(arg).shape) + return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(ShapeTracker(self.st).permute(arg).shape) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).permute(arg), MovementOps.PERMUTE, arg) - def shrink_op(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: + def shrink(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: if all([b - a == s for s, (a, b) in zip(self.shape, arg)]): return self - if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink_op(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)])) + if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)])) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).shrink(arg), MovementOps.SHRINK, arg) - def stride_op(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: + def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: local_st = ShapeTracker(self.shape).stride(arg) if self.shape == local_st.shape and local_st.contiguous: return self - if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride_op(tuple(map(operator.mul, arg, self.op.arg))) + if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg))) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).stride(arg), MovementOps.STRIDE, arg) @property @@ -353,10 +351,10 @@ REALIZE_DISPATCHER: Dict[Any, Callable] = { } MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { - MovementOps.RESHAPE: LazyBuffer.reshape_op, - MovementOps.EXPAND: LazyBuffer.expand_op, - MovementOps.SHRINK: LazyBuffer.shrink_op, - MovementOps.PERMUTE: LazyBuffer.permute_op, - MovementOps.PAD: LazyBuffer.pad_op, - MovementOps.STRIDE: LazyBuffer.stride_op, + MovementOps.RESHAPE: LazyBuffer.reshape, + MovementOps.EXPAND: LazyBuffer.expand, + MovementOps.SHRINK: LazyBuffer.shrink, + MovementOps.PERMUTE: LazyBuffer.permute, + MovementOps.PAD: LazyBuffer.pad, + MovementOps.STRIDE: LazyBuffer.stride, } \ No newline at end of file diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index fd3a4fa425..a5ded1e073 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -63,8 +63,8 @@ class Sum(Function): self.input_shape = x.shape return x.reduce_op(ReduceOps.SUM, new_shape) - def backward(self, grad_output): - return grad_output.expand_op(self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: + return grad_output.expand(self.input_shape) class Max(Function): __slots__ = "x", "ret" @@ -74,13 +74,13 @@ class Max(Function): def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand_op(self.x.shape)) + max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)) # sum of locations, averaged - div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand_op(self.x.shape) + div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div) - grad_output_expanded = grad_output.expand_op(self.x.shape) + grad_output_expanded = grad_output.expand(self.x.shape) return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded) # ************* binary ops ************* @@ -96,7 +96,7 @@ class Maximum(Function): self.ret = x.binary_op(BinaryOps.MAX, y) return self.ret - def backward(self, grad_output): + def backward(self, grad_output:LazyBuffer): mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret) eq = self.x.binary_op(BinaryOps.CMPEQ, self.y) splitter = eq.const_like(2).binary_op(BinaryOps.SUB, eq).binary_op(BinaryOps.DIV, eq.const_like(2)) @@ -113,7 +113,7 @@ class Add(Function): grad_output if self.needs_input_grad[1] else None class Sub(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer): + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.binary_op(BinaryOps.SUB, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: @@ -122,7 +122,7 @@ class Sub(Function): class Mul(Function): __slots__ = 'x', 'y' - def forward(self, x:LazyBuffer, y:LazyBuffer): + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y = x, y return x.binary_op(BinaryOps.MUL, y) @@ -132,7 +132,7 @@ class Mul(Function): class Pow(Function): __slots__ = 'x', 'y', 'ret' - def forward(self, x:LazyBuffer, y:LazyBuffer): + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: self.x, self.y, self.ret = x, y, x.binary_op(BinaryOps.POW, y) return self.ret @@ -157,7 +157,7 @@ class Expand(Function): __slots__ = 'input_shape' def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape - return x.expand_op(shape) + return x.expand(shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reduce_op(ReduceOps.SUM, self.input_shape) @@ -166,43 +166,43 @@ class Reshape(Function): __slots__ = 'input_shape' def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer: self.input_shape = x.shape - return x.reshape_op(shape) + return x.reshape(shape) - def backward(self, grad_output): - return grad_output.reshape_op(self.input_shape) + def backward(self, grad_output:LazyBuffer): + return grad_output.reshape(self.input_shape) class Permute(Function): __slots__ = 'input_order' def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer: self.input_order = order - return x.permute_op(order) + return x.permute(order) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.permute_op(argsort(self.input_order)) + return grad_output.permute(argsort(self.input_order)) class Pad(Function): __slots__ = 'narg' def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) - return x.pad_op(arg) + return x.pad(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.shrink_op(self.narg) + return grad_output.shrink(self.narg) class Shrink(Function): __slots__ = 'narg' def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) - return x.shrink_op(arg) + return x.shrink(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.pad_op(self.narg) + return grad_output.pad(self.narg) class Flip(Function): __slots__ = 'arg' def forward(self, x:LazyBuffer, axis:Tuple[int, ...]): self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) - return x.stride_op(self.arg) + return x.stride(self.arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.stride_op(self.arg) \ No newline at end of file + return grad_output.stride(self.arg) \ No newline at end of file diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5db967f605..4f035ba199 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -65,12 +65,14 @@ class LazyOp: @property def optype(self): raise NotImplementedError def realize(self): raise NotImplementedError - def reshape_op(self, _): raise NotImplementedError - def pad_op(self, _): raise NotImplementedError - def expand_op(self, _): raise NotImplementedError - def permute_op(self, _): raise NotImplementedError - def shrink_op(self, _): raise NotImplementedError - def stride_op(self, _): raise NotImplementedError + + # movement ops + def reshape(self, _): raise NotImplementedError + def pad(self, _): raise NotImplementedError + def expand(self, _): raise NotImplementedError + def permute(self, _): raise NotImplementedError + def shrink(self, _): raise NotImplementedError + def stride(self, _): raise NotImplementedError # **************** for Interpreted Buffers ****************