From a3d516c4b5448263da128beba9f322630589fa51 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 18 Feb 2026 13:22:32 +0800 Subject: [PATCH] viz: start displaying pma (#14848) * viz: start displaying pma * s * work * colors * cleaner * max packets * fine * work * pma * diff cleanup --- tinygrad/viz/js/index.js | 6 ++++-- tinygrad/viz/serve.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 9bc7cbb92b..5cb3d56025 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -192,7 +192,9 @@ const waveColor = (op) => { const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]), DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"], BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]), - WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor} + GPC:new Map([["NONE","#1a7a2e"],["MEMORY_DEPENDENCY","#8b1a00"],["EXEC_DEPENDENCY","#006b6b"],["INST_FETCH","#7a7a00"],["SYNC","#6b006b"], + ["PIPE_BUSY","#7a4a00"],["MEMORY_THROTTLE","#5c0000"],["CONSTANT_MEMORY","#1a3d7a"],["NOT_SELECTED","#2e2e3a"],["OTHER","#4a4a55"], + ["SLEEPING","#1a1a2a"],["DEFAULT","#3a3a45"]]), WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor} const cycleColors = (lst, i) => lst[i%lst.length]; const rescaleTrack = (source, tid, k) => { @@ -826,7 +828,7 @@ async function main() { } // timeline with cycles on the x axis if (ret instanceof ArrayBuffer) { - opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), colorByName:step.name.includes("PKTS")}; + opts = {heightScale:0.5, hideLabels:true, levelKey:step.name.includes("PKTS") ? (e) => parseInt(e.name.split(" ")[1].split(":")[1]) : null, colorByName:ckey.includes("pkts")}; return renderProfiler(ckey, "clk", opts); } metadata.innerHTML = ""; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 01228ce1cc..6231de862f 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -404,6 +404,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_ if (d:=ev.device.split(":")[0]) == "AMD": device_decoders[d] = load_counters amdgpu_targets[d] = f"gfx{unwrap(ev.props)['gfx_target_version']//1000}" + if d == "NV": device_decoders[d] = load_pma_counters # load device specific counters for fxn in device_decoders.values(): fxn(profile) # map events per device @@ -431,6 +432,35 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_ index = json.dumps({"strings":list(scache), "dtypeSize":dtype_size, "markers":[{"ts":rel_ts(e.ts, start_ts), **e.arg} for e in markers]}).encode() return struct.pack(" None: + steps:list[dict] = [] + sm_version = {e.device:e.props.get("sm_version", 0x800) for e in profile if isinstance(e, ProfileDeviceEvent) and e.props is not None} + run_number:dict[str, int] = {} + for e in profile: + if type(e).__name__ == "ProfilePMAEvent": + run_number[e.kern] = run_num = run_number.get(e.kern, 0)+1 + steps.append(create_step(f"PMA {e.kern}"+(f"n{run_num}" if run_num>1 else ""), ("/prg-pma-pkts", len(ctxs), len(steps)), + data=(e.blob, sm_version[e.device]))) + if steps: ctxs.append({"name":"All Counters", "steps":steps}) + +def pma_timeline(blob:bytes, sm_version:int) -> list[ProfileEvent]: + from extra.nv_pma.decode import decode, decode_tpc_id + ret:list[ProfileEvent] = [] + rows:dict[str, None] = {} + tpc_count:dict[int, int] = {} + # assume every sample is 32 cycles + cycles_per_sample = 32 + for s, tpc_id in decode(blob, sm_version): + if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break + gpc, tpc, sm = decode_tpc_id(tpc_id) + tpc_count[tpc_id] = (n:=tpc_count.get(tpc_id,0)) + 1 + rows.setdefault(row:=f"GPC:{gpc} TPC:{tpc} SM:{sm} WAVE:{s.wave_id}") + ret.append(ProfileRangeEvent(row, TracingKey(s.stall_reason.name, ret=f"pc=0x{s.pc_offset:06x} active={s.active}"), + Decimal(n*cycles_per_sample), Decimal((n+1)*cycles_per_sample))) + return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret + # ** Assembly static analyzers def get_stdout(f: Callable) -> str: @@ -585,6 +615,12 @@ def get_render(query:str) -> dict: summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu}, {"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}] return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "ref":ref_map.get(data["prg"].name)} + if fmt == "prg-pma-pkts": + ret = {} + with soft_err(lambda err:ret.update(err)): + if (events:=get_profile(pma_timeline(*data), sort_fn=row_tuple)): ret = {"value":events, "content_type":"application/octet-stream"} + else: ret = {"src":"No PMA samples found."} + return ret return data # ** HTTP server