viz: no global state (#15705)

* start viz data

* get_full_rewrites also moves

* update ref_map

* work

* update consumers

* cleaner cli

* linter

* cleanup tests

* back

* better

* sqtt tests
This commit is contained in:
qazal
2026-04-13 15:35:20 +03:00
committed by GitHub
parent 4c1fb18a09
commit ac027055ef
5 changed files with 111 additions and 99 deletions

View File

@@ -136,7 +136,8 @@ def print_data(data:dict) -> None:
def main() -> None:
import tinygrad.viz.serve as viz
viz.ctxs = []
from tinygrad.uop.ops import RewriteTrace
data = viz.VizData()
parser = argparse.ArgumentParser()
parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)',
@@ -147,24 +148,24 @@ def main() -> None:
with args.profile.open("rb") as f: profile = pickle.load(f)
viz.get_profile(profile)
viz.get_profile(profile, data=data)
# List all kernels
if args.kernel is None:
for c in viz.ctxs:
for c in data.ctxs:
print(c["name"])
for s in c["steps"]: print(" "+s["name"])
return None
# Find kernel trace
trace = next((c for c in viz.ctxs if c["name"] == f"Exec {args.kernel}"), None)
trace = next((c for c in data.ctxs if c["name"] == f"SQTT {args.kernel}"), None)
if not trace: raise RuntimeError(f"no matching trace for {args.kernel}")
n = 0
for s in trace["steps"]:
if "PKTS" in s["name"]: continue
print(s["name"])
data = viz.get_render(s["query"])
print_data(data)
ret = viz.get_render(data, s["query"])
print_data(ret)
n += 1
if n > args.n: break

View File

