mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
reorg opts, nicer graph
This commit is contained in:
@@ -12,14 +12,20 @@ from enum import Enum
|
||||
LoadOps = Enum("LoadOps", ["FROMCPU", "CONTIGUOUS"])
|
||||
Op = Union[BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps]
|
||||
|
||||
MERGE_MOVEMENT_OPS = True
|
||||
SHUFFLE_MOVEMENT_OPS = True
|
||||
SHUFFLE_SLICE_OPS = False # NOTE: 0/0 is NaN if you slice, so this can change the output
|
||||
REMOVE_MOVEMENT_NOPS = True
|
||||
MERGE_ELEMENTWISE_OPS = True
|
||||
MERGE_ELEMENTWISE_INTO_CONV_OUTPUT = False # TODO: this should be done at resolve time
|
||||
FOLD_CONSTANTS_INTO_KERNELS = True
|
||||
# -O1
|
||||
CACHE_LAZYBUFFERS = True # this leaks tons of memory. TODO: only cache unresolved LazyBuffers
|
||||
MERGE_MOVEMENT_OPS = True
|
||||
MERGE_UNARY_OPS = True
|
||||
REMOVE_MOVEMENT_NOPS = True
|
||||
|
||||
# -O2
|
||||
SHUFFLE_MOVEMENT_OPS = True
|
||||
MERGE_ELEMENTWISE_OPS = True
|
||||
FOLD_CONSTANTS_INTO_KERNELS = True # should depend on the JIT if it's a float or a number
|
||||
|
||||
# -O3
|
||||
SHUFFLE_SLICE_OPS = False # NOTE: 0/0 is NaN if you slice, so this can change the output
|
||||
MERGE_ELEMENTWISE_INTO_CONV_OUTPUT = False # TODO: should this be done at resolve time?
|
||||
|
||||
class LazyOp(NamedTuple):
|
||||
op: Op
|
||||
@@ -159,7 +165,7 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer:
|
||||
#srcs = [srcs[0], srcs[1].op]
|
||||
return LazyBuffer(out_shape, ProcessingOps, LazyOp(op, tuple(srcs)))
|
||||
|
||||
if MERGE_ELEMENTWISE_OPS:
|
||||
if (MERGE_UNARY_OPS and len(srcs) == 1) or MERGE_ELEMENTWISE_OPS:
|
||||
# remove the buffers from any BinaryOps that feed into this
|
||||
srcs = tuple(x.op if x.optype == BinaryOps and x.realized is None else x for x in srcs)
|
||||
|
||||
@@ -196,6 +202,7 @@ def _realize_binary_op(self:LazyBuffer) -> Tuple[gops.GPUBuffer, List[gops.GPUBu
|
||||
real_dict[s] = f"({root.op.arg[0]}f)"
|
||||
else:
|
||||
# TODO: this is a terrible hack, and it's very unclear if it's always right
|
||||
# can't we just replace the getter function?
|
||||
inline_valid = s.st.expr().replace("valid=valid && ", "").replace(";idx=0", "").replace("//", "/").replace("idx", "gid")
|
||||
if ';' not in inline_valid:
|
||||
real_dict[s] = f"(({inline_valid}) * {str(root.op.arg[0])}f)"
|
||||
|
||||
@@ -46,7 +46,8 @@ def log_op(optype, op, ret, inp):
|
||||
|
||||
for x in inp:
|
||||
if not isinstance(op, list): op = [op]
|
||||
if GRAPH == 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
if len(op) <= 2: sop = '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
elif len(op) <= 4: sop = '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
|
||||
else: sop = str(len(op))
|
||||
G.add_edge(nm(x), nm(ret), label=sop)
|
||||
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
|
||||
|
||||
Reference in New Issue
Block a user