mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: show ALT/OTHER packets on second lane (#14192)
* viz: show dimmer ALT/OTHER packets * remove todo comment * work * current vmem is gray
This commit is contained in:
@@ -170,19 +170,18 @@ function formatMicroseconds(ts, showUs=true) {
|
||||
const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
|
||||
|
||||
const WAVE_COLORS = {VALU:"#ffffc0", SALU:"#cef263", LOAD:"#ffc0c0", STORE:"#4fa3cc", IMMEDIATE:"#f3b44a", BARRIER:"#d00000", JUMP:"#ffb703",
|
||||
JUMP_NO:"#fb8500", MESSAGE:"#90dbf4"};
|
||||
JUMP_NO:"#fb8500", MESSAGE:"#90dbf4", VMEM:"#b2b7c9"};
|
||||
const waveColor = (op) => {
|
||||
const cat = op.includes("VALU") || op === "VINTERP" ? "VALU" : op.includes("SALU") ? "SALU"
|
||||
const cat = op.includes("VALU") || op === "VINTERP" ? "VALU" : op.includes("SALU") ? "SALU" : op.includes("VMEM") ? "VMEM"
|
||||
: op.includes("LOAD") || op === "SMEM" ? "LOAD" : op.includes("STORE") ? "STORE" : op;
|
||||
ret = WAVE_COLORS[cat] ?? "#ffffff";
|
||||
// TODO: OTHER packets need to go on the second row
|
||||
if (op.includes("OTHER_")) { ret = darkenHex(ret, 75) }
|
||||
if (op.includes("OTHER_") || op.includes("_ALT")) { ret = darkenHex(ret, 75) }
|
||||
return ret
|
||||
};
|
||||
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
|
||||
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
|
||||
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
|
||||
WAVE:waveColor, VMEMEXEC:["#f4978e"], ALUEXEC:["#f72585"]}
|
||||
WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor}
|
||||
const cycleColors = (lst, i) => lst[i%lst.length];
|
||||
|
||||
const rescaleTrack = (source, tid, k) => {
|
||||
@@ -800,7 +799,7 @@ async function main() {
|
||||
}
|
||||
// timeline with cycles on the x axis
|
||||
if (ret instanceof ArrayBuffer) {
|
||||
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), stepColors:!step.name.includes("Packets")};
|
||||
opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), stepColors:!step.name.includes("PKTS")};
|
||||
return renderProfiler(ckey, "clk", opts);
|
||||
}
|
||||
metadata.innerHTML = "";
|
||||
|
||||
@@ -283,21 +283,17 @@ def sqtt_timeline(e) -> list[ProfileEvent]:
|
||||
from extra.assembly.amd.sqtt import decode, PacketType, INST, InstOp, VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC
|
||||
ret:list[ProfileEvent] = []
|
||||
rows:dict[str, None] = {}
|
||||
def add(name:str, p:PacketType, idx=0, width=1) -> None:
|
||||
def add(name:str, p:PacketType, idx=0, width=1, op_name=None) -> None:
|
||||
rows.setdefault(r:=(f"WAVE:{p.wave}" if hasattr(p, "wave") else f"{p.__class__.__name__}:0 {name}"))
|
||||
ret.append(ProfileRangeEvent(r, f"{name} OP:{idx}", Decimal(p._time), Decimal(p._time+width)))
|
||||
ret.append(ProfileRangeEvent(r, f"{op_name if op_name is not None else name} OP:{idx}", Decimal(p._time), Decimal(p._time+width)))
|
||||
for p in decode(e.blob):
|
||||
if len(ret) > 50_000: break
|
||||
if isinstance(p, INST):
|
||||
op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}"
|
||||
# skip OTHER_* packets
|
||||
if "OTHER" in op_name: continue
|
||||
name, width = (op_name, 10 if "BARRIER" in op_name else 1)
|
||||
# Wave SALU op, not to be confused with the global SALU
|
||||
if op_name == "SALU": name = "WAVE_SALU"
|
||||
add(name, p, width=width)
|
||||
add(name, p, width=width, idx=int("OTHER" in name))
|
||||
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p)
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)): add(str(p.src).split('.')[1], p)
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)): add((name:=str(p.src).split('.')[1]).replace("_ALT", ""), p, idx=int("ALT" in name), op_name=name)
|
||||
return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
Reference in New Issue
Block a user