|
|
|
|
@@ -1,8 +1,9 @@
|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct, re
|
|
|
|
|
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, codecs, io, struct, re
|
|
|
|
|
import pathlib, traceback, itertools, socketserver
|
|
|
|
|
from contextlib import redirect_stdout, redirect_stderr, contextmanager
|
|
|
|
|
from decimal import Decimal
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
|
from http.server import BaseHTTPRequestHandler
|
|
|
|
|
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
|
|
|
|
@@ -61,21 +62,27 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|
|
|
|
def create_step(name:str, query:tuple[str, int, int], data=None, depth:int=0, **kwargs) -> dict:
|
|
|
|
|
return {"name":name, "query":f"{query[0]}?ctx={query[1]}&step={query[2]}", "data":data, "depth":depth, **kwargs}
|
|
|
|
|
|
|
|
|
|
# ** list all saved rewrites
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class VizData:
|
|
|
|
|
trace:RewriteTrace = field(default_factory=lambda: RewriteTrace([], [], {}))
|
|
|
|
|
ctxs:list[dict] = field(default_factory=list)
|
|
|
|
|
ref_map:dict[Any, int] = field(default_factory=dict)
|
|
|
|
|
all_uops:dict[int, UOp] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
ref_map:dict[Any, int] = {}
|
|
|
|
|
def get_rewrites(t:RewriteTrace) -> list[dict]:
|
|
|
|
|
ret = []
|
|
|
|
|
for i,(k,v) in enumerate(zip(t.keys, t.rewrites)):
|
|
|
|
|
# ** load all saved rewrites
|
|
|
|
|
|
|
|
|
|
def load_rewrites(data:VizData) -> None:
|
|
|
|
|
assert not data.ctxs and not data.ref_map, "load_rewrites called multiple times"
|
|
|
|
|
for i,k in enumerate(data.trace.keys):
|
|
|
|
|
v = data.trace.rewrites[i]
|
|
|
|
|
steps = [create_step(s.name, ("/graph-rewrites", i, j), loc=s.loc, match_count=len(s.matches), code_line=printable(s.loc),
|
|
|
|
|
trace=k.tb if j==0 else None, depth=s.depth) for j,s in enumerate(v)]
|
|
|
|
|
if (p:=get_prg_uop(i)) is not None:
|
|
|
|
|
if (p:=get_prg_uop(data, i)) is not None:
|
|
|
|
|
steps.append(create_step("View UOp List", ("/uops", i, len(steps))))
|
|
|
|
|
steps.append(create_step("View Source", ("/code", i, len(steps)), p.src[3].arg))
|
|
|
|
|
steps.append(create_step("View Disassembly", ("/asm", i, len(steps)), (k.ret, p.src[4].arg)))
|
|
|
|
|
for key in k.keys: ref_map[key] = i
|
|
|
|
|
ret.append({"name":k.display_name, "steps":steps})
|
|
|
|
|
return ret
|
|
|
|
|
for key in k.keys: data.ref_map[key] = i
|
|
|
|
|
data.ctxs.append({"name":k.display_name, "steps":steps})
|
|
|
|
|
|
|
|
|
|
# ** get the complete UOp graphs for one rewrite
|
|
|
|
|
|
|
|
|
|
@@ -93,9 +100,7 @@ def pystr(u:UOp) -> str:
|
|
|
|
|
try: return pyrender(u)
|
|
|
|
|
except Exception: return str(u)
|
|
|
|
|
|
|
|
|
|
# all the trace points, initialized after the trace loads
|
|
|
|
|
ctxs:list[dict] = []
|
|
|
|
|
def uop_to_json(x:UOp) -> dict[int, dict]:
|
|
|
|
|
def uop_to_json(x:UOp, data:VizData) -> dict[int, dict]:
|
|
|
|
|
assert isinstance(x, UOp)
|
|
|
|
|
graph: dict[int, dict] = {}
|
|
|
|
|
excluded: set[UOp] = set()
|
|
|
|
|
@@ -138,7 +143,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|
|
|
|
label += "\n"+' '.join([f"{range_str(s, color=True)}({s.vmax+1})" for s in trngs])
|
|
|
|
|
except Exception:
|
|
|
|
|
label += "\n<ISSUE GETTING LABEL>"
|
|
|
|
|
if (ref:=ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None and ctxs: label += f"\ncodegen@{ctxs[ref]['name']}"
|
|
|
|
|
if (ref:=data.ref_map.get(u.src[0]) if u.op is Ops.CALL else None) is not None: label += f"\ncodegen@{data.ctxs[ref]['name']}"
|
|
|
|
|
# NOTE: kernel already has metadata in arg
|
|
|
|
|
if TRACEMETA >= 2 and u.metadata is not None and u.op is not Ops.CALL: label += "\n"+str(u.metadata)
|
|
|
|
|
# limit SOURCE labels line count
|
|
|
|
|
@@ -148,28 +153,30 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|
|
|
|
"ref":ref, "tag":repr(u.tag) if u.tag is not None else None}
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
@functools.cache
|
|
|
|
|
def _reconstruct(a:int, depth:int|None=None):
|
|
|
|
|
op, dtype, src, arg, *rest = trace.uop_fields[a]
|
|
|
|
|
def _reconstruct(data:VizData, a:int, depth:int|None=None):
|
|
|
|
|
if depth is None and a in data.all_uops: return data.all_uops[a]
|
|
|
|
|
op, dtype, src, arg, *rest = data.trace.uop_fields[a]
|
|
|
|
|
if depth is not None and depth <= 0: return UOp(op, dtype, (), arg, *rest)
|
|
|
|
|
return UOp(op, dtype, tuple(_reconstruct(s, None if depth is None else depth-1) for s in src), arg, *rest)
|
|
|
|
|
ret = UOp(op, dtype, tuple(_reconstruct(data, s, None if depth is None else depth-1) for s in src), arg, *rest)
|
|
|
|
|
if depth is None: data.all_uops[a] = ret
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
|
|
|
|
next_sink = _reconstruct(ctx.sink)
|
|
|
|
|
yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
|
|
|
|
def get_full_rewrite(data:VizData, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]:
|
|
|
|
|
next_sink = _reconstruct(data, ctx.sink)
|
|
|
|
|
yield {"graph":uop_to_json(next_sink, data), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None}
|
|
|
|
|
replaces: dict[UOp, UOp] = {}
|
|
|
|
|
for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches):
|
|
|
|
|
replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num)
|
|
|
|
|
replaces[u0:=_reconstruct(data, u0_num)] = u1 = _reconstruct(data, u1_num)
|
|
|
|
|
try: new_sink = next_sink.substitute(replaces)
|
|
|
|
|
except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e))
|
|
|
|
|
match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc)
|
|
|
|
|
yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
|
|
|
|
yield {"graph":(sink_json:=uop_to_json(new_sink, data)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json],
|
|
|
|
|
"diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)}
|
|
|
|
|
if not ctx.bottom_up: next_sink = new_sink
|
|
|
|
|
|
|
|
|
|
def get_prg_uop(i:int) -> UOp|None:
|
|
|
|
|
s = next((s for s in trace.rewrites[i] if s.name == "View Program"), None)
|
|
|
|
|
return _reconstruct(s.sink, depth=1) if s is not None else None
|
|
|
|
|
def get_prg_uop(data:VizData, i:int) -> UOp|None:
|
|
|
|
|
s = next((s for s in data.trace.rewrites[i] if s.name == "View Program"), None)
|
|
|
|
|
return _reconstruct(data, s.sink, depth=1) if s is not None else None
|
|
|
|
|
|
|
|
|
|
# encoder helpers
|
|
|
|
|
|
|
|
|
|
@@ -187,31 +194,30 @@ def rel_ts(ts:int|Decimal, start_ts:int, ctx:str="") -> int:
|
|
|
|
|
|
|
|
|
|
# Profiler API
|
|
|
|
|
|
|
|
|
|
device_ts_diffs:dict[str, Decimal] = {}
|
|
|
|
|
def cpu_ts_diff(device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
|
|
|
|
|
def cpu_ts_diff(device_ts_diffs:dict[str, Decimal], device:str) -> Decimal: return device_ts_diffs.get(device, Decimal(0))
|
|
|
|
|
|
|
|
|
|
DevEvent = ProfileRangeEvent|ProfileGraphEntry|ProfilePointEvent
|
|
|
|
|
def flatten_events(profile:list[ProfileEvent]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
|
|
|
|
def flatten_events(profile:list[ProfileEvent], device_ts_diffs:dict[str, Decimal]) -> Generator[tuple[Decimal, Decimal, DevEvent], None, None]:
|
|
|
|
|
for e in profile:
|
|
|
|
|
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(e.device)), (e.en if e.en is not None else e.st)+diff, e)
|
|
|
|
|
if isinstance(e, ProfileRangeEvent): yield (e.st+(diff:=cpu_ts_diff(device_ts_diffs, e.device)), (e.en if e.en is not None else e.st)+diff, e)
|
|
|
|
|
elif isinstance(e, ProfilePointEvent): yield (e.ts, e.ts, e)
|
|
|
|
|
elif isinstance(e, ProfileGraphEvent):
|
|
|
|
|
cpu_ts = []
|
|
|
|
|
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(ent.device)), e.sigs[ent.en_id]+diff]
|
|
|
|
|
for ent in e.ents: cpu_ts += [e.sigs[ent.st_id]+(diff:=cpu_ts_diff(device_ts_diffs, ent.device)), e.sigs[ent.en_id]+diff]
|
|
|
|
|
yield (st:=min(cpu_ts)), (et:=max(cpu_ts)), ProfileRangeEvent(f"{e.ents[0].device.split(':')[0]} Graph", f"batched {len(e.ents)}", st, et)
|
|
|
|
|
for i,ent in enumerate(e.ents): yield (cpu_ts[i*2], cpu_ts[i*2+1], ent)
|
|
|
|
|
|
|
|
|
|
# normalize event timestamps and attach kernel metadata
|
|
|
|
|
def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
|
|
|
|
|
def timeline_layout(data:VizData, dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, scache:dict[str, int]) -> bytes|None:
|
|
|
|
|
events:list[bytes] = []
|
|
|
|
|
exec_points:dict[str, ProfilePointEvent] = {}
|
|
|
|
|
for st,et,dur,e in dev_events:
|
|
|
|
|
if isinstance(e, ProfilePointEvent) and e.name == "exec": exec_points[e.arg["name"]] = e
|
|
|
|
|
if dur == 0: continue
|
|
|
|
|
name, fmt, key = e.name, [], None
|
|
|
|
|
if (ref:=ref_map.get(name)) is not None and ctxs:
|
|
|
|
|
name = ctxs[ref]["name"]
|
|
|
|
|
if (p:=get_prg_uop(ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
|
|
|
|
|
if (ref:=data.ref_map.get(name)) is not None and ref < len(data.ctxs):
|
|
|
|
|
name = data.ctxs[ref]["name"]
|
|
|
|
|
if (p:=get_prg_uop(data, ref)) is not None and (ei:=exec_points.get(p.src[0].arg.name)) is not None:
|
|
|
|
|
flops = sym_infer((estimates:=p.src[0].arg.estimates).ops, var_vals:=ei.arg['var_vals'])/(t:=dur*1e-6)
|
|
|
|
|
membw, ldsbw = sym_infer(estimates.mem, var_vals)/t, sym_infer(estimates.lds, var_vals)/t
|
|
|
|
|
fmt = [f"{flops*1e-9:.0f} GFLOPS" if flops < 1e14 else f"{flops*1e-12:.0f} TFLOPS",
|
|
|
|
|
@@ -222,7 +228,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:
|
|
|
|
|
key = ei.key
|
|
|
|
|
elif isinstance(e.name, TracingKey):
|
|
|
|
|
name = e.name.display_name
|
|
|
|
|
ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None)
|
|
|
|
|
ref = next((v for k in e.name.keys if (v:=data.ref_map.get(k)) is not None), None)
|
|
|
|
|
if isinstance(e.name.ret, str): fmt.append(e.name.ret)
|
|
|
|
|
elif isinstance(e.name.ret, int):
|
|
|
|
|
membw = (nbytes:=e.name.ret) / (dur * 1e-6)
|
|
|
|
|
@@ -313,7 +319,7 @@ def unpack_pmc(e) -> dict:
|
|
|
|
|
|
|
|
|
|
# ** on startup, list all the performance counter traces
|
|
|
|
|
|
|
|
|
|
def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
|
|
|
|
def load_amd_counters(data:VizData, profile:list[ProfileEvent]) -> None:
|
|
|
|
|
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
|
|
|
|
|
counter_events:dict[tuple[int, int], dict] = {}
|
|
|
|
|
durations:dict[str, list[float]] = {}
|
|
|
|
|
@@ -326,22 +332,23 @@ def load_amd_counters(ctxs:list[dict], profile:list[ProfileEvent]) -> None:
|
|
|
|
|
if isinstance(e, ProfileProgramEvent) and e.tag is not None: prg_events[e.tag] = e
|
|
|
|
|
if isinstance(e, ProfileDeviceEvent) and e.device.startswith("AMD"): arch = f"gfx{unwrap(e.props)['gfx_target_version']//1000}"
|
|
|
|
|
if len(counter_events) == 0: return None
|
|
|
|
|
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
|
|
|
|
|
data.ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(data.ctxs), 0), (durations, all_counters:={}))]})
|
|
|
|
|
run_number = {n:0 for n,_ in counter_events}
|
|
|
|
|
for (k, tag),v in counter_events.items():
|
|
|
|
|
# use the colored name if it exists
|
|
|
|
|
name = unwrap(get_prg_uop(r)).src[0].arg.name if (r:=ref_map.get(pname:=prg_events[k].name)) is not None else pname
|
|
|
|
|
name = unwrap(get_prg_uop(data, r)).src[0].arg.name if (r:=data.ref_map.get(pname:=prg_events[k].name)) is not None else pname
|
|
|
|
|
run_number[k] += 1
|
|
|
|
|
steps:list[dict] = []
|
|
|
|
|
if (pmc:=v.get(ProfilePMCEvent)):
|
|
|
|
|
steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
|
|
|
|
|
steps.append(create_step("PMC", ("/prg-pmc", len(data.ctxs), len(steps)), pmc))
|
|
|
|
|
all_counters[(name, run_number[k], pname)] = pmc[0]
|
|
|
|
|
# to decode a SQTT trace, we need the raw stream, program binary and device properties
|
|
|
|
|
if (sqtt:=v.get(ProfileSQTTEvent)):
|
|
|
|
|
for e in sqtt:
|
|
|
|
|
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e.blob, prg_events[k].lib,arch)))
|
|
|
|
|
steps.append(create_step("OCC", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
|
|
|
|
ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
|
|
|
|
if e.itrace: steps.append(create_step(f"SE:{e.se} PKTS", (f"/prg-pkts-{e.se}", len(data.ctxs), len(steps)),
|
|
|
|
|
data=(e.blob, prg_events[k].lib, arch)))
|
|
|
|
|
steps.append(create_step("OCC", ("/prg-sqtt", len(data.ctxs), len(steps)), ((k, tag), sqtt, prg_events[k], arch)))
|
|
|
|
|
data.ctxs.append({"name":f"SQTT {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
|
|
|
|
|
|
|
|
|
|
wave_colors = {"WMMA": "#1F7857", **{x:"#ffffc0" for x in ["VALU", "VINTERP"]}, "SALU": "#cef263", "SMEM": "#ffc0c0", "STORE": "#4fa3cc",
|
|
|
|
|
**{x:"#b2b7c9" for x in ["VMEM", "SGMEM"]}, "LDS": "#9fb4a6", "IMMEDIATE": "#f3b44a", "BARRIER": "#d00000",
|
|
|
|
|
@@ -457,23 +464,25 @@ def device_sort_fn(k:str) -> tuple:
|
|
|
|
|
dev_base = p[0] if len(p) < 2 or not p[1].isdigit() else f"{p[0]}:{p[1]}"
|
|
|
|
|
return (is_memory, special.get(p[0], special['ALLDEVS']), dev_base, k)
|
|
|
|
|
|
|
|
|
|
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn) -> bytes|None:
|
|
|
|
|
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_sort_fn, data:VizData|None=None) -> bytes|None:
|
|
|
|
|
if data is None: data = VizData(RewriteTrace([], [], {}))
|
|
|
|
|
# start by getting the time diffs
|
|
|
|
|
device_decoders:dict[str, Callable[[list[dict], list[ProfileEvent]], None]] = {}
|
|
|
|
|
device_ts_diffs:dict[str, Decimal] = {}
|
|
|
|
|
device_decoders:dict[str, Callable[[VizData, list[ProfileEvent]], None]] = {}
|
|
|
|
|
for ev in profile:
|
|
|
|
|
if isinstance(ev, ProfileDeviceEvent):
|
|
|
|
|
device_ts_diffs[ev.device] = ev.tdiff
|
|
|
|
|
if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_amd_counters
|
|
|
|
|
if d == "NV": device_decoders[d] = load_nv_counters
|
|
|
|
|
# load device specific counters
|
|
|
|
|
for fxn in device_decoders.values(): fxn(ctxs, profile)
|
|
|
|
|
for fxn in device_decoders.values(): fxn(data, profile)
|
|
|
|
|
# map events per device
|
|
|
|
|
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
|
|
|
|
|
markers:list[ProfilePointEvent] = []
|
|
|
|
|
ext_data:dict[str, Any] = {}
|
|
|
|
|
start_ts:int|None = None
|
|
|
|
|
end_ts:int|None = None
|
|
|
|
|
for ts,en,e in flatten_events(profile):
|
|
|
|
|
for ts,en,e in flatten_events(profile, device_ts_diffs):
|
|
|
|
|
dev_events.setdefault(e.device,[]).append((st:=int(ts), et:=int(en), float(en-ts), e))
|
|
|
|
|
if start_ts is None or st < start_ts: start_ts = st
|
|
|
|
|
if end_ts is None or et > end_ts: end_ts = et
|
|
|
|
|
@@ -487,7 +496,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
|
|
|
|
dtype_size:dict[str, int] = {}
|
|
|
|
|
for k,v in dev_events.items():
|
|
|
|
|
v.sort(key=lambda e:e[0])
|
|
|
|
|
layout[k] = timeline_layout(v, start_ts, scache)
|
|
|
|
|
layout[k] = timeline_layout(data, v, start_ts, scache)
|
|
|
|
|
layout.update([graph_layout(k, v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)])
|
|
|
|
|
sorted_layout = sorted([k for k,v in layout.items() if v is not None], key=sort_fn)
|
|
|
|
|
ret = [b"".join([struct.pack("<B", len(k)), k.encode(), unwrap(layout[k])]) for k in sorted_layout]
|
|
|
|
|
@@ -498,16 +507,16 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
|
|
|
|
|
|
|
|
|
|
# ** PMA counters
|
|
|
|
|
|
|
|
|
|
def load_nv_counters(ctxs:list[dict], profile:list) -> None:
|
|
|
|
|
def load_nv_counters(data:VizData, profile:list) -> None:
|
|
|
|
|
steps:list[dict] = []
|
|
|
|
|
sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None}
|
|
|
|
|
run_number:dict[str, int] = {}
|
|
|
|
|
for e in profile:
|
|
|
|
|
if type(e).__name__ == "ProfilePMAEvent":
|
|
|
|
|
run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1
|
|
|
|
|
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)),
|
|
|
|
|
steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(data.ctxs), len(steps)),
|
|
|
|
|
data=(e.blob, sm_version[e.device])))
|
|
|
|
|
if steps: ctxs.append({"name":"All Counters", "steps":steps})
|
|
|
|
|
if steps: data.ctxs.append({"name":"All Counters", "steps":steps})
|
|
|
|
|
|
|
|
|
|
def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]:
|
|
|
|
|
from extra.nv_pma.decode import decode, decode_tpc_id
|
|
|
|
|
@@ -586,10 +595,10 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
|
|
|
|
|
from tinygrad.renderer.amd.dsl import Reg
|
|
|
|
|
for pc, inst in pc_table.items():
|
|
|
|
|
pc_tokens[pc] = tokens = []
|
|
|
|
|
for name, field in inst._fields:
|
|
|
|
|
for name, f in inst._fields:
|
|
|
|
|
if isinstance(val:=getattr(inst, name), Reg): tokens.append({"st":val.fmt(), "keys":[f"r{val.offset+i}" for i in range(val.sz)], "kind":1})
|
|
|
|
|
elif name in {"op","opx","opy"}: tokens.append({"st":(op_name:=val.name.lower()), "keys":[op_name], "kind":0})
|
|
|
|
|
elif name != "encoding" and val != field.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
|
|
|
|
|
elif name != "encoding" and val != f.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1})
|
|
|
|
|
# show a smaller view for repeated instructions in the graph
|
|
|
|
|
lines:list[str] = []
|
|
|
|
|
disasm = {pc:str(inst) for pc,inst in pc_table.items()}
|
|
|
|
|
@@ -616,12 +625,12 @@ def amdgpu_cfg(lib:bytes, target:str) -> dict:
|
|
|
|
|
|
|
|
|
|
# ** Main render function to get the complete details about a trace event
|
|
|
|
|
|
|
|
|
|
def get_render(query:str) -> dict:
|
|
|
|
|
def get_render(viz_data:VizData, query:str) -> dict:
|
|
|
|
|
url = urlparse(query)
|
|
|
|
|
i, j, fmt = get_int(qs:=parse_qs(url.query), "ctx"), get_int(qs, "step"), url.path.lstrip("/")
|
|
|
|
|
data = ctxs[i]["steps"][j]["data"]
|
|
|
|
|
if fmt == "graph-rewrites": return {"value":get_full_rewrite(trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
|
|
|
|
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(trace.rewrites[i][j-1].sink).src[2].src)), "lang":"txt"}
|
|
|
|
|
data = viz_data.ctxs[i]["steps"][j]["data"]
|
|
|
|
|
if fmt == "graph-rewrites": return {"value":get_full_rewrite(viz_data, viz_data.trace.rewrites[i][j]), "content_type":"text/event-stream"}
|
|
|
|
|
if fmt == "uops": return {"src":get_stdout(lambda: print_uops(_reconstruct(viz_data, viz_data.trace.rewrites[i][j-1].sink).src[2].src))}
|
|
|
|
|
if fmt == "code": return {"src":data, "lang":"cpp"}
|
|
|
|
|
if fmt == "asm":
|
|
|
|
|
ret:dict = {}
|
|
|
|
|
@@ -643,13 +652,13 @@ def get_render(query:str) -> dict:
|
|
|
|
|
if fmt.startswith("prg-pkts"):
|
|
|
|
|
ret = {}
|
|
|
|
|
with soft_err(lambda err:ret.update(err)):
|
|
|
|
|
if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple)):
|
|
|
|
|
if (events:=get_profile(list(itertools.islice(sqtt_timeline(*data), getenv("MAX_SQTT_PKTS", 50_000))), sort_fn=row_tuple, data=viz_data)):
|
|
|
|
|
ret = {"value":events, "content_type":"application/octet-stream"}
|
|
|
|
|
else: ret = {"src":"No SQTT trace on this SE."}
|
|
|
|
|
return ret
|
|
|
|
|
if fmt == "prg-sqtt":
|
|
|
|
|
ret = {}
|
|
|
|
|
if len((steps:=ctxs[i]["steps"])[j+1:]) == 0:
|
|
|
|
|
if len((steps:=viz_data.ctxs[i]["steps"])[j+1:]) == 0:
|
|
|
|
|
with soft_err(lambda err: ret.update(err)):
|
|
|
|
|
cu_events, units, wave_insts = unpack_sqtt(*data)
|
|
|
|
|
for cu in sorted(cu_events, key=row_tuple):
|
|
|
|
|
@@ -658,7 +667,7 @@ def get_render(query:str) -> dict:
|
|
|
|
|
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
|
|
|
|
|
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", i, len(steps)), loc=(data:=wave_insts[cu][k])["loc"], depth=2, data=data))
|
|
|
|
|
return {**ret, "steps":[{k:v for k,v in s.items() if k != "data"} for s in steps[j+1:]]}
|
|
|
|
|
if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple), "content_type":"application/octet-stream"}
|
|
|
|
|
if fmt == "cu-sqtt": return {"value":get_profile(data, sort_fn=row_tuple, data=viz_data), "content_type":"application/octet-stream"}
|
|
|
|
|
if fmt == "sqtt-insts":
|
|
|
|
|
columns = ["PC", "Instruction", "Hits", "Cycles", "Stall", "Type"]
|
|
|
|
|
inst_columns = ["N", "Clk", "Idle", "Dur", "Stall"]
|
|
|
|
|
@@ -685,11 +694,11 @@ def get_render(query:str) -> dict:
|
|
|
|
|
prev_instr = max(prev_instr, e.time + e.dur)
|
|
|
|
|
summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu},
|
|
|
|
|
{"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}]
|
|
|
|
|
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)}
|
|
|
|
|
return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":viz_data.ref_map.get(data["prg"].name)}
|
|
|
|
|
if fmt == "prg-pma-pkts":
|
|
|
|
|
ret = {}
|
|
|
|
|
with soft_err(lambda err:ret.update(err)):
|
|
|
|
|
if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"}
|
|
|
|
|
if (events:=get_profile(pma_timeline(*data), row_tuple, data=viz_data)): ret = {"value":events, "content_type":"application/octet-stream"}
|
|
|
|
|
else: ret = {"src":"No PMA samples found."}
|
|
|
|
|
return ret
|
|
|
|
|
return data
|
|
|
|
|
@@ -712,11 +721,11 @@ class Handler(HTTPRequestHandler):
|
|
|
|
|
except FileNotFoundError: status_code = 404
|
|
|
|
|
|
|
|
|
|
elif url.path == "/ctxs":
|
|
|
|
|
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs]
|
|
|
|
|
lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in data.ctxs]
|
|
|
|
|
ret, content_type = json.dumps(lst).encode(), "application/json"
|
|
|
|
|
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
|
|
|
|
else:
|
|
|
|
|
if not (render_src:=get_render(self.path)): status_code = 404
|
|
|
|
|
if not (render_src:=get_render(data, self.path)): status_code = 404
|
|
|
|
|
else:
|
|
|
|
|
if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"]
|
|
|
|
|
else: ret, content_type = json.dumps(render_src).encode(), "application/json"
|
|
|
|
|
@@ -754,8 +763,9 @@ if __name__ == "__main__":
|
|
|
|
|
st = time.perf_counter()
|
|
|
|
|
print("*** viz is starting")
|
|
|
|
|
|
|
|
|
|
ctxs = get_rewrites(trace:=load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
|
|
|
|
|
profile_ret = get_profile(load_pickle(args.profile_path, default=[]))
|
|
|
|
|
data = VizData(load_pickle(args.rewrites_path, default=RewriteTrace([], [], {})))
|
|
|
|
|
load_rewrites(data)
|
|
|
|
|
profile_ret = get_profile(load_pickle(args.profile_path, default=[]), data=data)
|
|
|
|
|
|
|
|
|
|
server = TCPServerWithReuse(('', PORT), Handler)
|
|
|
|
|
reloader_thread = threading.Thread(target=reloader)
|
|
|
|
|
|