From e822aae9ec337dfa918aa1a0bbf71cda131aa2ed Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 2 Jul 2022 22:29:09 -0700 Subject: [PATCH] reorg opts, nicer graph --- accel/lazy/ops_lazy.py | 23 +++++++++++++++-------- tinygrad/ops.py | 3 ++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/accel/lazy/ops_lazy.py b/accel/lazy/ops_lazy.py index b3a09770c4..94423a53af 100644 --- a/accel/lazy/ops_lazy.py +++ b/accel/lazy/ops_lazy.py @@ -12,14 +12,20 @@ from enum import Enum LoadOps = Enum("LoadOps", ["FROMCPU", "CONTIGUOUS"]) Op = Union[BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps] -MERGE_MOVEMENT_OPS = True -SHUFFLE_MOVEMENT_OPS = True -SHUFFLE_SLICE_OPS = False # NOTE: 0/0 is NaN if you slice, so this can change the output -REMOVE_MOVEMENT_NOPS = True -MERGE_ELEMENTWISE_OPS = True -MERGE_ELEMENTWISE_INTO_CONV_OUTPUT = False # TODO: this should be done at resolve time -FOLD_CONSTANTS_INTO_KERNELS = True +# -O1 CACHE_LAZYBUFFERS = True # this leaks tons of memory. TODO: only cache unresolved LazyBuffers +MERGE_MOVEMENT_OPS = True +MERGE_UNARY_OPS = True +REMOVE_MOVEMENT_NOPS = True + +# -O2 +SHUFFLE_MOVEMENT_OPS = True +MERGE_ELEMENTWISE_OPS = True +FOLD_CONSTANTS_INTO_KERNELS = True # should depend on the JIT if it's a float or a number + +# -O3 +SHUFFLE_SLICE_OPS = False # NOTE: 0/0 is NaN if you slice, so this can change the output +MERGE_ELEMENTWISE_INTO_CONV_OUTPUT = False # TODO: should this be done at resolve time? class LazyOp(NamedTuple): op: Op @@ -159,7 +165,7 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer: #srcs = [srcs[0], srcs[1].op] return LazyBuffer(out_shape, ProcessingOps, LazyOp(op, tuple(srcs))) - if MERGE_ELEMENTWISE_OPS: + if (MERGE_UNARY_OPS and len(srcs) == 1) or MERGE_ELEMENTWISE_OPS: # remove the buffers from any BinaryOps that feed into this srcs = tuple(x.op if x.optype == BinaryOps and x.realized is None else x for x in srcs) @@ -196,6 +202,7 @@ def _realize_binary_op(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBu real_dict[s] = f"({root.op.arg[0]}f)" else: # TODO: this is a terrible hack, and it's very unclear if it's always right + # can't we just replace the getter function? inline_valid = s.st.expr().replace("valid=valid && ", "").replace(";idx=0", "").replace("//", "/").replace("idx", "gid") if ';' not in inline_valid: real_dict[s] = f"(({inline_valid}) * {str(root.op.arg[0])}f)" diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bada0f3447..cba09007ab 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -46,7 +46,8 @@ def log_op(optype, op, ret, inp): for x in inp: if not isinstance(op, list): op = [op] - if GRAPH == 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1]) + if len(op) <= 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1]) + elif len(op) <= 4: sop = '.'.join([str(y).split(".")[1][0:2] for y in op][::-1]) else: sop = str(len(op)) G.add_edge(nm(x), nm(ret), label=sop) if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)