From 7bdb3adbbf1f8c5cbeca4ac484f70913a762e8a2 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 17 Apr 2026 15:16:07 +0300 Subject: [PATCH] viz/cli: simplification and reordering (#15785) * remove * work * this is all one thing * the reorder --- extra/viz/cli.py | 191 ++++++++++++++++++++++------------------------- 1 file changed, 91 insertions(+), 100 deletions(-) diff --git a/extra/viz/cli.py b/extra/viz/cli.py index 16838eedaf..94718418cc 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -46,6 +46,8 @@ def decode_profile(data:bytes) -> dict: for k,rep,num,mode in [u(" str: return ansistrip(s) if NO_COLOR else s + def get(data:dict, key:str): for k,v in data.items(): if ansistrip(k) == key: return v @@ -56,95 +58,102 @@ def get(data:dict, key:str): def main(args) -> None: viz.load_rewrites(viz_data:=viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))) - def format_colored(s:str) -> str: return ansistrip(s) if NO_COLOR else s + rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz_data.ctxs if c.get("steps")} + def print_step(step:dict) -> None: + data = viz.get_render(viz_data, step["query"]) + if isinstance(data.get("value"), Iterator): + for m in data["value"]: + if m.get("uop"): print(m["uop"]) + if m.get("diff"): + loc = pathlib.Path(m["upat"][0][0]) + print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}") + for line in m["diff"]: print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) + if data.get("src") is not None: print(data["src"]) - if args.profile: - events:list = viz.load_pickle(args.profile_path, default=[]) - if (profile_bytes:=viz.get_profile(viz_data, events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}") - profile = decode_profile(profile_bytes) - profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs - if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))]) - if args.src is None: - print("Select a source with -s") - for k in profile["layout"]: print(f" {format_colored(k)}") - return None + # ** Graph rewrites printer + if args.rewrites: + if args.src is None: return print("Select a source with -s"+"\n"+"\n".join([f" {fmt_colored(k)}" for k in rewrites])) + steps = get(rewrites, args.src) + if args.item is None: + for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else '')) + else: print_step(get(steps, args.item)) + return None - # ** SQTT printer - data = get(profile["layout"], args.src) - if "SQTT" in args.src: - # modern terminals support 24-bit color - def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m" - print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}") - print("-" * 100) - pc_map:dict[int, str] = {} - pkt_idxs:dict[str, itertools.count] = {} - dispatch_to_inst:dict[str, tuple[str, int]] = {} - inst_st:int|None = None - for e in viz.sqtt_timeline(*data): - if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg - if not isinstance(e, ProfileRangeEvent): continue - if inst_st is None: inst_st = int(e.st) - assert isinstance(e.name, TracingKey) - op_name, info = e.name.display_name, e.name.ret or "" - color = next((v for k,v in viz.wave_colors.items() if k in op_name), None) - op_str = hex_colored(op_name, color) if color and not NO_COLOR else op_name - phase, delay = None, 0 - idx = next(pkt_idxs.setdefault(e.device, itertools.count())) - if e.device.startswith("WAVE"): - inst = f"0x{(pc:=int(info.replace('PC:', ''))):05x} {pc_map[pc]}" if info else f"{'':7} {op_name}" - dispatch_to_inst[f"{e.device}-{idx}"] = (inst, int(e.st)) - phase = "DISPATCH" - if info.startswith("LINK:"): - inst, dispatch_st = dispatch_to_inst[info.replace("LINK:", "")] - phase, delay = "EXEC", int(e.st) - dispatch_st - if inst and phase: info = f"{phase:<8} {inst}" - unit = e.device.replace(" ", "-") - print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}") - return None + events:list = viz.load_pickle(args.profile_path, default=[]) + if (profile_bytes:=viz.get_profile(viz_data, events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}") + profile = decode_profile(profile_bytes) + profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz_data.ctxs + if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))]) + if args.src is None: return print("Select a source with -s"+"\n"+"\n".join([f" {fmt_colored(k)}" for k in profile["layout"]])) - # ** PMC printer - if "PMC" in args.src: - pmc = viz.unpack_pmc(data) - cols = pmc["cols"] - rows:list = [] - for r in pmc["rows"]: - if args.item is None: rows.append(r[:2]) - elif args.item == r[0]: - rows = r[2]["rows"] if len(r) > 2 else [r[:2]] - cols = r[2]["cols"] if len(r) > 2 else cols - data = [[x for x in cols], *[[str(x) for x in r] for r in rows]] - widths = [max(len(r[i]) for r in data) for i in range(len(cols))] - def fmt(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |" - print(fmt(data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in data[1:]]))) - return None + # ** SQTT printer + data = get(profile["layout"], args.src) + if "SQTT" in args.src: + # modern terminals support 24-bit color + def hex_colored(st:str, color:str) -> str: return f"\x1b[38;2;{int(color[1:3],16)};{int(color[3:5],16)};{int(color[5:7],16)}m{st}\x1b[0m" + print(f"{'Clk':<12} {'Unit':<20} {'Op':<22} {'Dur':<4} {'Delay':<4} {'Info'}") + print("-" * 100) + pc_map:dict[int, str] = {} + pkt_idxs:dict[str, itertools.count] = {} + dispatch_to_inst:dict[str, tuple[str, int]] = {} + inst_st:int|None = None + for e in viz.sqtt_timeline(*data): + if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg + if not isinstance(e, ProfileRangeEvent): continue + if inst_st is None: inst_st = int(e.st) + assert isinstance(e.name, TracingKey) + op_name, info = e.name.display_name, e.name.ret or "" + color = next((v for k,v in viz.wave_colors.items() if k in op_name), None) + op_str = hex_colored(op_name, color) if color and not NO_COLOR else op_name + phase, delay = None, 0 + idx = next(pkt_idxs.setdefault(e.device, itertools.count())) + if e.device.startswith("WAVE"): + inst = f"0x{(pc:=int(info.replace('PC:', ''))):05x} {pc_map[pc]}" if info else f"{'':7} {op_name}" + dispatch_to_inst[f"{e.device}-{idx}"] = (inst, int(e.st)) + phase = "DISPATCH" + if info.startswith("LINK:"): + inst, dispatch_st = dispatch_to_inst[info.replace("LINK:", "")] + phase, delay = "EXEC", int(e.st) - dispatch_st + if inst and phase: info = f"{phase:<8} {inst}" + unit = e.device.replace(" ", "-") + print(f"{int(e.st)-inst_st:<12} {unit:<20} {op_str}{' '*(22-ansilen(op_str))} {int(unwrap(e.en)-e.st):<4} {str(delay or ''):<4} {info}") - # ** Memory printer - if data["event_type"] == 1 and data.get("events", []): - print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info") - modes = ("read","write","write+read") - for e in data["events"]: - info = str(e.get("arg", {})) - if e["event"] == "free": - info = ', '.join([f"{format_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in e["arg"]["users"]]) - print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}") - return None + # ** PMC printer + elif "PMC" in args.src: + pmc = viz.unpack_pmc(data) + cols = pmc["cols"] + rows:list = [] + for r in pmc["rows"]: + if args.item is None: rows.append(r[:2]) + elif args.item == r[0]: + rows = r[2]["rows"] if len(r) > 2 else [r[:2]] + cols = r[2]["cols"] if len(r) > 2 else cols + data = [[x for x in cols], *[[str(x) for x in r] for r in rows]] + widths = [max(len(r[i]) for r in data) for i in range(len(cols))] + def fmt(r): return "| "+" | ".join(x+" "*(w-len(x)) for x,w in zip(r, widths))+" |" + print(fmt(data[0])+"\n"+fmt(["-"*w for w in widths])+"\n"+("\n".join([fmt(row) for row in data[1:]]))) - # ** Profiler printer + # ** Memory printer + elif data["event_type"] == 1: + print(f"Peak: {data['peak']}"+"\n"+f"{'TS':<10} {'Event':<6} {'Key':>8} Info") + for e in data["events"]: + info = str(e.get("arg", {})) + if e["event"] == "free": + info = ', '.join([f"{fmt_colored(kernel)} {['read','write','write+read'][mode]}@data{num}" for _,kernel,num,mode in e["arg"]["users"]]) + print(f"{e['ts']:<10} {e['event']:<6} {e.get('key', ''):>8} {info}") + + # ** Profiler printer + else: agg:dict[str, tuple[float, int, int|None]] = {} - total, first = 0, True - def print_kernel(ref:int) -> None: - if DEBUG >= 3: print(viz._reconstruct(viz_data, viz_data.trace.rewrites[ref][0].sink).pyrender()) - if DEBUG >= 4: print(viz_data.ctxs[ref]["prg"].src[3].arg) + total = 0 for e in data.get("events", []): et = e["dur"] * 1e-6 + # TODO: this shouldn't exist, replace with the DEBUG reconstructor if args.item is not None: if ansistrip(e["name"]) == args.item: ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) name = e["name"] + (" " * (46 - ansilen(e["name"]))) - print(f"{format_colored(name)} {ptm}/{et*1e3:9.2f}ms " + e.get("fmt", "").replace("\n", " | ") + " ") - if first: - if e["ref"] is not None: print_kernel(e["ref"]) - first = False + print(f"{fmt_colored(name)} {ptm}/{et*1e3:9.2f}ms " + e.get("fmt", "").replace("\n", " | ") + " ") else: t, c, ref = agg.get(e["name"], (0.0, 0, None)) agg[e["name"]] = (t+et, c+1, e["ref"]) @@ -153,33 +162,15 @@ def main(args) -> None: items = sorted(agg.items(), key=lambda kv:kv[1][0], reverse=True) num_rows = args.top for name,(t,c,ref) in items[:num_rows]: - print(f"{format_colored(name)}{' ' * max(0, 36 - ansilen(name))} {time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%") - if ref is not None: print_kernel(ref) + print(f"{fmt_colored(name)}{' ' * max(0, 36 - ansilen(name))} {time_to_str(t, w=9)} {c:7d} {t/total*100.0:6.2f}%") + if ref is not None: + steps = rewrites[viz_data.ctxs[ref]["name"]] + if DEBUG >= 3: print_step(get(steps, "View Base AST")) + if DEBUG >= 4: print_step(get(steps, "View Source")) if num_rows > 0 and items[num_rows:]: other_t = sum(t for _,(t,_,_) in items[num_rows:]) other_c = sum(c for _,(_,c,_) in items[num_rows:]) print(f"{'Other':<36} {time_to_str(other_t, w=9)} {other_c:7d} {other_t/total*100.0:6.2f}%") - return None - - # ** Graph rewrites printer - rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz_data.ctxs if c.get("steps")} - if args.src is None: - for k in rewrites: print(f" {format_colored(k)}") - return None - steps = get(rewrites, args.src) - if args.item is None: - for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else '')) - else: - data = viz.get_render(viz_data, get(steps, args.item)["query"]) - if isinstance(data.get("value"), Iterator): - for m in data["value"]: - if m.get("uop"): print(f"Input UOp:\n{m['uop']}") - if m.get("diff"): - loc = pathlib.Path(m["upat"][0][0]) - print(f"Rewrite at {loc.parent.name}/{loc.name}:{m['upat'][0][1]}\n{m['upat'][1]}") - for line in m["diff"]: - print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None)) - if data.get("src") is not None: print(data["src"]) def get_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(add_help=False)