mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz/cli: ux cleanups, show user python (#15817)
* small fixes * print python trace * jsonl * cleanup fmt, fix tqdm * print mode * types * less * keep those * fix * everyone can print json * pmc p2
This commit is contained in:
@@ -12,7 +12,7 @@ By default, VIZ CLI automatically loads the latest trace files.
|
||||
|
||||
Use `extra/viz/cli.py --profile -s ALL` to inspect the complete timing data of kernels, JIT, codegen and scheduling.
|
||||
|
||||
- Add DEBUG=3 to see AST, DEGUG=4 to also see source code.
|
||||
- Add DEBUG=3 to see AST, DEBUG=4 to also see source code.
|
||||
- Make sure to add NO_COLOR=1 to disable colored output.
|
||||
- Add --jsonl to see JSON output
|
||||
|
||||
@@ -22,12 +22,6 @@ DEBUG=3 extra/viz/cli.py --profile -s ALL > asts.txt
|
||||
|
||||
# Get kernel timing information in JSONL format
|
||||
extra/viz/cli.py --profile -s ALL --jsonl
|
||||
|
||||
# View top 40 slowest kernels on the AMD device and their AST (DEBUG=4 to see source code)
|
||||
DEBUG=3 extra/viz/cli.py --profile -s AMD --top 40
|
||||
|
||||
# List top 10 slowest operations across all devices
|
||||
extra/viz/cli.py --profile --top 10 -s ALL
|
||||
```
|
||||
|
||||
## Inspect codegen and PatternMatcher
|
||||
|
||||
@@ -58,17 +58,19 @@ def get(data:dict, key:str):
|
||||
def main(args) -> None:
|
||||
viz.load_rewrites(viz_data:=viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))))
|
||||
|
||||
def fmt(val, to_str=str) -> str: return json.dumps(val if isinstance(val, dict) else {"value":val}) if args.jsonl else to_str(val)
|
||||
|
||||
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz_data.ctxs if c.get("steps")}
|
||||
def print_step(step:dict) -> None:
|
||||
data = viz.get_render(viz_data, step["query"])
|
||||
if isinstance(data.get("value"), Iterator):
|
||||
for m in data["value"]:
|
||||
if m.get("uop"): print(json.dumps({"ast":m["uop"]}) if args.jsonl else m["uop"])
|
||||
if m.get("uop"): print(fmt(m["uop"]))
|
||||
if m.get("diff"):
|
||||
loc = pathlib.Path(m["upat"][0][0])
|
||||
print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}")
|
||||
for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if data.get("src") is not None: print(json.dumps({"src":data["src"]}) if args.jsonl else data["src"])
|
||||
print(fmt(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}"))
|
||||
for line in m["diff"]: print(fmt(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)))
|
||||
if data.get("src") is not None: print(fmt(data["src"]))
|
||||
|
||||
# ** Graph rewrites printer
|
||||
if args.rewrites:
|
||||
@@ -116,7 +118,8 @@ def main(args) -> None:
|
||||
phase, delay = "EXEC", int(e.st) - dispatch_st
|
||||
if inst and phase: info = f"{phase:<8} {inst}"
|
||||
unit = e.device.replace(" ", "-")
|
||||
print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}")
|
||||
row = {"clk":int(e.st)-inst_st, "unit":unit, "op":op_name, "dur":int(unwrap(e.en)-e.st), "delay":delay or "", "info":info}
|
||||
print(fmt(row, lambda _: f"{row['clk']:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {row['dur']:<4} {str(row['delay']):<4} {info}"))
|
||||
|
||||
# ** PMC printer
|
||||
elif "PMC" in args.src:
|
||||
@@ -130,17 +133,19 @@ def main(args) -> None:
|
||||
cols = r[2]["cols"] if len(r) > 2 else cols
|
||||
pmc_data = [[x for x in cols], *[[str(x) for x in r] for r in rows]]
|
||||
widths = [max(len(r[i]) for r in pmc_data) for i in range(len(cols))]
|
||||
def fmt(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |"
|
||||
print(fmt(pmc_data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in pmc_data[1:]])))
|
||||
def pad(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |"
|
||||
table_str = pad(pmc_data[0])+"\n"+pad(["-"*w for w in widths])+"\n"+("\n".join([pad(row) for row in pmc_data[1:]]))
|
||||
print(fmt({"cols":cols, "rows":rows}, lambda _: table_str))
|
||||
|
||||
# ** Memory printer
|
||||
elif data is not None and data["event_type"] == 1:
|
||||
print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info")
|
||||
print(fmt({"peak":data["peak"], "cols":["ts", "event", "key", "info"]},
|
||||
lambda _: f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info"))
|
||||
for e in data["events"]:
|
||||
info = str(e.get("arg", {}))
|
||||
info = str(arg:=e.pop("arg", {}))
|
||||
if e["event"] == "free":
|
||||
info = ', '.join([f"{fmt_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in e["arg"]["users"]])
|
||||
print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}")
|
||||
info = ', '.join([f"{fmt_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in arg["users"]])
|
||||
print(fmt({**e, "info":info}, lambda _: f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}"))
|
||||
|
||||
# ** Profiler printer
|
||||
else:
|
||||
@@ -157,7 +162,7 @@ def main(args) -> None:
|
||||
items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True)
|
||||
num_rows = len(items) if args.top < 0 else args.top
|
||||
for (dev,name),(t,c,ref) in items[:num_rows]:
|
||||
display = f"{dev[:7]:7s} {fmt_colored(name)}" if args.src == "ALL" else name
|
||||
display = f"{dev[:7]:7s} {fmt_colored(name)}" if args.src == "ALL" else fmt_colored(name)
|
||||
yield {"name":display, "dur_ms":t, "count":c, "pct":t/total*100.0, "ref":ref}
|
||||
if num_rows > 0 and items[num_rows:]:
|
||||
other_t = sum(t for _,(t,_,_) in items[num_rows:])
|
||||
@@ -171,9 +176,16 @@ def main(args) -> None:
|
||||
if dev == "MARKER":
|
||||
yield {"device":dev, "name":fmt_colored(e["name"]), "et_ms":ts*1e-3, "ref":None, "ext":None}
|
||||
continue
|
||||
if e["fmt"].startswith("TB:"): e["fmt"] = "" # TODO: print python backtrace at a reasonable DEBUG level
|
||||
ext:list[str] = []
|
||||
if (fmt:=e["fmt"]).startswith("TB:"):
|
||||
tb, fmt = json.loads(e["fmt"].replace("TB:", "")), ""
|
||||
while tb:
|
||||
file, lineno, fxn, code = tb.pop()
|
||||
line = f"{file.split('/')[-1]}:{lineno} {fxn}"
|
||||
if fmt: ext.append(f"{line} {code}")
|
||||
elif not file.startswith("<") and not fxn.startswith("<"): fmt = line
|
||||
yield {"device":dev, "name":fmt_colored(e["name"]), "dur_ms":e["dur"]*1e-3,
|
||||
"et_ms":(e["st"]+e["dur"])*1e-3, "fmt":e["fmt"], "ref":e["ref"], "ext":None}
|
||||
"et_ms":(e["st"]+e["dur"])*1e-3, "fmt":fmt, "ref":e["ref"], "ext":"\n".join(ext)}
|
||||
def fmt_top(k:dict) -> str:
|
||||
return f"{fmt_colored(k['name'])}{' ' * max(0, 36-ansilen(k['name']))} {time_to_str(k['dur_ms']*1e-3, w=9)} {k['count']:7d} {k['pct']:6.2f}%"
|
||||
def fmt_all(k:dict) -> str:
|
||||
@@ -184,12 +196,12 @@ def main(args) -> None:
|
||||
return f"{name} tm {ptm}/{k['et_ms']:9.2f}ms"+(f" ({fmt_str})" if k["fmt"] else "")
|
||||
fmt_row = fmt_top if args.top else fmt_all
|
||||
for k in (produce_top_kernels if args.top else produce_all_kernels)():
|
||||
if args.jsonl: print(json.dumps(k))
|
||||
else: print(fmt_row(k))
|
||||
print(fmt(k, to_str=fmt_row))
|
||||
if k["ref"] is not None:
|
||||
steps = rewrites[viz_data.ctxs[k["ref"]]["name"]]
|
||||
if DEBUG >= 3 and (ast_step:=steps.get("View Base AST")) is not None: print_step(ast_step)
|
||||
if DEBUG >= 4 and (src_step:=steps.get("View Source")) is not None: print_step(src_step)
|
||||
elif DEBUG >= 3 and k.get("ext"): print(fmt(k["ext"]))
|
||||
|
||||
def get_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
|
||||
@@ -170,7 +170,7 @@ def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRe
|
||||
next_sink = _reconstruct(data, ctx.sink)
|
||||
yield {"graph":uop_to_json(data, next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
||||
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches, disable=not ctx.matches):
|
||||
replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num)
|
||||
try: new_sink = next_sink.substitute(replaces)
|
||||
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
||||
|
||||
Reference in New Issue
Block a user