From ac027055ef2cf442f8ce55cc2ddb2f433f742bce Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:35:20 +0300 Subject: [PATCH] viz: no global state (#15705) * start viz data * get_full_rewrites also moves * update ref_map * work * update consumers * cleaner cli * linter * cleanup tests * back * better * sqtt tests --- extra/sqtt/roc.py | 13 +-- extra/viz/cli.py | 12 +-- test/amd/test_sqtt_profiler.py | 9 ++- test/null/test_viz.py | 34 ++++---- tinygrad/viz/serve.py | 142 ++++++++++++++++++--------------- 5 files changed, 111 insertions(+), 99 deletions(-) diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 8771d09a42..21a2d676d6 100755 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -136,7 +136,8 @@ def print_data(data:dict) -> None: def main() -> None: import tinygrad.viz.serve as viz - viz.ctxs = [] + from tinygrad.uop.ops import RewriteTrace + data = viz.VizData() parser = argparse.ArgumentParser() parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)', @@ -147,24 +148,24 @@ def main() -> None: with args.profile.open("rb") as f: profile = pickle.load(f) - viz.get_profile(profile) + viz.get_profile(profile, data=data) # List all kernels if args.kernel is None: - for c in viz.ctxs: + for c in data.ctxs: print(c["name"]) for s in c["steps"]: print(" "+s["name"]) return None # Find kernel trace - trace = next((c for c in viz.ctxs if c["name"] == f"Exec {args.kernel}"), None) + trace = next((c for c in data.ctxs if c["name"] == f"SQTT {args.kernel}"), None) if not trace: raise RuntimeError(f"no matching trace for {args.kernel}") n = 0 for s in trace["steps"]: if "PKTS" in s["name"]: continue print(s["name"]) - data = viz.get_render(s["query"]) - print_data(data) + ret = viz.get_render(data, s["query"]) + print_data(ret) n += 1 if n > args.n: break diff --git a/extra/viz/cli.py b/extra/viz/cli.py index d1b8337c06..4b27b3f85d 100755 --- a/extra/viz/cli.py +++ b/extra/viz/cli.py @@ -52,16 +52,16 @@ def get(data:dict, key:str): raise RuntimeError(f'item "{key}" not found in list'+(f", did you mean {match[0]!r}?" if match else '')) def main(args) -> None: - viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})) - viz.ctxs = viz.get_rewrites(viz.trace) + viz.data = viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))) + viz.load_rewrites(viz.data) def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s if args.profile: events:list = viz.load_pickle(args.profile_path, default=[]) - if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}") + if (profile_bytes:=viz.get_profile(events, data=viz.data)) is None: raise RuntimeError(f"empty profile in {args.profile_path}") profile = decode_profile(profile_bytes) - profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.ctxs + profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.data.ctxs if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))]) if args.src is None: for k in profile["layout"]: @@ -142,7 +142,7 @@ def main(args) -> None: return None # ** Graph rewrites printer - rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.ctxs if c.get("steps")} + rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.data.ctxs if c.get("steps")} if args.src is None: for k in rewrites: print(f" {format_colored(k)}") return None @@ -150,7 +150,7 @@ def main(args) -> None: if args.item is None: for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else '')) else: - data = viz.get_render(get(steps, args.item)["query"]) + data = viz.get_render(viz.data, get(steps, args.item)["query"]) if isinstance(data.get("value"), Iterator): for m in data["value"]: if m.get("uop"): print(f"Input UOp:\n{m['uop']}") diff --git a/test/amd/test_sqtt_profiler.py b/test/amd/test_sqtt_profiler.py index 65bc97edc4..5f8334b89a 100644 --- a/test/amd/test_sqtt_profiler.py +++ b/test/amd/test_sqtt_profiler.py @@ -1,15 +1,16 @@ import unittest, contextlib from tinygrad import Device, Tensor, Context, TinyJit from tinygrad.device import Compiled, ProfileProgramEvent, ProfileDeviceEvent -from tinygrad.viz.serve import load_amd_counters +from tinygrad.viz.serve import load_amd_counters, VizData @contextlib.contextmanager def save_sqtt(): - yield (ret:=[]) + data = VizData() + yield data.ctxs Device[Device.DEFAULT].synchronize() Device[Device.DEFAULT]._at_profile_finalize() - load_amd_counters(ret, Compiled.profile_events) - ret[:] = [r for r in ret if r["name"].startswith("SQTT")] + load_amd_counters(data, Compiled.profile_events) + data.ctxs[:] = [r for r in data.ctxs if r["name"].startswith("SQTT")] @unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD") class TestSQTTProfiler(unittest.TestCase): diff --git a/test/null/test_viz.py b/test/null/test_viz.py index 9a1fd425b4..37acde3e07 100644 --- a/test/null/test_viz.py +++ b/test/null/test_viz.py @@ -10,7 +10,7 @@ from tinygrad.helpers import VIZ, cpu_profile, ProfilePointEvent, unwrap from tinygrad.device import Buffer from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace -from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json +from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData @track_rewrites(name=True) def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp: @@ -21,19 +21,19 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non # small container class for the viz server module class VizTrace: # loader init - def __init__(self): self._trace:RewriteTrace|None = None + def __init__(self): self._data:VizData|None = None @property - def trace(self) -> RewriteTrace: return unwrap(self._trace) - def set_trace(self) -> None: - self._trace = RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy()) - import tinygrad.viz.serve as serve_module - serve_module.trace = self._trace + def data(self) -> VizData: return unwrap(self._data) + def set_data(self) -> None: + data = VizData(RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy())) + load_rewrites(data) + self._data = data # the API - def list_items(self) -> list[dict]: return get_rewrites(self.trace) + def list_items(self) -> list[dict]: + return self.data.ctxs def get_details(self, rewrite_idx:int, step:int) -> Generator[dict, None, None]: - lst = self.list_items() - assert len(lst) > rewrite_idx, f"only loaded {len(lst)} traces, expecting at least {rewrite_idx}" - return get_full_rewrite(self.trace.rewrites[rewrite_idx][step]) + assert len(self.data.trace.rewrites) > rewrite_idx, f"only loaded {len(self.data.trace.rewrites)} traces, expecting at least {rewrite_idx}" + return get_full_rewrite(self.data, self.data.trace.rewrites[rewrite_idx][step]) @contextlib.contextmanager def save_viz(): @@ -52,7 +52,7 @@ def save_viz(): try: yield viz finally: - viz.set_trace() + viz.set_data() TRACK_MATCH_STATS.value = prev_tms PROFILE.value = prev_profile VIZ.value = prev_viz @@ -194,7 +194,7 @@ class TestViz(unittest.TestCase): class TestStruct: colored_field: str a = UOp(Ops.CUSTOM, arg=TestStruct(colored("xyz", "magenta")+colored("12345", "blue"))) - a2 = uop_to_json(a)[id(a)] + a2 = uop_to_json(a, VizData())[id(a)] self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')") def test_colored_label_multiline(self): @@ -217,11 +217,11 @@ class TestViz(unittest.TestCase): # use smaller stack limit for faster test (default is 250000) with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm]) graphs = flatten(x["graph"].values() for x in viz.get_details(0, 0)) - self.assertEqual(graphs[0], uop_to_json(a)[id(a)]) - self.assertEqual(graphs[1], uop_to_json(b)[id(b)]) + self.assertEqual(graphs[0], uop_to_json(a, VizData())[id(a)]) + self.assertEqual(graphs[1], uop_to_json(b, VizData())[id(b)]) # fallback to NOOP with the error message nop = UOp(Ops.NOOP, arg="infinite loop in fixed_point_rewrite") - self.assertEqual(graphs[2], uop_to_json(nop)[id(nop)]) + self.assertEqual(graphs[2], uop_to_json(nop, VizData())[id(nop)]) def test_const_node_visibility(self): with save_viz() as viz: @@ -241,7 +241,7 @@ class TestViz(unittest.TestCase): c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0)) alu = a + c - graph = uop_to_json(alu) + graph = uop_to_json(alu, VizData()) # the RESHAPE and EXPAND nodes from the const should not appear in the graph labels = {v["label"].split("\n")[0] for v in graph.values()} self.assertNotIn("RESHAPE", labels) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 3142d173b4..f36ff07be1 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 -import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct, re +import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, codecs, io, struct, re import pathlib, traceback, itertools, socketserver from contextlib import redirect_stdout, redirect_stderr, contextmanager from decimal import Decimal +from dataclasses import dataclass, field from urllib.parse import parse_qs, urlparse from http.server import BaseHTTPRequestHandler from typing import Any, TypedDict, TypeVar, Generator, Callable @@ -61,21 +62,27 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", def create_step(name:str, query:tuple[str, int, int], data=None, depth:int=0, **kwargs) -> dict: return {"name":name, "query":f"{query[0]}?ctx={query[1]}&step={query[2]}", "data":data, "depth":depth, **kwargs} -# ** list all saved rewrites +@dataclass(frozen=True) +class VizData: + trace:RewriteTrace = field(default_factory=lambda: RewriteTrace([], [], {})) + ctxs:list[dict] = field(default_factory=list) + ref_map:dict[Any, int] = field(default_factory=dict) + all_uops:dict[int, UOp] = field(default_factory=dict) -ref_map:dict[Any, int] = {} -def get_rewrites(t:RewriteTrace) -> list[dict]: - ret = [] - for i,(k,v) in enumerate(zip(t.keys, t.rewrites)): +# ** load all saved rewrites + +def load_rewrites(data:VizData) -> None: + assert not data.ctxs and not data.ref_map, "load_rewrites called multiple times" + for i,k in enumerate(data.trace.keys): + v = data.trace.rewrites[i] steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc), trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)] - if (p:=get_prg_uop(i)) is not None: + if (p:=get_prg_uop(data, i)) is not None: steps.append(create_step("View UOp List", ("/uops", i, len(steps)))) steps.append(create_step("View Source", ("/code", i, len(steps)), p.src[3].arg)) steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, p.src[4].arg))) - for key in k.keys: ref_map[key] = i - ret.append({"name":k.display_name, "steps":steps}) - return ret + for key in k.keys: data.ref_map[key] = i + data.ctxs.append({"name":k.display_name, "steps":steps}) # ** get the complete UOp graphs for one rewrite @@ -93,9 +100,7 @@ def pystr(u:UOp) -> str: try: return pyrender(u) except Exception: return str(u) -# all the trace points, initialized after the trace loads -ctxs:list[dict] = [] -def uop_to_json(x:UOp) -> dict[int, dict]: +def uop_to_json(x:UOp, data:VizData) -> dict[int, dict]: assert isinstance(x, UOp) graph: dict[int, dict] = {} excluded: set[UOp] = set() @@ -138,7 +143,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]: label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs]) except Exception: label += "\n" - if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None and ctxs: label += f"\ncodegen@{ctxs[ref]['name']}" + if (ref:=data.ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}" # NOTE: kernel already has metadata in arg if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata) # limit SOURCE labels line count @@ -148,28 +153,30 @@ def uop_to_json(x:UOp) -> dict[int, dict]: "ref":ref, "tag":repr(u.tag) if u.tag is not None else None} return graph -@functools.cache -def _reconstruct(a:int, depth:int|None=None): - op, dtype, src, arg, *rest = trace.uop_fields[a] +def _reconstruct(data:VizData, a:int, depth:int|None=None): + if depth is None and a in data.all_uops: return data.all_uops[a] + op, dtype, src, arg, *rest = data.trace.uop_fields[a] if depth is not None and depth <= 0: return UOp(op, dtype, (), arg, *rest) - return UOp(op, dtype, tuple(_reconstruct(s, None if depth is None else depth-1) for s in src), arg, *rest) + ret = UOp(op, dtype, tuple(_reconstruct(data, s, None if depth is None else depth-1) for s in src), arg, *rest) + if depth is None: data.all_uops[a] = ret + return ret -def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: - next_sink = _reconstruct(ctx.sink) - yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None} +def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: + next_sink = _reconstruct(data, ctx.sink) + yield {"graph":uop_to_json(next_sink, data), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches): - replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num) + replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num) try: new_sink = next_sink.substitute(replaces) except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc) - yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json], + yield {"graph":(sink_json:=uop_to_json(new_sink, data)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json], "diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)} if not ctx.bottom_up: next_sink = new_sink -def get_prg_uop(i:int) -> UOp|None: - s = next((s for s in trace.rewrites[i] if s.name == "View Program"), None) - return _reconstruct(s.sink, depth=1) if s is not None else None +def get_prg_uop(data:VizData, i:int) -> UOp|None: + s = next((s for s in data.trace.rewrites[i] if s.name == "View Program"), None) + return _reconstruct(data, s.sink, depth=1) if s is not None else None # encoder helpers @@ -187,31 +194,30 @@ def rel_ts(ts:int|Decimal, start_ts:int, ctx:str="") -> int: # Profiler API -device_ts_diffs:dict[str, Decimal] = {} -def cpu_ts_diff(device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0)) +def cpu_ts_diff(device_ts_diffs:dict[str, Decimal], device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0)) DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent -def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]: +def flatten_events(profile:list[ProfileEvent], device_ts_diffs:dict[str, Decimal]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]: for e in profile: - if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device)), (e.en if e.en is not None else e.st)+diff, e) + if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(device_ts_diffs, e.device)), (e.en if e.en is not None else e.st)+diff, e) elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e) elif isinstance(e, ProfileGraphEvent): cpu_ts = [] - for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device)), e.sigs[ent.en_id]+diff] + for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(device_ts_diffs, ent.device)), e.sigs[ent.en_id]+diff] yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et) for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent) # normalize event timestamps and attach kernel metadata -def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None: +def timeline_layout(data:VizData, dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None: events:list[bytes] = [] exec_points:dict[str, ProfilePointEvent] = {} for st,et,dur,e in dev_events: if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e if dur == 0: continue name, fmt, key = e.name, [], None - if (ref:=ref_map.get(name)) is not None and ctxs: - name = ctxs[ref]["name"] - if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None: + if (ref:=data.ref_map.get(name)) is not None and ref < len(data.ctxs): + name = data.ctxs[ref]["name"] + if (p:=get_prg_uop(data, ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None: flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6) membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS", @@ -222,7 +228,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts: key = ei.key elif isinstance(e.name, TracingKey): name = e.name.display_name - ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None) + ref = next((v for k in e.name.keys if (v:=data.ref_map.get(k)) is not None), None) if isinstance(e.name.ret, str): fmt.append(e.name.ret) elif isinstance(e.name.ret, int): membw = (nbytes:=e.name.ret) / (dur * 1e-6) @@ -313,7 +319,7 @@ def unpack_pmc(e) -> dict: # ** on startup, list all the performance counter traces -def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None: +def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None: from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent counter_events:dict[tuple[int, int], dict] = {} durations:dict[str, list[float]] = {} @@ -326,22 +332,23 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None: if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e if isinstance(e, ProfileDeviceEvent) and e.device.startswith("AMD"): arch = f"gfx{unwrap(e.props)['gfx_target_version']//1000}" if len(counter_events) == 0: return None - ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]}) + data.ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(data.ctxs), 0), (durations, all_counters:={}))]}) run_number = {n:0 for n,_ in counter_events} for (k, tag),v in counter_events.items(): # use the colored name if it exists - name = unwrap(get_prg_uop(r)).src[0].arg.name if (r:=ref_map.get(pname:=prg_events[k].name)) is not None else pname + name = unwrap(get_prg_uop(data, r)).src[0].arg.name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname run_number[k] += 1 steps:list[dict] = [] if (pmc:=v.get(ProfilePMCEvent)): - steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc)) + steps.append(create_step("PMC", ("/prg-pmc", len(data.ctxs), len(steps)), pmc)) all_counters[(name, run_number[k], pname)] = pmc[0] # to decode a SQTT trace, we need the raw stream, program binary and device properties if (sqtt:=v.get(ProfileSQTTEvent)): for e in sqtt: - if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch))) - steps.append(create_step("OCC", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch))) - ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) + if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(data.ctxs), len(steps)), + data=(e.blob, prg_events[k].lib, arch))) + steps.append(create_step("OCC", ("/prg-sqtt", len(data.ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch))) + data.ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc", **{x:"#b2b7c9" for x in ["VMEM", "SGMEM"]}, "LDS": "#9fb4a6", "IMMEDIATE": "#f3b44a", "BARRIER": "#d00000", @@ -457,23 +464,25 @@ def device_sort_fn(k:str) -> tuple: dev_base = p[0] if len(p) < 2 or not p[1].isdigit() else f"{p[0]}:{p[1]}" return (is_memory, special.get(p[0], special['ALLDEVS']), dev_base, k) -def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None: +def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn, data:VizData|None=None) -> bytes|None: + if data is None: data = VizData(RewriteTrace([], [], {})) # start by getting the time diffs - device_decoders:dict[str, Callable[[list[dict], list[ProfileEvent]], None]] = {} + device_ts_diffs:dict[str, Decimal] = {} + device_decoders:dict[str, Callable[[VizData, list[ProfileEvent]], None]] = {} for ev in profile: if isinstance(ev, ProfileDeviceEvent): device_ts_diffs[ev.device] = ev.tdiff if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_amd_counters if d == "NV": device_decoders[d] = load_nv_counters # load device specific counters - for fxn in device_decoders.values(): fxn(ctxs, profile) + for fxn in device_decoders.values(): fxn(data, profile) # map events per device dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {} markers:list[ProfilePointEvent] = [] ext_data:dict[str, Any] = {} start_ts:int|None = None end_ts:int|None = None - for ts,en,e in flatten_events(profile): + for ts,en,e in flatten_events(profile, device_ts_diffs): dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e)) if start_ts is None or st < start_ts: start_ts = st if end_ts is None or et > end_ts: end_ts = et @@ -487,7 +496,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_ dtype_size:dict[str, int] = {} for k,v in dev_events.items(): v.sort(key=lambda e:e[0]) - layout[k] = timeline_layout(v, start_ts, scache) + layout[k] = timeline_layout(data, v, start_ts, scache) layout.update([graph_layout(k, v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)]) sorted_layout = sorted([k for k,v in layout.items() if v is not None], key=sort_fn) ret = [b"".join([struct.pack(" None: +def load_nv_counters(data:VizData, profile:list) -> None: steps:list[dict] = [] sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None} run_number:dict[str, int] = {} for e in profile: if type(e).__name__ == "ProfilePMAEvent": run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1 - steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)), + steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(data.ctxs), len(steps)), data=(e.blob, sm_version[e.device]))) - if steps: ctxs.append({"name":"All Counters", "steps":steps}) + if steps: data.ctxs.append({"name":"All Counters", "steps":steps}) def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]: from extra.nv_pma.decode import decode, decode_tpc_id @@ -586,10 +595,10 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict: from tinygrad.renderer.amd.dsl import Reg for pc, inst in pc_table.items(): pc_tokens[pc] = tokens = [] - for name, field in inst._fields: + for name, f in inst._fields: if isinstance(val:=getattr(inst, name), Reg): tokens.append({"st":val.fmt(), "keys":[f"r{val.offset+i}" for i in range(val.sz)], "kind":1}) elif name in {"op","opx","opy"}: tokens.append({"st":(op_name:=val.name.lower()), "keys":[op_name], "kind":0}) - elif name != "encoding" and val != field.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1}) + elif name != "encoding" and val != f.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1}) # show a smaller view for repeated instructions in the graph lines:list[str] = [] disasm = {pc:str(inst) for pc,inst in pc_table.items()} @@ -616,12 +625,12 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict: # ** Main render function to get the complete details about a trace event -def get_render(query:str) -> dict: +def get_render(viz_data:VizData, query:str) -> dict: url = urlparse(query) i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/") - data = ctxs[i]["steps"][j]["data"] - if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"} - if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(trace.rewrites[i][j-1].sink).src[2].src)), "lang":"txt"} + data = viz_data.ctxs[i]["steps"][j]["data"] + if fmt == "graph-rewrites": return {"value":get_full_rewrite(viz_data, viz_data.trace.rewrites[i][j]), "content_type":"text/event-stream"} + if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(viz_data, viz_data.trace.rewrites[i][j-1].sink).src[2].src))} if fmt == "code": return {"src":data, "lang":"cpp"} if fmt == "asm": ret:dict = {} @@ -643,13 +652,13 @@ def get_render(query:str) -> dict: if fmt.startswith("prg-pkts"): ret = {} with soft_err(lambda err:ret.update(err)): - if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple)): + if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple, data=viz_data)): ret = {"value":events, "content_type":"application/octet-stream"} else: ret = {"src":"No SQTT trace on this SE."} return ret if fmt == "prg-sqtt": ret = {} - if len((steps:=ctxs[i]["steps"])[j+1:]) == 0: + if len((steps:=viz_data.ctxs[i]["steps"])[j+1:]) == 0: with soft_err(lambda err: ret.update(err)): cu_events, units, wave_insts = unpack_sqtt(*data) for cu in sorted(cu_events, key=row_tuple): @@ -658,7 +667,7 @@ def get_render(query:str) -> dict: for k in sorted(wave_insts.get(cu, []), key=row_tuple): steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", i, len(steps)), loc=(data:=wave_insts[cu][k])["loc"], depth=2, data=data)) return {**ret, "steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]} - if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple), "content_type":"application/octet-stream"} + if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple, data=viz_data), "content_type":"application/octet-stream"} if fmt == "sqtt-insts": columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"] inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"] @@ -685,11 +694,11 @@ def get_render(query:str) -> dict: prev_instr = max(prev_instr, e.time + e.dur) summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu}, {"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}] - return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)} + return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)} if fmt == "prg-pma-pkts": ret = {} with soft_err(lambda err:ret.update(err)): - if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"} + if (events:=get_profile(pma_timeline(*data), row_tuple, data=viz_data)): ret = {"value":events, "content_type":"application/octet-stream"} else: ret = {"src":"No PMA samples found."} return ret return data @@ -712,11 +721,11 @@ class Handler(HTTPRequestHandler): except FileNotFoundError: status_code = 404 elif url.path == "/ctxs": - lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs] + lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in data.ctxs] ret, content_type = json.dumps(lst).encode(), "application/json" elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" else: - if not (render_src:=get_render(self.path)): status_code = 404 + if not (render_src:=get_render(data, self.path)): status_code = 404 else: if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"] else: ret, content_type = json.dumps(render_src).encode(), "application/json" @@ -754,8 +763,9 @@ if __name__ == "__main__": st = time.perf_counter() print("*** viz is starting") - ctxs = get_rewrites(trace:=load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))) - profile_ret = get_profile(load_pickle(args.profile_path, default=[])) + data = VizData(load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))) + load_rewrites(data) + profile_ret = get_profile(load_pickle(args.profile_path, default=[]), data=data) server = TCPServerWithReuse(('', PORT), Handler) reloader_thread = threading.Thread(target=reloader)