reorg opts, nicer graph

This commit is contained in:
George Hotz
2022-07-02 22:29:09 -07:00
parent f9a8412b68
commit e822aae9ec
2 changed files with 17 additions and 9 deletions

View File

@@ -12,14 +12,20 @@ from enum import Enum
LoadOps = Enum("LoadOps", ["FROMCPU", "CONTIGUOUS"])
Op = Union[BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps]
MERGE_MOVEMENT_OPS = True
SHUFFLE_MOVEMENT_OPS = True
SHUFFLE_SLICE_OPS = False # NOTE: 0/0 is NaN if you slice, so this can change the output
REMOVE_MOVEMENT_NOPS = True
MERGE_ELEMENTWISE_OPS = True
MERGE_ELEMENTWISE_INTO_CONV_OUTPUT = False # TODO: this should be done at resolve time
FOLD_CONSTANTS_INTO_KERNELS = True
# -O1
CACHE_LAZYBUFFERS = True # this leaks tons of memory. TODO: only cache unresolved LazyBuffers
MERGE_MOVEMENT_OPS = True
MERGE_UNARY_OPS = True
REMOVE_MOVEMENT_NOPS = True
# -O2
SHUFFLE_MOVEMENT_OPS = True
MERGE_ELEMENTWISE_OPS = True
FOLD_CONSTANTS_INTO_KERNELS = True # should depend on the JIT if it's a float or a number
# -O3
SHUFFLE_SLICE_OPS = False # NOTE: 0/0 is NaN if you slice, so this can change the output
MERGE_ELEMENTWISE_INTO_CONV_OUTPUT = False # TODO: should this be done at resolve time?
class LazyOp(NamedTuple):
op: Op
@@ -159,7 +165,7 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer:
#srcs = [srcs[0], srcs[1].op]
return LazyBuffer(out_shape, ProcessingOps, LazyOp(op, tuple(srcs)))
if MERGE_ELEMENTWISE_OPS:
if (MERGE_UNARY_OPS and len(srcs) == 1) or MERGE_ELEMENTWISE_OPS:
# remove the buffers from any BinaryOps that feed into this
srcs = tuple(x.op if x.optype == BinaryOps and x.realized is None else x for x in srcs)
@@ -196,6 +202,7 @@ def _realize_binary_op(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBu
real_dict[s] = f"({root.op.arg[0]}f)"
else:
# TODO: this is a terrible hack, and it's very unclear if it's always right
# can't we just replace the getter function?
inline_valid = s.st.expr().replace("valid=valid && ", "").replace(";idx=0", "").replace("//", "/").replace("idx", "gid")
if ';' not in inline_valid:
real_dict[s] = f"(({inline_valid}) * {str(root.op.arg[0])}f)"

View File

@@ -46,7 +46,8 @@ def log_op(optype, op, ret, inp):
for x in inp:
if not isinstance(op, list): op = [op]
if GRAPH == 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
if len(op) <= 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
elif len(op) <= 4: sop = '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
else: sop = str(len(op))
G.add_edge(nm(x), nm(ret), label=sop)
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)