mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
catch cycles in print_tree (#2891)
* feat: smaller tree on references * fix: shorter line * fix: huh * fix: should be all * feat: cleaner * fix: extra imports * fix: pass by reference
This commit is contained in:
@@ -2,7 +2,7 @@ import os, atexit
|
||||
from typing import List, Any
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
|
||||
@@ -86,16 +86,19 @@ def log_lazybuffer(lb, scheduled=False):
|
||||
if nm(lb) not in G.nodes:
|
||||
# realized but unseen?
|
||||
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{bm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
||||
def _tree(lazydata, prefix=""):
|
||||
if type(lazydata).__name__ == "LazyBuffer":
|
||||
return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
||||
|
||||
def _tree(lazydata, cycles, cnt, prefix=""):
|
||||
cnt[0] += 1
|
||||
if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
if (lid := id(lazydata)) in cycles and cycles[lid][1] > (tcnt := getenv("TREE_CYCLE_CNT", 5)) and tcnt >= 0:
|
||||
return [f"━⬆︎ goto {cycles[id(lazydata)][0]}: {lazydata.op.name}"]
|
||||
cycles[lid] = (cnt[0], 1 if lid not in cycles else cycles[lid][1]+1)
|
||||
lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"]
|
||||
childs = [_tree(c) for c in lazydata.src[:]]
|
||||
childs = [_tree(c, cycles, cnt) for c in lazydata.src[:]]
|
||||
for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]]
|
||||
return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]]
|
||||
|
||||
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))]))
|
||||
def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata, {}, [-1]))]))
|
||||
|
||||
def graph_uops(uops:List[UOp]):
|
||||
import networkx as nx
|
||||
|
||||
Reference in New Issue
Block a user