diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index cba93c1c25..884c52f211 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -717,7 +717,7 @@ async function main() { const div = d3.create("div").style("background", cycleColors(colorScheme.CATEGORICAL, s.idx)).style("width", "24px").style("height", "100%"); return [s.label.trim(), div.node()]; })).node()); - } else root.appendChild(codeBlock(ret.src, ret.lang)); + } else root.appendChild(codeBlock(ret.src, ret.lang || "txt")); return document.querySelector("#custom").replaceChildren(root); } // ** UOp view (default) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 7bdef738e1..22a8e1e423 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct -import subprocess, ctypes, pathlib +import subprocess, ctypes, pathlib, traceback from contextlib import redirect_stdout from decimal import Decimal from http.server import BaseHTTPRequestHandler @@ -194,12 +194,21 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, return struct.pack(" None: - from extra.sqtt.roc import decode - rctx = decode(profile) - steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.hit, e.lat, e.stall, str(e.typ).split("_")[-1]) for e in x[1].values()], - "cols":["Instruction", "Hit Count", "Latency", "Stall", "Type"], "summary":[]}, - "query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.wave_events.items())] - if steps: ctxs.append({"name":"Counters", "steps":steps}) + from tinygrad.runtime.ops_amd import ProfileSQTTEvent + if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None + def err(name:str, msg:str|None=None) -> None: + step = {"name":name, "data":{"src":msg or traceback.format_exc()}, "depth":0, "query":f"/render?ctx={len(ctxs)}&step=0&fmt=counters"} + return ctxs.append({"name":"Counters", "steps":[step]}) + try: from extra.sqtt.roc import decode + except Exception: return err("DECODER IMPORT ISSUE") + try: + rctx = decode(profile) + steps = [{"name":x[0], "depth":0, "data":{"rows":[(e.inst, e.hit, e.lat, e.stall, str(e.typ).split("_")[-1]) for e in x[1].values()], + "cols":["Instruction", "Hit Count", "Latency", "Stall", "Type"], "summary":[]}, + "query":f"/render?ctx={len(ctxs)}&step={i}&fmt=counters"} for i,x in enumerate(rctx.wave_events.items())] + if not steps: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded") + except Exception: return err("DECODER ERROR") + ctxs.append({"name":"Counters", "steps":steps}) def get_profile(profile:list[ProfileEvent]) -> bytes|None: # start by getting the time diffs @@ -210,9 +219,7 @@ def get_profile(profile:list[ProfileEvent]) -> bytes|None: for device in device_ts_diffs: d = device.split(":")[0] if d == "AMD": device_decoders[d] = load_sqtt - for fxn in device_decoders.values(): - try: fxn(profile) - except Exception: continue + for fxn in device_decoders.values(): fxn(profile) # map events per device dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {} markers:list[ProfilePointEvent] = []