diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html
index 360bcf0409..c6173216a2 100644
--- a/tinygrad/viz/index.html
+++ b/tinygrad/viz/index.html
@@ -350,12 +350,11 @@
tr.nested-row table tr.main-row:hover {
background-color: unset;
}
- tr.main-row.has-children > td:first-child {
- white-space: pre;
+ tr.main-row.has-children > td:first-child > p {
+ display: inline-block;
}
tr.main-row.has-children > td:first-child::before {
content: "▸ ";
- display: inline-block;
width: 1em;
margin-left: -0.25em;
}
diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js
index 275ccb0c37..9e6846307f 100644
--- a/tinygrad/viz/js/index.js
+++ b/tinygrad/viz/js/index.js
@@ -794,7 +794,7 @@ async function main() {
}
const td = tr.append("td").classed(ret.cols[i], true);
// string format scalar values
- if (!Array.isArray(value)) { td.text(typeof value === "string" ? value : ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)); continue; }
+ if (!Array.isArray(value)) { td.append(() => typeof value === "string" ? colored(value) : d3.create("p").text(ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)).node()); continue; }
// display arrays in a bar graph
td.classed("pct-row", true);
const bar = td.append("div");
@@ -804,9 +804,8 @@ async function main() {
}
return table;
}
- if (ret.cols != null) {
- renderTable(root, ret);
- } else root.append(() => codeBlock(ret.src, ret.lang));
+ if (ret.cols != null) renderTable(root, ret);
+ else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang));
ret.metadata?.forEach(m => {
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value, idx }) => {
const div = d3.create("div").style("background", cycleColors(colorScheme.CATEGORICAL, idx)).style("width", "100%").style("height", "100%");
diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py
index b98b5d939b..700de60a10 100755
--- a/tinygrad/viz/serve.py
+++ b/tinygrad/viz/serve.py
@@ -248,23 +248,25 @@ def load_counters(profile:list[ProfileEvent]) -> None:
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
- ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), \
- (durations, {k:v[ProfilePMCEvent][0] for k,v in counter_events.items()}))]})
+ if len(counter_events) == 0: return None
+ ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
run_number = {n:0 for n,_ in counter_events}
- for k,v in counter_events.items():
- prg = trace.keys[r].ret if (r:=ref_map.get(k[0])) else None
- name = prg.name if prg is not None else k[0]
- run_number[k[0]] += 1
+ for (k, tag),v in counter_events.items():
+ # use the colored name if it exists
+ name = trace.keys[r].ret.name if (r:=ref_map.get(k)) is not None else k
+ run_number[k] += 1
steps:list[dict] = []
- if (pmc:=v.get(ProfilePMCEvent)): steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
+ if (pmc:=v.get(ProfilePMCEvent)):
+ steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
+ all_counters[(name, run_number[k], k)] = pmc[0]
if (sqtt:=v.get(ProfileSQTTEvent)):
# to decode a SQTT trace, we need the raw stream, program binary and device properties
- steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), (k, [*sqtt, prg_events[k[0]], dev_events[sqtt[0].device]])))
+ steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]])))
if getenv("SQTT_PARSE"):
# run our decoder on startup, we don't use this since it only works on gfx11
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
for e in sqtt: parse_sqtt_print_packets(e.blob)
- ctxs.append({"name":f"Exec {name} n{run_number[k[0]]}", "steps":steps})
+ ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps})
# ** SQTT OCC only unpacks wave start, end time and SIMD location
@@ -398,10 +400,10 @@ def get_render(i:int, j:int, fmt:str) -> dict:
if fmt == "all-pmc":
durations, pmc = data
ret:dict = {"cols":{}, "rows":[]}
- for (prg,_),events in pmc.items():
+ for (name, n, k),events in data[1].items():
pmc_table = unpack_pmc(events)
ret["cols"].update([(r[0], None) for r in pmc_table["rows"]])
- ret["rows"].append((prg, durations[prg].pop(0), *[r[1] for r in pmc_table["rows"]]))
+ ret["rows"].append((name, durations[k][n-1], *[r[1] for r in pmc_table["rows"]]))
ret["cols"] = ["Kernel", "Duration", *ret["cols"]]
return ret
if fmt == "prg-pmc": return unpack_pmc(data[0])