diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 415aaa742e..4a51ec787e 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -149,7 +149,7 @@ function renderDag(graph, additions, recenter, layoutOpts) { // ** profiler graph -function formatTime(ts, dur=ts) { +function formatMicroseconds(ts, dur=ts) { if (dur<=1e3) return `${ts.toFixed(2)}us`; if (dur<=1e6) return `${(ts*1e-3).toFixed(2)}ms`; return `${(ts*1e-6).toFixed(2)}s`; @@ -158,7 +158,7 @@ const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit; const colorScheme = {TINY:["#1b5745", "#354f52", "#354f52", "#1d2e62", "#63b0cd"], DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"], - BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], + BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:["#3600f0"], CATEGORICAL:["#ff8080", "#F4A261", "#C8F9D4", "#8D99AE", "#F4A261", "#ffffa2", "#ffffc0", "#87CEEB"],} const cycleColors = (lst, i) => lst[i%lst.length]; @@ -198,13 +198,15 @@ function focusShape(shape) { return metadata.replaceChildren(shapeMetadata.get(focusedShape) ?? ""); } -async function renderProfiler() { +async function renderProfiler(path, unit) { displaySelection("#profiler"); metadata.replaceChildren(shapeMetadata.get(focusedShape) ?? ""); // layout once! - if (data != null) return updateProgress({ start:false }); + if (data != null && data.path === path) return updateProgress({ start:false }); + // support non realtime x axis units + const formatTime = unit === "realtime" ? formatMicroseconds : (s) => `${s} ${unit}`; const profiler = d3.select("#profiler").html(""); - const buf = await (await fetch("/get_profile")).arrayBuffer(); + const buf = await (await fetch(path)).arrayBuffer(); const view = new DataView(buf); let offset = 0; const u8 = () => { const ret = view.getUint8(offset); offset += 1; return ret; } @@ -227,7 +229,7 @@ async function renderProfiler() { const colorMap = new Map(); // map shapes by event key const shapeMap = new Map(); - data = {tracks:new Map(), axes:{}}; + data = {tracks:new Map(), axes:{}, path}; const heightScale = d3.scaleLinear().domain([0, tracePeak]).range([4,maxheight=100]); for (let i=0; i render(e.transform)); d3.select(canvas).call(canvasZoom); document.addEventListener("contextmenu", e => e.ctrlKey && e.preventDefault()); @@ -691,13 +694,14 @@ async function main() { if (url.pathname+url.search !== ckey) e.close(); else if (e.readyState === EventSource.OPEN) activeSrc = e; } - if (ctx.name === "Profiler") return renderProfiler(); + if (ctx.name === "Profiler") return renderProfiler("/get_profile", "realtime"); if (workerUrl == null) await initWorker(); if (ckey in cache) { ret = cache[ckey]; } // ** Disassembly view if (ckey.startsWith("/render")) { + if (step.fmt === "timeline") return renderProfiler(ckey, "clk"); // cycles on the x axis if (!(ckey in cache)) cache[ckey] = ret = await (await fetch(ckey)).json(); displaySelection("#custom"); metadata.innerHTML = ""; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index fa6ba95683..ed1aed51f2 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -215,10 +215,12 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: except Exception: return err("DECODER ERROR") if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded") steps:list[dict] = [] + units:set[str] = set() for name,waves in rctx.inst_execs.items(): + events:list[ProfileEvent] = [] prg = trace.keys[r].ret if (r:=ref_map.get(name)) else None - steps.append({"name":prg.name if prg is not None else name, "query":f"/render?ctx={len(ctxs)}&step={len(steps)}&fmt=counters", - "depth":0, "data":{"src":prg.src if prg is not None else name, "lang":"cpp"}}) + steps.append(first:={"name":prg.name if prg is not None else name, "query":f"/render?ctx={len(ctxs)}&step={len(steps)}&fmt=counters", + "depth":0, "fmt":"timeline"}) # Idle: The total time gap between the completion of previous instruction and the beginning of the current instruction. # The idle time can be caused by: @@ -228,14 +230,18 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: # Stall: The total number of cycles the hardware pipe couldn't issue an instruction. # Duration: Total latency in cycles, defined as "Stall time + Issue time" for gfx9 or "Stall time + Execute time" for gfx10+. for w in waves: + units.add(row:=f"SIMD:{w.simd} CU:{w.cu} SE:{w.se}") + events.append(ProfileRangeEvent(row, wave_name:=f"wave {w.wave_id}", Decimal(w.begin_time), Decimal(w.end_time))) rows, prev_instr = [], w.begin_time for i,e in enumerate(w.insts): rows.append((e.inst, e.time, max(0, e.time-prev_instr), e.dur, e.stall, str(e.typ).split("_")[-1])) prev_instr = max(prev_instr, e.time + e.dur) summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SIMD", "value":w.simd}, {"label":"CU", "value":w.cu}, {"label":"SE", "value":w.se}] - steps.append({"name":f"Wave {w.wave_id}", "depth":1, "query":f"/render?ctx={len(ctxs)}&step={len(steps)}&fmt=counters", + steps.append({"name":wave_name, "depth":1, "query":f"/render?ctx={len(ctxs)}&step={len(steps)}&fmt=counters", "data":{"rows":rows, "cols":["Instruction", "Clk", "Idle", "Duration", "Stall", "Type"], "summary":summary}}) + events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events + first["data"] = {"value":get_profile(events), "content_type":"application/octet-stream"} ctxs.append({"name":"Counters", "steps":steps}) def get_profile(profile:list[ProfileEvent]) -> bytes|None: @@ -302,9 +308,9 @@ def get_stdout(f: Callable) -> str: except Exception: traceback.print_exc(file=buf) return buf.getvalue() -def get_render(i:int, j:int, fmt:str) -> dict|None: +def get_render(i:int, j:int, fmt:str) -> dict: if fmt == "counters": return ctxs[i]["steps"][j]["data"] - if not isinstance(prg:=trace.keys[i].ret, ProgramSpec): return None + if not isinstance(prg:=trace.keys[i].ret, ProgramSpec): return {} if fmt == "uops": return {"src":get_stdout(lambda: print_uops(prg.uops or [])), "lang":"txt"} if fmt == "src": return {"src":prg.src, "lang":"cpp"} compiler = Device[prg.device].compiler @@ -336,11 +342,14 @@ class Handler(BaseHTTPRequestHandler): elif (query:=parse_qs(url.query)): if url.path == "/render": render_src = get_render(get_int(query, "ctx"), get_int(query, "step"), query["fmt"][0]) - ret, content_type = json.dumps(render_src).encode(), "application/json" + if "content_type" in render_src: ret, content_type = render_src["value"], render_src["content_type"] + else: ret, content_type = json.dumps(render_src).encode(), "application/json" else: try: return self.stream_json(get_full_rewrite(trace.rewrites[i:=get_int(query, "ctx")][get_int(query, "idx")], i)) except (KeyError, IndexError): status_code = 404 - elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json" + elif url.path == "/ctxs": + lst = [{**c, "steps":[{k:v for k, v in s.items() if k != "data"} for s in c["steps"]]} for c in ctxs] + ret, content_type = json.dumps(lst).encode(), "application/json" elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" else: status_code = 404