@@ -52,16 +52,16 @@ def get(data:dict, key:str):
raise RuntimeError(f'item "{key}" not found in list'+(f", did you mean {match[0]!r}?" if match else ''))
def main(args) -> None:
viz.trace = viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {}))
viz.ctxs = viz.get_rewrites(viz.trace)
viz.data = viz.VizData(viz.load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
viz.load_rewrites(viz.data)
def format_colored(s:str) -> str: return ansistrip(s) if args.no_color else s
if args.profile:
events:list = viz.load_pickle(args.profile_path, default=[])
if (profile_bytes:=viz.get_profile(events)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
if (profile_bytes:=viz.get_profile(events, data=viz.data)) is None: raise RuntimeError(f"empty profile in {args.profile_path}")
profile = decode_profile(profile_bytes)
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.ctxs
profile["layout"].update([(f'{c["name"][5:]}{" SQTT" if s["name"].endswith("PKTS") else ""} {s["name"]}', s["data"]) for c in viz.data.ctxs
if c["name"].startswith("SQTT") for s in c["steps"] if s["name"].endswith(("PMC", "PKTS"))])
if args.src is None:
for k in profile["layout"]:
@@ -142,7 +142,7 @@ def main(args) -> None:
return None
# ** Graph rewrites printer
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.ctxs if c.get("steps")}
rewrites = {c["name"]:{s["name"]:s for s in c["steps"]} for c in viz.data.ctxs if c.get("steps")}
if args.src is None:
for k in rewrites: print(f" {format_colored(k)}")
return None
@@ -150,7 +150,7 @@ def main(args) -> None:
if args.item is None:
for k,v in steps.items(): print(" "*v["depth"]+k+(f" - {v['match_count']}" if v.get('match_count', 0) else ''))
else:
data = viz.get_render(get(steps, args.item)["query"])
data = viz.get_render(viz.data, get(steps, args.item)["query"])
if isinstance(data.get("value"), Iterator):
for m in data["value"]:
if m.get("uop"): print(f"Input UOp:\n{m['uop']}")

View File

@@ -1,15 +1,16 @@
import unittest, contextlib
from tinygrad import Device, Tensor, Context, TinyJit
from tinygrad.device import Compiled, ProfileProgramEvent, ProfileDeviceEvent
from tinygrad.viz.serve import load_amd_counters
from tinygrad.viz.serve import load_amd_counters, VizData
@contextlib.contextmanager
def save_sqtt():
yield (ret:=[])
data = VizData()
yield data.ctxs
Device[Device.DEFAULT].synchronize()
Device[Device.DEFAULT]._at_profile_finalize()
load_amd_counters(ret, Compiled.profile_events)
ret[:] = [r for r in ret if r["name"].startswith("SQTT")]
load_amd_counters(data, Compiled.profile_events)
data.ctxs[:] = [r for r in data.ctxs if r["name"].startswith("SQTT")]
@unittest.skipUnless(Device.DEFAULT == "AMD", "only runs on AMD")
class TestSQTTProfiler(unittest.TestCase):

View File

@@ -10,7 +10,7 @@ from tinygrad.helpers import VIZ, cpu_profile, ProfilePointEvent, unwrap
from tinygrad.device import Buffer
from tinygrad.uop.ops import tracked_keys, tracked_ctxs, uop_fields, active_rewrites, active_group, _name_cnt, RewriteTrace
from tinygrad.viz.serve import get_rewrites, get_full_rewrite, uop_to_json
from tinygrad.viz.serve import load_rewrites, get_full_rewrite, uop_to_json, VizData
@track_rewrites(name=True)
def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=None) -> UOp:
@@ -21,19 +21,19 @@ def exec_rewrite(sink:UOp, pm_lst:list[PatternMatcher], names:None|list[str]=Non
# small container class for the viz server module
class VizTrace:
# loader init
def __init__(self): self._trace:RewriteTrace|None = None
def __init__(self): self._data:VizData|None = None
@property
def trace(self) -> RewriteTrace: return unwrap(self._trace)
def set_trace(self) -> None:
self._trace = RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy())
import tinygrad.viz.serve as serve_module
serve_module.trace = self._trace
def data(self) -> VizData: return unwrap(self._data)
def set_data(self) -> None:
data = VizData(RewriteTrace(tracked_keys.copy(), tracked_ctxs.copy(), uop_fields.copy()))
load_rewrites(data)
self._data = data
# the API
def list_items(self) -> list[dict]: return get_rewrites(self.trace)
def list_items(self) -> list[dict]:
return self.data.ctxs
def get_details(self, rewrite_idx:int, step:int) -> Generator[dict, None, None]:
lst = self.list_items()
assert len(lst) > rewrite_idx, f"only loaded {len(lst)} traces, expecting at least {rewrite_idx}"
return get_full_rewrite(self.trace.rewrites[rewrite_idx][step])
assert len(self.data.trace.rewrites) > rewrite_idx, f"only loaded {len(self.data.trace.rewrites)} traces, expecting at least {rewrite_idx}"
return get_full_rewrite(self.data, self.data.trace.rewrites[rewrite_idx][step])
@contextlib.contextmanager
def save_viz():
@@ -52,7 +52,7 @@ def save_viz():
try:
yield viz
finally:
viz.set_trace()
viz.set_data()
TRACK_MATCH_STATS.value = prev_tms
PROFILE.value = prev_profile
VIZ.value = prev_viz
@@ -194,7 +194,7 @@ class TestViz(unittest.TestCase):
class TestStruct:
colored_field: str
a = UOp(Ops.CUSTOM, arg=TestStruct(colored("xyz", "magenta")+colored("12345", "blue")))
a2 = uop_to_json(a)[id(a)]
a2 = uop_to_json(a, VizData())[id(a)]
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
def test_colored_label_multiline(self):
@@ -217,11 +217,11 @@ class TestViz(unittest.TestCase):
# use smaller stack limit for faster test (default is 250000)
with Context(REWRITE_STACK_LIMIT=100): self.assertRaises(RuntimeError, exec_rewrite, a, [pm])
graphs = flatten(x["graph"].values() for x in viz.get_details(0, 0))
self.assertEqual(graphs[0], uop_to_json(a)[id(a)])
self.assertEqual(graphs[1], uop_to_json(b)[id(b)])
self.assertEqual(graphs[0], uop_to_json(a, VizData())[id(a)])
self.assertEqual(graphs[1], uop_to_json(b, VizData())[id(b)])
# fallback to NOOP with the error message
nop = UOp(Ops.NOOP, arg="infinite loop in fixed_point_rewrite")
self.assertEqual(graphs[2], uop_to_json(nop)[id(nop)])
self.assertEqual(graphs[2], uop_to_json(nop, VizData())[id(nop)])
def test_const_node_visibility(self):
with save_viz() as viz:
@@ -241,7 +241,7 @@ class TestViz(unittest.TestCase):
c = UOp.const(dtypes.float, 1.0, device="CPU", shape=(3,4)) # creates CONST->RESHAPE->EXPAND chain
a = UOp(Ops.DEFINE_VAR, dtypes.float, arg=("a", 0.0, 10.0))
alu = a + c
graph = uop_to_json(alu)
graph = uop_to_json(alu, VizData())
# the RESHAPE and EXPAND nodes from the const should not appear in the graph
labels = {v["label"].split("\n")[0] for v in graph.values()}
self.assertNotIn("RESHAPE", labels)

