catch cycles in print_tree (#2891)

* feat: smaller tree on references

* fix: shorter line

* fix: huh

* fix: should be all

* feat: cleaner

* fix: extra imports

* fix: pass by reference
This commit is contained in:
wozeparrot
2023-12-21 21:40:37 -05:00
committed by GitHub
parent 4432cb17bb
commit 5f3d5cfb02

View File

@@ -2,7 +2,7 @@ import os, atexit
from typing import List, Any
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.symbolic import NumNode
@@ -86,16 +86,19 @@ def log_lazybuffer(lb, scheduled=False):
if nm(lb) not in G.nodes:
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{bm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
def _tree(lazydata, prefix=""):
if type(lazydata).__name__ == "LazyBuffer":
return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
def _tree(lazydata, cycles, cnt, prefix=""):
cnt[0] += 1
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
if (lid := id(lazydata)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
return [f"━⬆︎ goto {cycles[id(lazydata)][0]}: {lazydata.op.name}"]
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
childs = [_tree(c) for c in lazydata.src[:]]
childs = [_tree(c, cycles, cnt) for c in lazydata.src[:]]
for c in childs[:-1]: lines += [f"{c[0]}"] + [f"{l}" for l in c[1:]]
return lines + [""+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata, {}, [-1]))]))
def graph_uops(uops:List[UOp]):
import networkx as nx