diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 360bcf0409..c6173216a2 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -350,12 +350,11 @@ tr.nested-row table tr.main-row:hover { background-color: unset; } - tr.main-row.has-children > td:first-child { - white-space: pre; + tr.main-row.has-children > td:first-child > p { + display: inline-block; } tr.main-row.has-children > td:first-child::before { content: "▸ "; - display: inline-block; width: 1em; margin-left: -0.25em; } diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 275ccb0c37..9e6846307f 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -794,7 +794,7 @@ async function main() { } const td = tr.append("td").classed(ret.cols[i], true); // string format scalar values - if (!Array.isArray(value)) { td.text(typeof value === "string" ? value : ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)); continue; } + if (!Array.isArray(value)) { td.append(() => typeof value === "string" ? colored(value) : d3.create("p").text(ret.cols[i] === "Duration" ? formatMicroseconds(value) : formatUnit(value)).node()); continue; } // display arrays in a bar graph td.classed("pct-row", true); const bar = td.append("div"); @@ -804,9 +804,8 @@ async function main() { } return table; } - if (ret.cols != null) { - renderTable(root, ret); - } else root.append(() => codeBlock(ret.src, ret.lang)); + if (ret.cols != null) renderTable(root, ret); + else if (ret.src != null) root.append(() => codeBlock(ret.src, ret.lang)); ret.metadata?.forEach(m => { if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value, idx }) => { const div = d3.create("div").style("background", cycleColors(colorScheme.CATEGORICAL, idx)).style("width", "100%").style("height", "100%"); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index b98b5d939b..700de60a10 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -248,23 +248,25 @@ def load_counters(profile:list[ProfileEvent]) -> None: durations.setdefault(str(e.name), []).append(float(e.en-e.st)) if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e - ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), \ - (durations, {k:v[ProfilePMCEvent][0] for k,v in counter_events.items()}))]}) + if len(counter_events) == 0: return None + ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]}) run_number = {n:0 for n,_ in counter_events} - for k,v in counter_events.items(): - prg = trace.keys[r].ret if (r:=ref_map.get(k[0])) else None - name = prg.name if prg is not None else k[0] - run_number[k[0]] += 1 + for (k, tag),v in counter_events.items(): + # use the colored name if it exists + name = trace.keys[r].ret.name if (r:=ref_map.get(k)) is not None else k + run_number[k] += 1 steps:list[dict] = [] - if (pmc:=v.get(ProfilePMCEvent)): steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc)) + if (pmc:=v.get(ProfilePMCEvent)): + steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc)) + all_counters[(name, run_number[k], k)] = pmc[0] if (sqtt:=v.get(ProfileSQTTEvent)): # to decode a SQTT trace, we need the raw stream, program binary and device properties - steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), (k, [*sqtt, prg_events[k[0]], dev_events[sqtt[0].device]]))) + steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]]))) if getenv("SQTT_PARSE"): # run our decoder on startup, we don't use this since it only works on gfx11 from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets for e in sqtt: parse_sqtt_print_packets(e.blob) - ctxs.append({"name":f"Exec {name} n{run_number[k[0]]}", "steps":steps}) + ctxs.append({"name":f"Exec {name} n{run_number[k]}", "steps":steps}) # ** SQTT OCC only unpacks wave start, end time and SIMD location @@ -398,10 +400,10 @@ def get_render(i:int, j:int, fmt:str) -> dict: if fmt == "all-pmc": durations, pmc = data ret:dict = {"cols":{}, "rows":[]} - for (prg,_),events in pmc.items(): + for (name, n, k),events in data[1].items(): pmc_table = unpack_pmc(events) ret["cols"].update([(r[0], None) for r in pmc_table["rows"]]) - ret["rows"].append((prg, durations[prg].pop(0), *[r[1] for r in pmc_table["rows"]])) + ret["rows"].append((name, durations[k][n-1], *[r[1] for r in pmc_table["rows"]])) ret["cols"] = ["Kernel", "Duration", *ret["cols"]] return ret if fmt == "prg-pmc": return unpack_pmc(data[0])