mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz/cli: simplification and reordering (#15785)
* remove * work * this is all one thing * the reorder
This commit is contained in:
191
extra/viz/cli.py
191
extra/viz/cli.py
@@ -46,6 +46,8 @@ def decode_profile(data:bytes) -> dict:
|
||||
for k,rep,num,mode in [u("<IIIB") for _ in range(u("<I")[0])]]}})
|
||||
return {"dur":total_dur, "peak":global_peak, "layout":layout, "markers":markers}
|
||||
|
||||
def fmt_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s
|
||||
|
||||
def get(data:dict, key:str):
|
||||
for k,v in data.items():
|
||||
if ansistrip(k) == key: return v
|
||||
@@ -56,95 +58,102 @@ def get(data:dict, key:str):
|
||||
def main(args) -> None:
|
||||
viz.load_rewrites(viz_data:=viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))))
|
||||
|
||||
def format_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s
|
||||
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz_data.ctxs if c.get("steps")}
|
||||
def print_step(step:dict) -> None:
|
||||
data = viz.get_render(viz_data, step["query"])
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(m["uop"])
|
||||
if m.get("diff"):
|
||||
loc = pathlib.Path(m["upat"][0][0])
|
||||
print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")
|
||||
for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if data.get("src") is not None: print(data["src"])
|
||||
|
||||
if args.profile:
|
||||
events:list = viz.load_pickle(args.profile_path, default=[])
|
||||
if (profile_bytes:=viz.get_profile(viz_data, events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
|
||||
profile = decode_profile(profile_bytes)
|
||||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs
|
||||
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
|
||||
if args.src is None:
|
||||
print("Select a source with -s")
|
||||
for k in profile["layout"]: print(f" {format_colored(k)}")
|
||||
return None
|
||||
# ** Graph rewrites printer
|
||||
if args.rewrites:
|
||||
if args.src is None: return print("Select a source with -s"+"\n"+"\n".join([f" {fmt_colored(k)}" for k in rewrites]))
|
||||
steps = get(rewrites, args.src)
|
||||
if args.item is None:
|
||||
for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else ''))
|
||||
else: print_step(get(steps, args.item))
|
||||
return None
|
||||
|
||||
# ** SQTT printer
|
||||
data = get(profile["layout"], args.src)
|
||||
if "SQTT" in args.src:
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}")
|
||||
print("-" * 100)
|
||||
pc_map:dict[int, str] = {}
|
||||
pkt_idxs:dict[str, itertools.count] = {}
|
||||
dispatch_to_inst:dict[str, tuple[str, int]] = {}
|
||||
inst_st:int|None = None
|
||||
for e in viz.sqtt_timeline(*data):
|
||||
if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg
|
||||
if not isinstance(e, ProfileRangeEvent): continue
|
||||
if inst_st is None: inst_st = int(e.st)
|
||||
assert isinstance(e.name, TracingKey)
|
||||
op_name, info = e.name.display_name, e.name.ret or ""
|
||||
color = next((v for k,v in viz.wave_colors.items() if k in op_name), None)
|
||||
op_str = hex_colored(op_name, color) if color and not NO_COLOR else op_name
|
||||
phase, delay = None, 0
|
||||
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
|
||||
if e.device.startswith("WAVE"):
|
||||
inst = f"0x{(pc:=int(info.replace('PC:', ''))):05x} {pc_map[pc]}" if info else f"{'':7} {op_name}"
|
||||
dispatch_to_inst[f"{e.device}-{idx}"] = (inst, int(e.st))
|
||||
phase = "DISPATCH"
|
||||
if info.startswith("LINK:"):
|
||||
inst, dispatch_st = dispatch_to_inst[info.replace("LINK:", "")]
|
||||
phase, delay = "EXEC", int(e.st) - dispatch_st
|
||||
if inst and phase: info = f"{phase:<8} {inst}"
|
||||
unit = e.device.replace(" ", "-")
|
||||
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}")
|
||||
return None
|
||||
events:list = viz.load_pickle(args.profile_path, default=[])
|
||||
if (profile_bytes:=viz.get_profile(viz_data, events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
|
||||
profile = decode_profile(profile_bytes)
|
||||
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs
|
||||
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
|
||||
if args.src is None: return print("Select a source with -s"+"\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]]))
|
||||
|
||||
# ** PMC printer
|
||||
if "PMC" in args.src:
|
||||
pmc = viz.unpack_pmc(data)
|
||||
cols = pmc["cols"]
|
||||
rows:list = []
|
||||
for r in pmc["rows"]:
|
||||
if args.item is None: rows.append(r[:2])
|
||||
elif args.item == r[0]:
|
||||
rows = r[2]["rows"] if len(r) > 2 else [r[:2]]
|
||||
cols = r[2]["cols"] if len(r) > 2 else cols
|
||||
data = [[x for x in cols], *[[str(x) for x in r] for r in rows]]
|
||||
widths = [max(len(r[i]) for r in data) for i in range(len(cols))]
|
||||
def fmt(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |"
|
||||
print(fmt(data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in data[1:]])))
|
||||
return None
|
||||
# ** SQTT printer
|
||||
data = get(profile["layout"], args.src)
|
||||
if "SQTT" in args.src:
|
||||
# modern terminals support 24-bit color
|
||||
def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m"
|
||||
print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}")
|
||||
print("-" * 100)
|
||||
pc_map:dict[int, str] = {}
|
||||
pkt_idxs:dict[str, itertools.count] = {}
|
||||
dispatch_to_inst:dict[str, tuple[str, int]] = {}
|
||||
inst_st:int|None = None
|
||||
for e in viz.sqtt_timeline(*data):
|
||||
if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg
|
||||
if not isinstance(e, ProfileRangeEvent): continue
|
||||
if inst_st is None: inst_st = int(e.st)
|
||||
assert isinstance(e.name, TracingKey)
|
||||
op_name, info = e.name.display_name, e.name.ret or ""
|
||||
color = next((v for k,v in viz.wave_colors.items() if k in op_name), None)
|
||||
op_str = hex_colored(op_name, color) if color and not NO_COLOR else op_name
|
||||
phase, delay = None, 0
|
||||
idx = next(pkt_idxs.setdefault(e.device, itertools.count()))
|
||||
if e.device.startswith("WAVE"):
|
||||
inst = f"0x{(pc:=int(info.replace('PC:', ''))):05x} {pc_map[pc]}" if info else f"{'':7} {op_name}"
|
||||
dispatch_to_inst[f"{e.device}-{idx}"] = (inst, int(e.st))
|
||||
phase = "DISPATCH"
|
||||
if info.startswith("LINK:"):
|
||||
inst, dispatch_st = dispatch_to_inst[info.replace("LINK:", "")]
|
||||
phase, delay = "EXEC", int(e.st) - dispatch_st
|
||||
if inst and phase: info = f"{phase:<8} {inst}"
|
||||
unit = e.device.replace(" ", "-")
|
||||
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}")
|
||||
|
||||
# ** Memory printer
|
||||
if data["event_type"] == 1 and data.get("events", []):
|
||||
print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")
|
||||
modes = ("read","write","write+read")
|
||||
for e in data["events"]:
|
||||
info = str(e.get("arg", {}))
|
||||
if e["event"] == "free":
|
||||
info = ', '.join([f"{format_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in e["arg"]["users"]])
|
||||
print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")
|
||||
return None
|
||||
# ** PMC printer
|
||||
elif "PMC" in args.src:
|
||||
pmc = viz.unpack_pmc(data)
|
||||
cols = pmc["cols"]
|
||||
rows:list = []
|
||||
for r in pmc["rows"]:
|
||||
if args.item is None: rows.append(r[:2])
|
||||
elif args.item == r[0]:
|
||||
rows = r[2]["rows"] if len(r) > 2 else [r[:2]]
|
||||
cols = r[2]["cols"] if len(r) > 2 else cols
|
||||
data = [[x for x in cols], *[[str(x) for x in r] for r in rows]]
|
||||
widths = [max(len(r[i]) for r in data) for i in range(len(cols))]
|
||||
def fmt(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |"
|
||||
print(fmt(data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in data[1:]])))
|
||||
|
||||
# ** Profiler printer
|
||||
# ** Memory printer
|
||||
elif data["event_type"] == 1:
|
||||
print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")
|
||||
for e in data["events"]:
|
||||
info = str(e.get("arg", {}))
|
||||
if e["event"] == "free":
|
||||
info = ', '.join([f"{fmt_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in e["arg"]["users"]])
|
||||
print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")
|
||||
|
||||
# ** Profiler printer
|
||||
else:
|
||||
agg:dict[str, tuple[float, int, int|None]] = {}
|
||||
total, first = 0, True
|
||||
def print_kernel(ref:int) -> None:
|
||||
if DEBUG >= 3: print(viz._reconstruct(viz_data, viz_data.trace.rewrites[ref][0].sink).pyrender())
|
||||
if DEBUG >= 4: print(viz_data.ctxs[ref]["prg"].src[3].arg)
|
||||
total = 0
|
||||
for e in data.get("events", []):
|
||||
et = e["dur"] * 1e-6
|
||||
# TODO: this shouldn't exist, replace with the DEBUG reconstructor
|
||||
if args.item is not None:
|
||||
if ansistrip(e["name"]) == args.item:
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None)
|
||||
name = e["name"] + (" " * (46 - ansilen(e["name"])))
|
||||
print(f"{format_colored(name)} {ptm}/{et*1e3:9.2f}ms " + e.get("fmt", "").replace("\n", " | ") + " ")
|
||||
if first:
|
||||
if e["ref"] is not None: print_kernel(e["ref"])
|
||||
first = False
|
||||
print(f"{fmt_colored(name)} {ptm}/{et*1e3:9.2f}ms " + e.get("fmt", "").replace("\n", " | ") + " ")
|
||||
else:
|
||||
t, c, ref = agg.get(e["name"], (0.0, 0, None))
|
||||
agg[e["name"]] = (t+et, c+1, e["ref"])
|
||||
@@ -153,33 +162,15 @@ def main(args) -> None:
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
num_rows = args.top
|
||||
for name,(t,c,ref) in items[:num_rows]:
|
||||
print(f"{format_colored(name)}{' ' * max(0, 36 - ansilen(name))} {time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%")
|
||||
if ref is not None: print_kernel(ref)
|
||||
print(f"{fmt_colored(name)}{' ' * max(0, 36 - ansilen(name))} {time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%")
|
||||
if ref is not None:
|
||||
steps = rewrites[viz_data.ctxs[ref]["name"]]
|
||||
if DEBUG >= 3: print_step(get(steps, "View Base AST"))
|
||||
if DEBUG >= 4: print_step(get(steps, "View Source"))
|
||||
if num_rows > 0 and items[num_rows:]:
|
||||
other_t = sum(t for _,(t,_,_) in items[num_rows:])
|
||||
other_c = sum(c for _,(_,c,_) in items[num_rows:])
|
||||
print(f"{'Other':<36} {time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%")
|
||||
return None
|
||||
|
||||
# ** Graph rewrites printer
|
||||
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz_data.ctxs if c.get("steps")}
|
||||
if args.src is None:
|
||||
for k in rewrites: print(f" {format_colored(k)}")
|
||||
return None
|
||||
steps = get(rewrites, args.src)
|
||||
if args.item is None:
|
||||
for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else ''))
|
||||
else:
|
||||
data = viz.get_render(viz_data, get(steps, args.item)["query"])
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(f"Input UOp:\n{m['uop']}")
|
||||
if m.get("diff"):
|
||||
loc = pathlib.Path(m["upat"][0][0])
|
||||
print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")
|
||||
for line in m["diff"]:
|
||||
print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if data.get("src") is not None: print(data["src"])
|
||||
|
||||
def get_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
|
||||
Reference in New Issue
Block a user