View File

@@ -1,8 +1,9 @@
#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct, re
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, codecs, io, struct, re
import pathlib, traceback, itertools, socketserver
from contextlib import redirect_stdout, redirect_stderr, contextmanager
from decimal import Decimal
from dataclasses import dataclass, field
from urllib.parse import parse_qs, urlparse
from http.server import BaseHTTPRequestHandler
from typing import Any, TypedDict, TypeVar, Generator, Callable
@@ -61,21 +62,27 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
def create_step(name:str, query:tuple[str, int, int], data=None, depth:int=0, **kwargs) -> dict:
return {"name":name, "query":f"{query[0]}?ctx={query[1]}&step={query[2]}", "data":data, "depth":depth, **kwargs}
# ** list all saved rewrites
@dataclass(frozen=True)
class VizData:
trace:RewriteTrace = field(default_factory=lambda: RewriteTrace([], [], {}))
ctxs:list[dict] = field(default_factory=list)
ref_map:dict[Any, int] = field(default_factory=dict)
all_uops:dict[int, UOp] = field(default_factory=dict)
ref_map:dict[Any, int] = {}
def get_rewrites(t:RewriteTrace) -> list[dict]:
ret = []
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
# ** load all saved rewrites
def load_rewrites(data:VizData) -> None:
assert not data.ctxs and not data.ref_map, "load_rewrites called multiple times"
for i,k in enumerate(data.trace.keys):
v = data.trace.rewrites[i]
steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)]
if (p:=get_prg_uop(i)) is not None:
if (p:=get_prg_uop(data, i)) is not None:
steps.append(create_step("View UOp List", ("/uops", i, len(steps))))
steps.append(create_step("View Source", ("/code", i, len(steps)), p.src[3].arg))
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, p.src[4].arg)))
for key in k.keys: ref_map[key] = i
ret.append({"name":k.display_name, "steps":steps})
return ret
for key in k.keys: data.ref_map[key] = i
data.ctxs.append({"name":k.display_name, "steps":steps})
# ** get the complete UOp graphs for one rewrite
@@ -93,9 +100,7 @@ def pystr(u:UOp) -> str:
try: return pyrender(u)
except Exception: return str(u)
# all the trace points, initialized after the trace loads
ctxs:list[dict] = []
def uop_to_json(x:UOp) -> dict[int, dict]:
def uop_to_json(x:UOp, data:VizData) -> dict[int, dict]:
assert isinstance(x, UOp)
graph: dict[int, dict] = {}
excluded: set[UOp] = set()
@@ -138,7 +143,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
except Exception:
label += "\n<ISSUE GETTING LABEL>"
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None and ctxs: label += f"\ncodegen@{ctxs[ref]['name']}"
if (ref:=data.ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
# NOTE: kernel already has metadata in arg
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
# limit SOURCE labels line count
@@ -148,28 +153,30 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
return graph
@functools.cache
def _reconstruct(a:int, depth:int|None=None):
op, dtype, src, arg, *rest = trace.uop_fields[a]
def _reconstruct(data:VizData, a:int, depth:int|None=None):
if depth is None and a in data.all_uops: return data.all_uops[a]
op, dtype, src, arg, *rest = data.trace.uop_fields[a]
if depth is not None and depth <= 0: return UOp(op, dtype, (), arg, *rest)
return UOp(op, dtype, tuple(_reconstruct(s, None if depth is None else depth-1) for s in src), arg, *rest)
ret = UOp(op, dtype, tuple(_reconstruct(data, s, None if depth is None else depth-1) for s in src), arg, *rest)
if depth is None: data.all_uops[a] = ret
return ret
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(ctx.sink)
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
next_sink = _reconstruct(data, ctx.sink)
yield {"graph":uop_to_json(next_sink, data), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
replaces: dict[UOp, UOp] = {}
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num)
try: new_sink = next_sink.substitute(replaces)
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
yield {"graph":(sink_json:=uop_to_json(new_sink, data)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
if not ctx.bottom_up: next_sink = new_sink
def get_prg_uop(i:int) -> UOp|None:
s = next((s for s in trace.rewrites[i] if s.name == "View Program"), None)
return _reconstruct(s.sink, depth=1) if s is not None else None
def get_prg_uop(data:VizData, i:int) -> UOp|None:
s = next((s for s in data.trace.rewrites[i] if s.name == "View Program"), None)
return _reconstruct(data, s.sink, depth=1) if s is not None else None
# encoder helpers
@@ -187,31 +194,30 @@ def rel_ts(ts:int|Decimal, start_ts:int, ctx:str="") -> int:
# Profiler API
device_ts_diffs:dict[str, Decimal] = {}
def cpu_ts_diff(device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
def cpu_ts_diff(device_ts_diffs:dict[str, Decimal], device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
def flatten_events(profile:list[ProfileEvent], device_ts_diffs:dict[str, Decimal]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
for e in profile:
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device)), (e.en if e.en is not None else e.st)+diff, e)
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(device_ts_diffs, e.device)), (e.en if e.en is not None else e.st)+diff, e)
elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
elif isinstance(e, ProfileGraphEvent):
cpu_ts = []
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device)), e.sigs[ent.en_id]+diff]
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(device_ts_diffs, ent.device)), e.sigs[ent.en_id]+diff]
yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
# normalize event timestamps and attach kernel metadata
def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
def timeline_layout(data:VizData, dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
events:list[bytes] = []
exec_points:dict[str, ProfilePointEvent] = {}
for st,et,dur,e in dev_events:
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e
if dur == 0: continue
name, fmt, key = e.name, [], None
if (ref:=ref_map.get(name)) is not None and ctxs:
name = ctxs[ref]["name"]
if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
if (ref:=data.ref_map.get(name)) is not None and ref < len(data.ctxs):
name = data.ctxs[ref]["name"]
if (p:=get_prg_uop(data, ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
@@ -222,7 +228,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
key = ei.key
elif isinstance(e.name, TracingKey):
name = e.name.display_name
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
ref = next((v for k in e.name.keys if (v:=data.ref_map.get(k)) is not None), None)
if isinstance(e.name.ret, str): fmt.append(e.name.ret)
elif isinstance(e.name.ret, int):
membw = (nbytes:=e.name.ret) / (dur * 1e-6)
@@ -313,7 +319,7 @@ def unpack_pmc(e) -> dict:
# ** on startup, list all the performance counter traces
def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None:
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
counter_events:dict[tuple[int, int], dict] = {}
durations:dict[str, list[float]] = {}
@@ -326,22 +332,23 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e
if isinstance(e, ProfileDeviceEvent) and e.device.startswith("AMD"): arch = f"gfx{unwrap(e.props)['gfx_target_version']//1000}"
if len(counter_events) == 0: return None
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
data.ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(data.ctxs), 0), (durations, all_counters:={}))]})
run_number = {n:0 for n,_ in counter_events}
for (k, tag),v in counter_events.items():
# use the colored name if it exists
name = unwrap(get_prg_uop(r)).src[0].arg.name if (r:=ref_map.get(pname:=prg_events[k].name)) is not None else pname
name = unwrap(get_prg_uop(data, r)).src[0].arg.name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname
run_number[k] += 1
steps:list[dict] = []
if (pmc:=v.get(ProfilePMCEvent)):
steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
steps.append(create_step("PMC", ("/prg-pmc", len(data.ctxs), len(steps)), pmc))
all_counters[(name, run_number[k], pname)] = pmc[0]
# to decode a SQTT trace, we need the raw stream, program binary and device properties
if (sqtt:=v.get(ProfileSQTTEvent)):
for e in sqtt:
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch)))
steps.append(create_step("OCC", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(data.ctxs), len(steps)),
data=(e.blob, prg_events[k].lib, arch)))
steps.append(create_step("OCC", ("/prg-sqtt", len(data.ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
data.ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc",
**{x:"#b2b7c9" for x in ["VMEM", "SGMEM"]}, "LDS": "#9fb4a6", "IMMEDIATE": "#f3b44a", "BARRIER": "#d00000",
@@ -457,23 +464,25 @@ def device_sort_fn(k:str) -> tuple:
dev_base = p[0] if len(p) < 2 or not p[1].isdigit() else f"{p[0]}:{p[1]}"
return (is_memory, special.get(p[0], special['ALLDEVS']), dev_base, k)
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn, data:VizData|None=None) -> bytes|None:
if data is None: data = VizData(RewriteTrace([], [], {}))
# start by getting the time diffs
device_decoders:dict[str, Callable[[list[dict], list[ProfileEvent]], None]] = {}
device_ts_diffs:dict[str, Decimal] = {}
device_decoders:dict[str, Callable[[VizData, list[ProfileEvent]], None]] = {}
for ev in profile:
if isinstance(ev, ProfileDeviceEvent):
device_ts_diffs[ev.device] = ev.tdiff
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_amd_counters
if d == "NV": device_decoders[d] = load_nv_counters
# load device specific counters
for fxn in device_decoders.values(): fxn(ctxs, profile)
for fxn in device_decoders.values(): fxn(data, profile)
# map events per device
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
markers:list[ProfilePointEvent] = []
ext_data:dict[str, Any] = {}
start_ts:int|None = None
end_ts:int|None = None
for ts,en,e in flatten_events(profile):
for ts,en,e in flatten_events(profile, device_ts_diffs):
dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
if start_ts is None or st < start_ts: start_ts = st
if end_ts is None or et > end_ts: end_ts = et
@@ -487,7 +496,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
dtype_size:dict[str, int] = {}
for k,v in dev_events.items():
v.sort(key=lambda e:e[0])
layout[k] = timeline_layout(v, start_ts, scache)
layout[k] = timeline_layout(data, v, start_ts, scache)
layout.update([graph_layout(k, v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)])
sorted_layout = sorted([k for k,v in layout.items() if v is not None], key=sort_fn)
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), unwrap(layout[k])]) for k in sorted_layout]
@@ -498,16 +507,16 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
# ** PMA counters
def load_nv_counters(ctxs:list[dict], profile:list) -> None:
def load_nv_counters(data:VizData, profile:list) -> None:
steps:list[dict] = []
sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None}
run_number:dict[str, int] = {}
for e in profile:
if type(e).__name__ == "ProfilePMAEvent":
run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)),
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(data.ctxs), len(steps)),
data=(e.blob, sm_version[e.device])))
if steps: ctxs.append({"name":"All Counters", "steps":steps})
if steps: data.ctxs.append({"name":"All Counters", "steps":steps})
def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]:
from extra.nv_pma.decode import decode, decode_tpc_id
@@ -586,10 +595,10 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
from tinygrad.renderer.amd.dsl import Reg
for pc, inst in pc_table.items():
pc_tokens[pc] = tokens = []
for name, field in inst._fields:
for name, f in inst._fields:
if isinstance(val:=getattr(inst, name), Reg): tokens.append({"st":val.fmt(), "keys":[f"r{val.offset+i}" for i in range(val.sz)], "kind":1})
elif name in {"op","opx","opy"}: tokens.append({"st":(op_name:=val.name.lower()), "keys":[op_name], "kind":0})
elif name != "encoding" and val != field.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
elif name != "encoding" and val != f.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
# show a smaller view for repeated instructions in the graph
lines:list[str] = []
disasm = {pc:str(inst) for pc,inst in pc_table.items()}
@@ -616,12 +625,12 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
# ** Main render function to get the complete details about a trace event
def get_render(query:str) -> dict:
def get_render(viz_data:VizData, query:str) -> dict:
url = urlparse(query)
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
data = ctxs[i]["steps"][j]["data"]
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(trace.rewrites[i][j-1].sink).src[2].src)), "lang":"txt"}
data = viz_data.ctxs[i]["steps"][j]["data"]
if fmt == "graph-rewrites": return {"value":get_full_rewrite(viz_data, viz_data.trace.rewrites[i][j]), "content_type":"text/event-stream"}
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(viz_data, viz_data.trace.rewrites[i][j-1].sink).src[2].src))}
if fmt == "code": return {"src":data, "lang":"cpp"}
if fmt == "asm":
ret:dict = {}
@@ -643,13 +652,13 @@ def get_render(query:str) -> dict:
if fmt.startswith("prg-pkts"):
ret = {}
with soft_err(lambda err:ret.update(err)):
if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple)):
if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple, data=viz_data)):
ret = {"value":events, "content_type":"application/octet-stream"}
else: ret = {"src":"No SQTT trace on this SE."}
return ret
if fmt == "prg-sqtt":
ret = {}
if len((steps:=ctxs[i]["steps"])[j+1:]) == 0:
if len((steps:=viz_data.ctxs[i]["steps"])[j+1:]) == 0:
with soft_err(lambda err: ret.update(err)):
cu_events, units, wave_insts = unpack_sqtt(*data)
for cu in sorted(cu_events, key=row_tuple):
@@ -658,7 +667,7 @@ def get_render(query:str) -> dict:
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", i, len(steps)), loc=(data:=wave_insts[cu][k])["loc"], depth=2, data=data))
return {**ret, "steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple), "content_type":"application/octet-stream"}
if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple, data=viz_data), "content_type":"application/octet-stream"}
if fmt == "sqtt-insts":
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
@@ -685,11 +694,11 @@ def get_render(query:str) -> dict:
prev_instr = max(prev_instr, e.time + e.dur)
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)}
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)}
if fmt == "prg-pma-pkts":
ret = {}
with soft_err(lambda err:ret.update(err)):
if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
if (events:=get_profile(pma_timeline(*data), row_tuple, data=viz_data)): ret = {"value":events, "content_type":"application/octet-stream"}
else: ret = {"src":"No PMA samples found."}
return ret
return data
@@ -712,11 +721,11 @@ class Handler(HTTPRequestHandler):
except FileNotFoundError: status_code = 404
elif url.path == "/ctxs":
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in data.ctxs]
ret, content_type = json.dumps(lst).encode(), "application/json"
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
else:
if not (render_src:=get_render(self.path)): status_code = 404
if not (render_src:=get_render(data, self.path)): status_code = 404
else:
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
@@ -754,8 +763,9 @@ if __name__ == "__main__":
st = time.perf_counter()
print("*** viz is starting")
ctxs = get_rewrites(trace:=load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
profile_ret = get_profile(load_pickle(args.profile_path, default=[]))
data = VizData(load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
load_rewrites(data)
profile_ret = get_profile(load_pickle(args.profile_path, default=[]), data=data)
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)