mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: put alts in the same row, LDS color (#14194)
* viz: put alts in the same row, coloring work * assert if packets overlap * lds color
This commit is contained in:
@@ -180,7 +180,7 @@ function formatCycles(cycles) {
|
||||
const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
|
||||
|
||||
const WAVE_COLORS = {VALU:"#ffffc0", SALU:"#cef263", LOAD:"#ffc0c0", STORE:"#4fa3cc", IMMEDIATE:"#f3b44a", BARRIER:"#d00000", JUMP:"#ffb703",
|
||||
JUMP_NO:"#fb8500", MESSAGE:"#90dbf4", VMEM:"#b2b7c9"};
|
||||
JUMP_NO:"#fb8500", MESSAGE:"#90dbf4", VMEM:"#b2b7c9", LDS:"#9fb4a6"};
|
||||
const waveColor = (op) => {
|
||||
const cat = op.includes("VALU") || op === "VINTERP" ? "VALU" : op.includes("SALU") ? "SALU" : op.includes("VMEM") ? "VMEM"
|
||||
: op.includes("LOAD") || op === "SMEM" ? "LOAD" : op.includes("STORE") ? "STORE" : op;
|
||||
|
||||
@@ -283,6 +283,7 @@ def sqtt_timeline(e) -> list[ProfileEvent]:
|
||||
from extra.assembly.amd.sqtt import decode, PacketType, INST, InstOp, VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC
|
||||
ret:list[ProfileEvent] = []
|
||||
rows:dict[str, None] = {}
|
||||
trace:dict[str, set[int]] = {}
|
||||
def add(name:str, p:PacketType, idx=0, width=1, op_name=None) -> None:
|
||||
rows.setdefault(r:=(f"WAVE:{p.wave}" if hasattr(p, "wave") else f"{p.__class__.__name__}:0 {name}"))
|
||||
ret.append(ProfileRangeEvent(r, f"{op_name if op_name is not None else name} OP:{idx}", Decimal(p._time), Decimal(p._time+width)))
|
||||
@@ -293,7 +294,10 @@ def sqtt_timeline(e) -> list[ProfileEvent]:
|
||||
name, width = (op_name, 10 if "BARRIER" in op_name else 1)
|
||||
add(name, p, width=width, idx=int("OTHER" in name))
|
||||
if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p)
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)): add((name:=str(p.src).split('.')[1]).replace("_ALT", ""), p, idx=int("ALT" in name), op_name=name)
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)):
|
||||
add((name:=str(p.src).split('.')[1]).replace("_ALT", ""), p, op_name=name)
|
||||
if p._time in trace.setdefault(name, set()): raise AssertionError(f"packets overlap in shared resource! {name}")
|
||||
trace[name].add(p._time)
|
||||
return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
Reference in New Issue
Block a user