From 001cc96e254ff47be72e94ebbb73e83a886335ed Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 7 Feb 2023 11:53:21 -0600 Subject: [PATCH] Lazy refactor (#538) * refactor lazy to return ASTs * a lil cleaner * oops, compare ids * gate on GRAPH * cleanups * less calls to log_op * simpler * realize_buffers -> map_buffers * even simpler * think in asts * a lil cleaner * NOOP means contiguous --- tinygrad/graph.py | 10 +++- tinygrad/lazy.py | 115 +++++++++++++++----------------------- tinygrad/llops/ops_cpu.py | 2 +- 3 files changed, 55 insertions(+), 72 deletions(-) diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 83e0de2f02..1c1bf1961f 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -3,7 +3,7 @@ import atexit import itertools from collections import defaultdict from typing import Dict, List -from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType +from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType, LazyOp, get_buffers, get_lazyops from tinygrad.helpers import getenv GRAPH = getenv("GRAPH", 0) @@ -34,7 +34,13 @@ if GRAPH: atexit.register(save_graph_exit) global_num_max = 0 -def log_op(optype : OpType, op : List[Op], ret : DeviceBuffer, inp : List[DeviceBuffer]): +def log_op(ret : DeviceBuffer, ast : LazyOp): + if not DEBUG and not GRAPH: return + op : List[Op] = [x.op for x in get_lazyops(ast)] + inp : List[DeviceBuffer] = get_buffers(ast) + if len(inp) == 1 and inp[0] == ret: return # don't log self loops + oporder = [LoadOps, ProcessingOps, ReduceOps, BinaryOps, UnaryOps, MovementOps] + optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0]) cnts[optype] += 1 if DEBUG >= 3: print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}") diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index d3beda27d9..e7b2e522bb 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -4,7 +4,7 @@ from copy import copy import sys, weakref from tinygrad.helpers import ConvArgs, get_available_llops, prod from tinygrad.shape import ShapeTracker -from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, DEBUG +from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, DEBUG from tinygrad.graph import log_op from tinygrad.helpers import getenv @@ -28,59 +28,22 @@ class Device: vars()[name] = name # **** realize helpers **** -def realize_buffers(real_srcs, x:LazyOp) -> LazyOp: +def map_buffers(real_srcs, x:LazyOp) -> LazyOp: if x in real_srcs: - return realize_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x] - return LazyOp(x.op, tuple(realize_buffers(real_srcs, y) for y in x.src), x.arg) + return map_buffers(real_srcs, real_srcs[x]) if isinstance(real_srcs[x], LazyOp) else real_srcs[x] + return LazyOp(x.op, tuple(map_buffers(real_srcs, y) for y in x.src), x.arg) # **** realize functions **** -# TODO: make all _realize functions return an AST, perhaps unrealized -# NOTE: loadops and movementops aren't valid ASTs and won't become kernels - -def _realize_loadops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], Optional[OpType]]: - if self.op.op == LoadOps.FROMCPU: - return Device._buffers[self.device].fromCPU(self.op.arg), [], LoadOps - elif self.op.op == LoadOps.CONTIGUOUS: - # under the hood, this is an AST or a no op. rename to MetaOps? - real_src = self.op.src[0].realize(self.device) - ret = real_src.contiguous() - return ret, [real_src], LoadOps if id(ret) != id(real_src) else None - else: - raise NotImplementedError(f"unknown LoadOp {self.op.op}") - -def _realize_movementops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: - src = self.op.src[0] - - # fuse RESHAPE and ReduceOps - # TODO: add MetaOps.TOIMAGE instead? - if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1: - return _realize_reduceops_w_shape(src, output_shape = self.op.arg) - - real_src = src.realize(self.device) - return real_src.movement_op(self.op.op, self.op.arg), [real_src], MovementOps - -def _realize_processingops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: - ast = LazyOp(self.op.op, tuple(x.realize(self.device) for x in self.op.src), self.op.arg) - return self.dbuffer.exec_ast(ast), get_buffers(ast), ProcessingOps - -# this supports late merging an upstream Elementwise op -def _realize_reduceops_w_shape(self:LazyBuffer, output_shape=None) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: +def _ast_reduceops(self:LazyBuffer) -> LazyOp: # TODO: this can also corealize a binary op after the reduce, not just before src = self.op.src[0] if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1: - # this is the new version, deprecate _processing_op - real_srcs : Dict[LazyBuffer, DeviceBuffer] = {x:x.realize(self.device) for x in get_buffers(src.op)} - ast = LazyOp(self.op.op, (realize_buffers(real_srcs, src.op),), self.op.arg) - else: - ast = LazyOp(self.op.op, (src.realize(self.device),), self.op.arg) - if output_shape is not None: ast = LazyOp(MovementOps.RESHAPE, (ast, ), output_shape) - return self.dbuffer.exec_ast(ast), get_buffers(ast), ReduceOps -def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: return _realize_reduceops_w_shape(self) + src = src.op + return LazyOp(self.op.op, (src,), self.op.arg) # this supports late merging an upstream Reduce op and even an Elementwise op above that -def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer], OpType]: - real_srcs : Dict[LazyBuffer, Union[None, LazyOp, DeviceBuffer]] = {x:None for x in get_buffers(self.op)} - op_type : OpType = BinaryOps +def _ast_binaryops(self:LazyBuffer) -> LazyOp: + real_srcs : Dict[LazyBuffer, Union[None, LazyOp, LazyBuffer]] = {x:None for x in get_buffers(self.op)} if DEBUG >= 3: for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())): if x.optype in [ProcessingOps,ReduceOps] and x.realized is None: @@ -89,22 +52,15 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer for tx in x.children: print("x", tx) # NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how psrcs : List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype in [ProcessingOps,ReduceOps] and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1] - intermediate_shape = self.shape + intermediate_shape : Tuple[int, ...] = self.shape if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE: if psrcs[0][1].optype == ProcessingOps: - real_srcs[psrcs[0][0]] = psrcs[0][1].op - for x in psrcs[0][1].op.src: - real_srcs[x] = x.realize(self.device) - op_type = ProcessingOps + top = psrcs[0][1].op # _ast_processingops elif psrcs[0][1].optype == ReduceOps: - src = psrcs[0][1].op.src[0] - if MERGE_ELEMENTWISE_INTO_REDUCE and src.realized is None and src.optype == BinaryOps and len(src.children) <= 1: - src = src.op - real_srcs[psrcs[0][0]] = LazyOp(psrcs[0][1].op.op, (src,), psrcs[0][1].op.arg) - for x in get_buffers(real_srcs[psrcs[0][0]]): # type: ignore - # these are the early buffers - real_srcs[x] = x.realize(self.device) - op_type = ReduceOps + top = _ast_reduceops(psrcs[0][1]) + real_srcs[psrcs[0][0]] = top + real_srcs.update({x:x for x in get_buffers(top)}) # the reduce op buffers are not modified + # if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs if psrcs[0][0].shape != psrcs[0][1].shape: intermediate_shape = psrcs[0][1].shape @@ -114,11 +70,8 @@ def _realize_binaryops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer # NOTE: these RESHAPEs will return self if they don't change the shape for x in real_srcs.keys(): if real_srcs[x] is None: - real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape).realize(self.device) - ast = LazyOp(MovementOps.RESHAPE, (realize_buffers(real_srcs, self.op), ), self.shape) - return self.dbuffer.exec_ast(ast), get_buffers(ast), op_type - -_realize = {LoadOps:_realize_loadops, ReduceOps:_realize_reduceops, MovementOps:_realize_movementops, BinaryOps:_realize_binaryops, ProcessingOps:_realize_processingops} + real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape) + return LazyOp(MovementOps.RESHAPE, (map_buffers(real_srcs, self.op), ), self.shape) # **** lazy operations **** @@ -159,14 +112,38 @@ class LazyBuffer: if required_device is not None: assert required_device == self.device if self.realized is None: - # we haven't realized the Buffer yet - self.realized, real_srcs, real_type = _realize[self.optype](self) - # in lazy mode, we don't log until we realize - if real_type is not None: - log_op(real_type, [x.op for x in get_lazyops(self.op)], self.realized, real_srcs) + # get real ops first + if self.op.op == LoadOps.FROMCPU: + self.realized = Device._buffers[self.device].fromCPU(self.op.arg) + ast = LazyOp(self.op.op, tuple()) + elif self.op.op == LoadOps.CONTIGUOUS: + real_src = self.op.src[0].realize(self.device) + self.realized = real_src.contiguous() + ast = LazyOp(self.op.op, (real_src, )) + elif self.optype == MovementOps: + src = self.op.src[0] + + # fuse RESHAPE and ReduceOps + if src.realized is None and src.optype == ReduceOps and self.op.op == MovementOps.RESHAPE and len(src.children) <= 1: + # it's okay to add a RESHAPE to the ast here + ast = LazyOp(MovementOps.RESHAPE, (_ast_reduceops(src), ), self.op.arg) + else: + # movement ops aren't an AST, just run them + real_src = src.realize(self.device) + self.realized = real_src.movement_op(self.op.op, self.op.arg) + ast = LazyOp(self.op.op, (real_src, )) + elif self.optype == ProcessingOps: ast = self.op # no ast modifications for ProcessingOps + elif self.optype == ReduceOps: ast = _ast_reduceops(self) + elif self.optype == BinaryOps: ast = _ast_binaryops(self) + # no need to keep the op after realization del self.op + # run the ast if we still have to, and log the op + if self.realized is None: + self.realized = self.dbuffer.exec_ast(map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)) + log_op(self.realized, ast) + assert self.realized.shape == self.shape assert isinstance(self.realized, Device._buffers[self.device]) return self.realized diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 6dde86e20c..fcde5bd60f 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -6,7 +6,7 @@ from tinygrad.helpers import shape_to_axis class CPUBuffer(np.ndarray, GenericExecAST): fxn_for_op = { - UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(), + UnaryOps.NOOP: lambda x: x[:].contiguous(), UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),