mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: refactor sqtt timeline builder (#15494)
* viz: refactor sqtt timeline builder * barrier maps to waves * clean up cli
This commit is contained in:
@@ -90,7 +90,7 @@ def main():
|
||||
pkt_idxs:dict[str, itertools.count] = {}
|
||||
dispatch_to_pc:dict[str, int] = {}
|
||||
for e in viz.sqtt_timeline(*sqtt_data):
|
||||
if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg; continue
|
||||
if isinstance(e, ProfilePointEvent) and e.key == 'pcMap': pc_map = e.arg
|
||||
if not isinstance(e, ProfileRangeEvent): continue
|
||||
op_name, info = e.name.display_name, e.name.ret or ""
|
||||
color = next((c for p, c in WAVE_COLORS if any(x in op_name for x in p)), None)
|
||||
|
||||
@@ -344,29 +344,35 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
||||
pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()}
|
||||
row_ends:dict[str, Decimal] = {}
|
||||
row_counts:dict[str, itertools.count] = {}
|
||||
curr_barrier:dict[str, ProfileRangeEvent] = {}
|
||||
curr_barrier:dict[int, ProfileRangeEvent] = {}
|
||||
exec_pending:dict[str, list[str]] = {}
|
||||
NS_PER_TICK = 10 # 100MHz
|
||||
prev_pair:tuple[int, int]|None = None # (shader, realtime)
|
||||
is_cdna = target.startswith("gfx9")
|
||||
dispatch_to_exec = {"WMMA":"VALU", "VALU":"VALU", "VALU1":"VALU", "VALUT":"VALU", "VALUB":"VALU", "VALUINST":"VALU", "VINTERP":"VALU",
|
||||
"SGMEM":"VMEM", "FLAT":"VMEM", "LDS":"LDS", "SALU":"SALU", "SMEM":"SALU", "VMEM":"VMEM"}
|
||||
def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> Generator[ProfileEvent, None, None]:
|
||||
row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}"
|
||||
# default length is 1 cycle
|
||||
duration = 1
|
||||
# exec links to dispatch, dispatch links to PC
|
||||
link = f"PC:{info.pc}" if info else None
|
||||
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src):
|
||||
link = f"LINK:{exec_pending[name].pop(0)}"
|
||||
# queue inst dispatches
|
||||
idx = next(row_counts.setdefault(row, itertools.count(0)))
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
|
||||
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
|
||||
# construct and yield the event for this packet
|
||||
if row not in row_ends: yield ProfilePointEvent(row, "JSON", "pcMap", pc_map, ts=Decimal(0))
|
||||
# barrier on this row extends to fill the time our wave was waiting
|
||||
if (barrier:=curr_barrier.pop(row, None)) is not None: barrier.en = Decimal(p._time)
|
||||
e = ProfileRangeEvent(row, TracingKey(op or name, ret=f"PC:{info.pc}" if info else None), Decimal(p._time), Decimal(p._time+1))
|
||||
yield (e:=ProfileRangeEvent(row, TracingKey(op or name, ret=link), Decimal(p._time), Decimal(p._time+duration)))
|
||||
# allow CDNA packets to overlap, NOT allowed on RDNA.
|
||||
if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
|
||||
row_ends[row] = unwrap(e.en)
|
||||
idx = next(row_counts.setdefault(row, itertools.count(0)))
|
||||
if name == "BARRIER": curr_barrier[row] = e
|
||||
# queue for exec linking
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
|
||||
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
|
||||
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}")
|
||||
yield e
|
||||
# barrier on this wave extends to fill the time it was waiting
|
||||
if wave is not None:
|
||||
if (barrier:=curr_barrier.pop(wave, None)) is not None: barrier.en = Decimal(p._time)
|
||||
if name == "BARRIER": curr_barrier[wave] = e
|
||||
NS_PER_TICK = 10 # 100MHz
|
||||
prev_pair:tuple[int, int]|None = None # (shader, realtime)
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker:
|
||||
pair = (p._time, p.delta)
|
||||
@@ -384,8 +390,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
||||
if isinstance(p, WAVERDY):
|
||||
for wave in range(16):
|
||||
if p.mask & (1 << wave):
|
||||
row = f"WAVE:{wave}"
|
||||
if row in curr_barrier: yield from add("WAVERDY", p, wave=wave)
|
||||
if wave in curr_barrier: yield from add("WAVERDY", p, wave=wave)
|
||||
if isinstance(p, (VMEMEXEC, ALUEXEC)):
|
||||
name = str(p.src).split('.')[1]
|
||||
if name == "VALU_SALU":
|
||||
|
||||
Reference in New Issue
Block a user