mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -94,6 +94,8 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
# sopk/immediates don't get ALU/MEM EXEC
|
||||
if e.name.display_name not in {"IMMEDIATE", "IMMEDIATE_MASK", "JUMP", "JUMP_NO", "MESSAGE", "BARRIER", "BARRIER_SIGNAL",
|
||||
"WAVEEND", "WAVERDY"}: insts += 1
|
||||
# OTHER_ is its own stream, it's the INST from other SIMDs that share the same EXEC.
|
||||
elif e.device.startswith("OTHER"): continue
|
||||
else: raise Exception(f"timeline row must be INST or EXEC, got {e.device}")
|
||||
self.assertEqual(execs, insts)
|
||||
|
||||
|
||||
@@ -657,8 +657,6 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
||||
elif isinstance(p, WAVEEND):
|
||||
pc = wave_pc.pop(p.wave)
|
||||
yield (p, InstructionInfo(pc, p.wave, s_endpgm()))
|
||||
# skip OTHER_ instructions, they don't belong to this unit
|
||||
elif isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_"): pass
|
||||
elif isinstance(p, IMMEDIATE_MASK):
|
||||
# immediate mask may yield multiple times per packet
|
||||
for wave in range(16):
|
||||
@@ -668,7 +666,8 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
||||
assert type(inst).__name__ == "SOPP", f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
|
||||
wave_pc[wave] += inst.size()
|
||||
yield (p, InstructionInfo(pc, wave, inst))
|
||||
elif isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
|
||||
# map INST events on this SIMD to the program counter, we know the waves
|
||||
elif isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)) and not (isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_")):
|
||||
inst = pc_map[pc:=wave_pc[p.wave]]
|
||||
# s_delay_alu, s_wait_alu and s_barrier_wait instructions are skipped
|
||||
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU", "S_BARRIER_WAIT"}:
|
||||
@@ -684,7 +683,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
||||
else:
|
||||
wave_pc[p.wave] += inst.size()
|
||||
yield (p, InstructionInfo(pc, p.wave, inst))
|
||||
# for all other packets (VMEMEXEC, ALUEXEC, etc.), yield with None
|
||||
# for all other packets (VMEMEXEC, ALUEXEC, OTHER_ INST, etc.), yield with None
|
||||
else: yield (p, None)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@@ -349,21 +349,23 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
||||
is_cdna = target.startswith("gfx9")
|
||||
dispatch_to_exec = {"WMMA":"VALU", "VALU":"VALU", "VALU1":"VALU", "VALUT":"VALU", "VALUB":"VALU", "VALUINST":"VALU", "VINTERP":"VALU",
|
||||
"SGMEM":"VMEM", "FLAT":"VMEM", "LDS":"LDS", "SALU":"SALU", "SMEM":"SALU", "VMEM":"VMEM"}
|
||||
def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> Generator[ProfileEvent, None, None]:
|
||||
row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}"
|
||||
def add(name:str, p:PacketType, wave:int|None=None, info:InstructionInfo|None=None) -> Generator[ProfileEvent, None, None]:
|
||||
row = "OTHER" if name.startswith("OTHER_") else f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None \
|
||||
else f"{p.__class__.__name__}:0 {name.replace('_ALT', '')}"
|
||||
# default length is 1 cycle
|
||||
duration = 1
|
||||
# exec links to dispatch, dispatch links to PC
|
||||
link = f"PC:{info.pc}" if info else None
|
||||
if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src):
|
||||
if isinstance(p, (ALUEXEC, VMEMEXEC)):
|
||||
link = f"LINK:{exec_pending[name].pop(0)}"
|
||||
# queue inst dispatches
|
||||
idx = next(row_counts.setdefault(row, itertools.count(0)))
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None:
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.replace("OTHER_", "").split("_")[0])) is not None:
|
||||
if name.startswith("OTHER_"): exec_type = f"{exec_type}_ALT"
|
||||
exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}")
|
||||
# construct and yield the event for this packet
|
||||
if row not in row_ends: yield ProfilePointEvent(row, "JSON", "pcMap", pc_map, ts=Decimal(0))
|
||||
yield (e:=ProfileRangeEvent(row, TracingKey(op or name, ret=link), Decimal(p._time), Decimal(p._time+duration)))
|
||||
yield (e:=ProfileRangeEvent(row, TracingKey(name, ret=link), Decimal(p._time), Decimal(p._time+duration)))
|
||||
# allow CDNA packets to overlap, NOT allowed on RDNA.
|
||||
if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.")
|
||||
row_ends[row] = unwrap(e.en)
|
||||
@@ -397,7 +399,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> Generator[ProfileEvent,
|
||||
yield from add("VALU", p)
|
||||
yield from add("SALU", p)
|
||||
else:
|
||||
yield from add(name.replace("_ALT", ""), p, op=name)
|
||||
yield from add(name, p)
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
|
||||
Reference in New Issue
Block a user