From 5f3d5cfb02a394c6c8151cd751a583a1ebd642df Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 21 Dec 2023 21:40:37 -0500 Subject: [PATCH] catch cycles in print_tree (#2891) * feat: smaller tree on references * fix: shorter line * fix: huh * fix: should be all * feat: cleaner * fix: extra imports * fix: pass by reference --- tinygrad/graph.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 8758c3fcc1..12868cf575 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -2,7 +2,7 @@ import os, atexit from typing import List, Any from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp from tinygrad.device import Device -from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters +from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.shape.symbolic import NumNode @@ -86,16 +86,19 @@ def log_lazybuffer(lb, scheduled=False): if nm(lb) not in G.nodes: # realized but unseen? G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{bm(lb.realized)}"', style='filled', fillcolor="#f0c08080") -def _tree(lazydata, prefix=""): - if type(lazydata).__name__ == "LazyBuffer": - return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ") + +def _tree(lazydata, cycles, cnt, prefix=""): + cnt[0] += 1 if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] + if (lid := id(lazydata)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0: + return [f"━⬆︎ goto {cycles[id(lazydata)][0]}: {lazydata.op.name}"] + cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1) lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] - childs = [_tree(c) for c in lazydata.src[:]] + childs = [_tree(c, cycles, cnt) for c in lazydata.src[:]] for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]] return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]] -def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))])) +def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata, {}, [-1]))])) def graph_uops(uops:List[UOp]): import networkx as nx