mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
show sqtt decoder errs in viz (#13088)
* show sqtt decoder errs in viz * don't touch roc.py * give hljs a default language * work from tinyr9 * work
This commit is contained in:
@@ -717,7 +717,7 @@ async function main() {
|
||||
const div = d3.create("div").style("background", cycleColors(colorScheme.CATEGORICAL, s.idx)).style("width", "24px").style("height", "100%");
|
||||
return [s.label.trim(), div.node()];
|
||||
})).node());
|
||||
} else root.appendChild(codeBlock(ret.src, ret.lang));
|
||||
} else root.appendChild(codeBlock(ret.src, ret.lang || "txt"));
|
||||
return document.querySelector("#custom").replaceChildren(root);
|
||||
}
|
||||
// ** UOp view (default)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
|
||||
import subprocess, ctypes, pathlib
|
||||
import subprocess, ctypes, pathlib, traceback
|
||||
from contextlib import redirect_stdout
|
||||
from decimal import Decimal
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
@@ -194,12 +194,21 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int,
|
||||
return struct.pack("<BIQ", 1, len(events), peak)+b"".join(events) if events else None
|
||||
|
||||
def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
from extra.sqtt.roc import decode
|
||||
rctx = decode(profile)
|
||||
steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.hit, e.lat, e.stall, str(e.typ).split("_")[-1]) for e in x[1].values()],
|
||||
"cols":["Instruction", "Hit Count", "Latency", "Stall", "Type"], "summary":[]},
|
||||
"query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.wave_events.items())]
|
||||
if steps: ctxs.append({"name":"Counters", "steps":steps})
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None
|
||||
def err(name:str, msg:str|None=None) -> None:
|
||||
step = {"name":name, "data":{"src":msg or traceback.format_exc()}, "depth":0, "query":f"/render?ctx={len(ctxs)}&step=0&fmt=counters"}
|
||||
return ctxs.append({"name":"Counters", "steps":[step]})
|
||||
try: from extra.sqtt.roc import decode
|
||||
except Exception: return err("DECODER IMPORT ISSUE")
|
||||
try:
|
||||
rctx = decode(profile)
|
||||
steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.hit, e.lat, e.stall, str(e.typ).split("_")[-1]) for e in x[1].values()],
|
||||
"cols":["Instruction", "Hit Count", "Latency", "Stall", "Type"], "summary":[]},
|
||||
"query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.wave_events.items())]
|
||||
if not steps: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
|
||||
except Exception: return err("DECODER ERROR")
|
||||
ctxs.append({"name":"Counters", "steps":steps})
|
||||
|
||||
def get_profile(profile:list[ProfileEvent]) -> bytes|None:
|
||||
# start by getting the time diffs
|
||||
@@ -210,9 +219,7 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None:
|
||||
for device in device_ts_diffs:
|
||||
d = device.split(":")[0]
|
||||
if d == "AMD": device_decoders[d] = load_sqtt
|
||||
for fxn in device_decoders.values():
|
||||
try: fxn(profile)
|
||||
except Exception: continue
|
||||
for fxn in device_decoders.values(): fxn(profile)
|
||||
# map events per device
|
||||
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
|
||||
markers:list[ProfilePointEvent] = []
|
||||
|
||||
Reference in New Issue
Block a user