From 3170365a5ba19c5f35718169dfc6a28aaf4cc452 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 6 Jan 2026 00:53:20 -0500 Subject: [PATCH] visualize SQTT with the same cfg infrastructure (#13870) * start * rough sketch * post render dag * art * intro g key * work * custom color scale * colors * more blue * better * smaller * use for loop in test --- test/testextra/test_cfg_viz.py | 16 ++++++++++++++++ tinygrad/viz/js/index.js | 9 +++++++-- tinygrad/viz/js/worker.js | 11 ++++++++--- tinygrad/viz/serve.py | 6 ++++-- 4 files changed, 35 insertions(+), 7 deletions(-) diff --git a/test/testextra/test_cfg_viz.py b/test/testextra/test_cfg_viz.py index 93b95a1b38..5fca94112f 100644 --- a/test/testextra/test_cfg_viz.py +++ b/test/testextra/test_cfg_viz.py @@ -178,5 +178,21 @@ class TestCfg(unittest.TestCase): s_endpgm(), ]) + def test_colored_blocks(self): + N = 10 + asm = ["entry:", s_branch("init0"),] + for i in range(N): + asm += [f"init{i}:", s_mov_b32(s[1], i + 1), s_branch(loop:=f"loop{i}")] + asm += [ + f"{loop}:", + s_nop(i & 7), + s_add_u32(s[1], s[1], -1), + s_cmp_eq_i32(s[1], 0), + s_cbranch_scc0(loop), + s_branch(f"init{i+1}" if i + 1 < N else "end"), + ] + asm += ["end:", s_endpgm()] + run_asm("test_colored_blocks", asm) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 811650d5c2..b67e56726f 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -55,8 +55,11 @@ function addTags(root) { root.selectAll("text").data(d => [d]).join("text").text(d => d).attr("dy", "0.35em"); } +const colorScale = d3.scaleSequential(t => t > 0 ? d3.interpolateLab(colorScheme.ACTIVE[1], colorScheme.ACTIVE[2])(t) : colorScheme.ACTIVE[0]).clamp(true); + const drawGraph = (data) => { const g = dagre.graphlib.json.read(data); + if (data.value.colorDomain != null) colorScale.domain(data.value.colorDomain); // draw nodes d3.select("#graph-svg").on("click", () => d3.selectAll(".highlight").classed("highlight", false)); const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g").attr("class", d => d.className ?? "node") @@ -88,7 +91,7 @@ const drawGraph = (data) => { } return [ret]; }).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan") - .attr("fill", d => d.color).text(d => d.st).attr("xml:space", "preserve").style("font-family", g.graph().font); + .attr("fill", d => typeof d.color === "string" ? d.color : colorScale(d.color)).text(d => d.st).attr("xml:space", "preserve").style("font-family", g.graph().font); addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag") .attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag)); // draw edges @@ -154,7 +157,8 @@ const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit; const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]), DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"], - BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SE:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),} + BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SE:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]), + ACTIVE:["#565f89", "#c8d3f5", "#7aa2f7"]} const cycleColors = (lst, i) => lst[i%lst.length]; const rescaleTrack = (source, tid, k) => { @@ -811,6 +815,7 @@ async function main() { } return table; } + if (ret.data != null) renderDag(ret, { recenter:true }); if (ret.cols != null) renderTable(root, ret); else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang)); return document.querySelector("#custom").replaceChildren(root.node()); diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 2856b0ac2b..1d9e6b0a5b 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -13,21 +13,26 @@ onmessage = (e) => { self.close(); } -const layoutCfg = (g, { blocks, paths, pc_table, colors }) => { +const layoutCfg = (g, { blocks, paths, pc_table, counters, colors }) => { g.setGraph({ rankdir:"TD", font:"monospace" }); ctx.font = `350 ${LINE_HEIGHT}px ${g.graph().font}`; // basic blocks render the assembly in nodes + let maxColor = 0; for (const [lead, members] of Object.entries(blocks)) { let [width, height, label] = [0, 0, []]; for (const m of members) { const text = pc_table[m][0]; + if (counters != null) { + const num = counters[m]?.hit_count || 0; + if (num > maxColor) maxColor = num; + label.push([{st:text, color:num}]); + } else { const [inst, ...operands] = text.split(" "); label.push([{st:inst+" ", color:"#7aa2f7"}, {st:operands.join(" "), color:"#9aa5ce"}]); } width = Math.max(width, ctx.measureText(text).width); height += LINE_HEIGHT; - const [inst, ...operands] = text.split(" "); - label.push([{st:inst+" ", color:"#7aa2f7"}, {st:operands.join(" "), color:"#9aa5ce"}]); } g.setNode(lead, { ...rectDims(width, height), label, id:lead, color:"#1a1b26" }); } + g.graph().colorDomain = [0, maxColor]; // paths become edges between basic blocks for (const [lead, value] of Object.entries(paths)) { for (const [id, color] of Object.entries(value)) g.setEdge(lead, id, {label:{type:"port", text:""}, color:colors[color]}); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 6da4047eeb..a7d37f59f9 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -285,7 +285,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[ n = next(inst_units[u]) if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = [] events.append(ProfileRangeEvent(w.simd_loc, loc:=f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time))) - wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "run_number":n, "loc":loc} + wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "prg":p, "run_number":n, "loc":loc} # * OCC waves units:dict[str, itertools.count] = {} wave_start:dict[str, int] = {} @@ -490,7 +490,9 @@ def get_render(query:str) -> dict: prev_instr = max(prev_instr, e.time + e.dur) summary = [{"label":"Total Cycles", "value":w.end_time-w.begin_time}, {"label":"SE", "value":w.se}, {"label":"CU", "value":w.cu}, {"label":"SIMD", "value":w.simd}, {"label":"Wave ID", "value":w.wave_id}, {"label":"Run number", "value":data["run_number"]}] - return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary]} + cfg = amdgpu_cfg((p:=data["prg"]).lib, device_props[p.device]["gfx_target_version"])["data"] + cfg["counters"] = {pc-p.base:v for pc,v in rows.items()} + return {"rows":[tuple(v.values()) for v in rows.values()], "cols":columns, "metadata":[summary], "data":cfg} return data # ** HTTP server