viz: refactor sqtt timeline builder (#15494)

* viz: refactor sqtt timeline builder

* barrier maps to waves

* clean up cli
This commit is contained in:
qazal
2026-03-26 14:16:15 +02:00
committed by GitHub
parent 313937ad6d
commit ec5b7a249e
2 changed files with 21 additions and 16 deletions

View File

@@ -90,7 +90,7 @@ def main():
pkt_idxs:dict[str, itertools.count] = {}
dispatch_to_pc:dict[str, int] = {}
for e in viz.sqtt_timeline(*sqtt_data):
if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg; continue
if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg
if not isinstance(e, ProfileRangeEvent): continue
op_name, info = e.name.display_name, e.name.ret or ""
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)

View File

@@ -344,29 +344,35 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()}
row_ends:dict[str, Decimal] = {}
row_counts:dict[str, itertools.count] = {}
curr_barrier:dict[str, ProfileRangeEvent] = {}
curr_barrier:dict[int, ProfileRangeEvent] = {}
exec_pending:dict[str, list[str]] = {}
NS_PER_TICK = 10 # 100MHz
prev_pair:tuple[int, int]|None = None # (shader, realtime)
is_cdna = target.startswith("gfx9")
dispatch_to_exec = {"WMMA":"VALU", "VALU":"VALU", "VALU1":"VALU", "VALUT":"VALU", "VALUB":"VALU", "VALUINST":"VALU", "VINTERP":"VALU",
"SGMEM":"VMEM", "FLAT":"VMEM", "LDS":"LDS", "SALU":"SALU", "SMEM":"SALU", "VMEM":"VMEM"}
def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> Generator[ProfileEvent, None, None]:
row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}"
# default length is 1 cycle
duration = 1
# exec links to dispatch, dispatch links to PC
link = f"PC:{info.pc}" if info else None
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src):
link = f"LINK:{exec_pending[name].pop(0)}"
# queue inst dispatches
idx = next(row_counts.setdefault(row, itertools.count(0)))
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
# construct and yield the event for this packet
if row not in row_ends: yield ProfilePointEvent(row, "JSON", "pcMap", pc_map, ts=Decimal(0))
# barrier on this row extends to fill the time our wave was waiting
if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time)
e = ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1))
yield (e:=ProfileRangeEvent(row, TracingKey(op or name, ret=link), Decimal(p._time), Decimal(p._time+duration)))
# allow CDNA packets to overlap, NOT allowed on RDNA.
if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
row_ends[row] = unwrap(e.en)
idx = next(row_counts.setdefault(row, itertools.count(0)))
if name == "BARRIER": curr_barrier[row] = e
# queue for exec linking
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}")
yield e
# barrier on this wave extends to fill the time it was waiting
if wave is not None:
if (barrier:=curr_barrier.pop(wave, None)) is not None: barrier.en = Decimal(p._time)
if name == "BARRIER": curr_barrier[wave] = e
NS_PER_TICK = 10 # 100MHz
prev_pair:tuple[int, int]|None = None # (shader, realtime)
for p, info in map_insts(data, lib, target):
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
pair = (p._time, p.delta)
@@ -384,8 +390,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
if isinstance(p, WAVERDY):
for wave in range(16):
if p.mask & (1 << wave):
row = f"WAVE:{wave}"
if row in curr_barrier: yield from add("WAVERDY", p, wave=wave)
if wave in curr_barrier: yield from add("WAVERDY", p, wave=wave)
if isinstance(p, (VMEMEXEC, ALUEXEC)):
name = str(p.src).split('.')[1]
if name == "VALU_SALU":