mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
84 lines
3.8 KiB
Python
84 lines
3.8 KiB
Python
import os, atexit, itertools
|
|
try:
|
|
import networkx as nx # type: ignore
|
|
except ImportError:
|
|
nx = None # graph won't work
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional, TYPE_CHECKING
|
|
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, Op, OpType, LazyOp
|
|
from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters
|
|
from tinygrad.runtime.lib import RawConst
|
|
|
|
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
|
|
|
|
# **** debugging and graphing ****
|
|
|
|
G = nx.DiGraph() if nx is not None else None
|
|
cnts: Dict[OpType, int] = defaultdict(int)
|
|
if DEBUG >= 2:
|
|
def print_globalcounters():
|
|
if GlobalCounters.time_sum_s == 0: return
|
|
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s",
|
|
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms")
|
|
atexit.register(print_globalcounters)
|
|
if GRAPH:
|
|
def save_graph_exit():
|
|
for k,v in cnts.items(): print(k, v)
|
|
if PRUNEGRAPH: prune_graph()
|
|
print("saving", G)
|
|
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
|
# -Gnslimit=100 can make it finish, but you won't like results
|
|
os.system(f'dot -Tsvg {GRAPHPATH}.dot -o {GRAPHPATH}.svg')
|
|
atexit.register(save_graph_exit)
|
|
|
|
node_count = 0
|
|
def nm(x):
|
|
global node_count
|
|
if not hasattr(x, 'node_id'):
|
|
setattr(x, 'node_id', node_count)
|
|
node_count += 1
|
|
return x.node_id
|
|
|
|
def get_sop(op: List[Op]):
|
|
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
|
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
|
return str(len(op))
|
|
|
|
def str_dtype(dtyp):
|
|
ret = str(dtyp)[7:]
|
|
return "" if ret == 'float' else f"\n{ret}"
|
|
|
|
def log_op(ret: 'LazyBuffer', ast: LazyOp, show_graph: Optional[bool] = None, phantom=False):
|
|
if show_graph is None: show_graph = bool(GRAPH)
|
|
if not DEBUG and not show_graph: return
|
|
op: List[Op] = [x.op for x in ast.get_lazyops()]
|
|
inp: List['LazyBuffer'] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1]
|
|
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
|
|
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
|
cnts[optype] += 1
|
|
if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
|
|
if show_graph:
|
|
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#ff8080"}
|
|
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
|
|
|
|
for x in inp:
|
|
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060' if phantom else 'black')
|
|
if 'label' not in G.nodes[nm(x)]:
|
|
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
|
|
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
|
|
|
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)
|
|
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('60' if phantom else ('80' if dashed else str()))) if optype in top_colors else "#ffffff"
|
|
G.nodes[nm(ret)]['color'] = 'white' if phantom else 'black'
|
|
G.nodes[nm(ret)]['style'] = ('filled, dashed' if dashed else 'filled')
|
|
G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]
|
|
|
|
# prune movementops and loadops
|
|
def prune_graph():
|
|
dead_nodes = []
|
|
for n in G.nodes:
|
|
if 'prunable' in G.nodes[n] and G.nodes[n]['prunable']:
|
|
G.add_edges_from([(x, y) for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n))])
|
|
dead_nodes.append(n)
|
|
G.remove_nodes_from(dead_nodes)
|