show sqtt decoder errs in viz (#13088)

* show sqtt decoder errs in viz

* don't touch roc.py

* give hljs a default language

* work from tinyr9

* work
This commit is contained in:
qazal
2025-11-04 22:05:06 +08:00
committed by GitHub
parent 49191ada77
commit 96417665e8
2 changed files with 18 additions and 11 deletions

View File

@@ -717,7 +717,7 @@ async function main() {
const div = d3.create("div").style("background", cycleColors(colorScheme.CATEGORICAL, s.idx)).style("width", "24px").style("height", "100%");
return [s.label.trim(), div.node()];
})).node());
} else root.appendChild(codeBlock(ret.src, ret.lang));
} else root.appendChild(codeBlock(ret.src, ret.lang || "txt"));
return document.querySelector("#custom").replaceChildren(root);
}
// ** UOp view (default)

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
import subprocess, ctypes, pathlib
import subprocess, ctypes, pathlib, traceback
from contextlib import redirect_stdout
from decimal import Decimal
from http.server import BaseHTTPRequestHandler
@@ -194,12 +194,21 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int,
return struct.pack("<BIQ", 1, len(events), peak)+b"".join(events) if events else None
def load_sqtt(profile:list[ProfileEvent]) -> None:
from extra.sqtt.roc import decode
rctx = decode(profile)
steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.hit, e.lat, e.stall, str(e.typ).split("_")[-1]) for e in x[1].values()],
"cols":["Instruction", "Hit Count", "Latency", "Stall", "Type"], "summary":[]},
"query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.wave_events.items())]
if steps: ctxs.append({"name":"Counters", "steps":steps})
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None
def err(name:str, msg:str|None=None) -> None:
step = {"name":name, "data":{"src":msg or traceback.format_exc()}, "depth":0, "query":f"/render?ctx={len(ctxs)}&step=0&fmt=counters"}
return ctxs.append({"name":"Counters", "steps":[step]})
try: from extra.sqtt.roc import decode
except Exception: return err("DECODER IMPORT ISSUE")
try:
rctx = decode(profile)
steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.hit, e.lat, e.stall, str(e.typ).split("_")[-1]) for e in x[1].values()],
"cols":["Instruction", "Hit Count", "Latency", "Stall", "Type"], "summary":[]},
"query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.wave_events.items())]
if not steps: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
except Exception: return err("DECODER ERROR")
ctxs.append({"name":"Counters", "steps":steps})
def get_profile(profile:list[ProfileEvent]) -> bytes|None:
# start by getting the time diffs
@@ -210,9 +219,7 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None:
for device in device_ts_diffs:
d = device.split(":")[0]
if d == "AMD": device_decoders[d] = load_sqtt
for fxn in device_decoders.values():
try: fxn(profile)
except Exception: continue
for fxn in device_decoders.values(): fxn(profile)
# map events per device
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
markers:list[ProfilePointEvent] = []