diff --git a/extra/assembly/amd/sqttmap.py b/extra/assembly/amd/sqttmap.py new file mode 100644 index 0000000000..d4942de191 --- /dev/null +++ b/extra/assembly/amd/sqttmap.py @@ -0,0 +1,134 @@ +# maps SQTT trace packets to instructions. +from dataclasses import dataclass +from typing import Iterator + +from tinygrad.runtime.support.elf import elf_loader +from tinygrad.helpers import DEBUG, colored + +from extra.assembly.amd.sqtt import decode, print_packets, INST, VALUINST, IMMEDIATE, WAVESTART, WAVEEND, InstOp, PacketType, IMMEDIATE_MASK +from extra.assembly.amd.dsl import Inst +from extra.assembly.amd.decode import decode_inst +from extra.assembly.amd.autogen.rdna3.ins import SOPP, s_endpgm +from extra.assembly.amd.autogen.rdna3.enum import SOPPOp + +@dataclass(frozen=True) +class InstructionInfo: + pc: int + wave: int + inst: Inst + +def map_insts(data:bytes, lib:bytes) -> Iterator[tuple[PacketType, InstructionInfo|None]]: + """maps SQTT packets to instructions, yields (packet, instruction_info or None)""" + # map pcs to insts + pc_map:dict[int, Inst] = {} + image, sections, _ = elf_loader(lib) + text = next((sh for sh in sections if sh.name == ".text"), None) + assert text is not None, "no .text section found" + text_off, text_size = text.header.sh_addr, text.header.sh_size + offset = text_off + while offset < text_off + text_size: + inst = decode_inst(image[offset:]) + pc_map[offset-text_off] = inst + offset += inst.size() + + wave_pc:dict[int, int] = {} + # only processing packets on one [CU, SIMD] unit + def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0 + for p in decode(data): + if not simd_select(p): continue + if DEBUG >= 2: print_packets([p]) + if isinstance(p, WAVESTART): + assert p.wave not in wave_pc, "only one inflight wave per unit" + wave_pc[p.wave] = 0 + continue + if isinstance(p, WAVEEND): + pc = wave_pc.pop(p.wave) + yield (p, InstructionInfo(pc, p.wave, s_endpgm())) + continue + # skip OTHER_ instructions, they don't belong to this unit + if isinstance(p, INST) and p.op.name.startswith("OTHER_"): continue + if isinstance(p, IMMEDIATE_MASK): + # immediate mask may yield multiple times per packet + for wave in range(16): + if p.mask & (1 << wave): + inst = pc_map[pc:=wave_pc[wave]] + # can this assert be more strict? + assert isinstance(inst, SOPP), f"IMMEDIATE_MASK packet must map to SOPP, got {inst}" + wave_pc[wave] += inst.size() + yield (p, InstructionInfo(pc, wave, inst)) + continue + if isinstance(p, (VALUINST, INST, IMMEDIATE)): + inst = pc_map[pc:=wave_pc[p.wave]] + # s_delay_alu doesn't get a packet? + if isinstance(inst, SOPP) and inst.op in {SOPPOp.S_DELAY_ALU}: + wave_pc[p.wave] += inst.size() + if DEBUG >= 2: print(f"{' '*29}{colored(inst.disasm(), 'BLACK')}") + inst = pc_map[pc:=wave_pc[p.wave]] + # identify a branch instruction, only used for asserts + is_branch = isinstance(inst, SOPP) and "BRANCH" in inst.op_name + if is_branch: assert isinstance(p, INST) and p.op in {InstOp.JUMP_NO, InstOp.JUMP}, f"branch can only be folowed by jump packets, got {p}" + # JUMP handling + if isinstance(p, INST) and p.op is InstOp.JUMP: + assert is_branch, f"JUMP packet must map to a branch instruction, got {inst}" + x = inst.simm16 & 0xffff + wave_pc[p.wave] += inst.size() + (x - 0x10000 if x & 0x8000 else x)*4 + else: + if is_branch: assert inst.op != SOPPOp.S_BRANCH, f"S_BRANCH must have a JUMP packet, got {p}" + wave_pc[p.wave] += inst.size() + if DEBUG >= 2: print(f"{' '*29}{colored(inst.disasm(), 'WHITE')}") + yield (p, InstructionInfo(pc, p.wave, inst)) + continue + # for all other packets (VMEMEXEC, ALUEXEC, etc.), yield with None + yield (p, None) + +# test to compare every packet with the rocprof decoder + +def test_rocprof_inst_traces_match(sqtt, prg, target): + from tinygrad.viz.serve import llvm_disasm + from extra.sqtt.roc import decode as roc_decode, InstExec + disasm = {addr+prg.base:inst_disasm for addr, inst_disasm in llvm_disasm(target, prg.lib).items()} + rctx = roc_decode([sqtt], {prg.name:disasm}) + rwaves = rctx.inst_execs[(sqtt.kern, sqtt.exec_tag)] + rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit + for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts()) + rwaves_base = next(iter(disasm)) # base program counter + + passed_insts = 0 + for pkt, info in map_insts(sqtt.blob, prg.lib): + if info is None: continue + rocprof_inst = next(rwaves_iter[info.wave][0]) + ref_pc = rocprof_inst.pc-rwaves_base + # always check pc matches + assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm[rocprof_inst.pc][0]} != {info.pc}:{info.inst.disasm()}" + # special handling for s_endpgm, it marks the wave completion. + if info.inst == s_endpgm(): + completed_wave = list(rwaves_iter[info.wave].pop(0)) + assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}" + # otherwise the packet timestamp is time + "stall" + else: + assert pkt._time == rocprof_inst.time+rocprof_inst.stall + passed_insts += 1 + + for k,v in rwaves_iter.items(): + assert len(v) == 0, f"incomplete wave {k}" + + print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units") + +if __name__ == "__main__": + import argparse, pickle, pathlib + from tinygrad.helpers import temp + parser = argparse.ArgumentParser() + parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)', + default=pathlib.Path(temp("profile.pkl", append_user=True))) + parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)') + args = parser.parse_args() + with open(args.profile, "rb") as f: + data = pickle.load(f) + sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"] + kern_events = {e.name:e for e in data if type(e).__name__ == "ProfileProgramEvent"} + target = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"))).props["gfx_target_version"] + for e in sqtt_events: + if args.kernel is not None and args.kernel != e.kern: continue + if not e.itrace: continue + print(f"==== {e.kern}") + test_rocprof_inst_traces_match(e, kern_events[e.kern], target) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index e4f7bea2c2..fbb9cbd1e3 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -175,6 +175,7 @@ def timeline_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts: elif isinstance(e.name, TracingKey): name = e.name.display_name ref = next((v for k in e.name.keys if (v:=ref_map.get(k)) is not None), None) + if isinstance(e.name.ret, str): fmt.append(e.name.ret) events.append(struct.pack(" None: # to decode a SQTT trace, we need the raw stream, program binary and device properties if (sqtt:=v.get(ProfileSQTTEvent)): for e in sqtt: - if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=e)) + if e.itrace: steps.append(create_step(f"PKTS SE:{e.se}", (f"/prg-pkts-{e.se}", len(ctxs), len(steps)), data=(e, prg_events[k]))) steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k]))) ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) -def sqtt_timeline(e) -> list[ProfileEvent]: - from extra.assembly.amd.sqtt import decode, PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC +def sqtt_timeline(data) -> list[ProfileEvent]: + from extra.assembly.amd.sqttmap import map_insts, InstructionInfo + from extra.assembly.amd.sqtt import PacketType, INST, InstOp, VALUINST, IMMEDIATE, IMMEDIATE_MASK, VMEMEXEC, ALUEXEC + e, prg = data ret:list[ProfileEvent] = [] rows:dict[str, None] = {} trace:dict[str, set[int]] = {} - def add(name:str, p:PacketType, idx=0, width=1, op_name=None, wave=None) -> None: + def add(name:str, p:PacketType, idx=0, width=1, op_name=None, wave=None, info:InstructionInfo|None=None) -> None: if hasattr(p, "wave"): wave = p.wave rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}")) - ret.append(ProfileRangeEvent(r, f"{op_name if op_name is not None else name} OP:{idx}", Decimal(p._time), Decimal(p._time+width))) - for p in decode(e.blob): + key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=info.inst.disasm() if info is not None else None) + ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width))) + for p, info in map_insts(e.blob, prg.lib): if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break if isinstance(p, INST): op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}" name, width = (op_name, 10 if "BARRIER" in op_name else 1) - add(name, p, width=width, idx=int("OTHER" in name)) - if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p) - if isinstance(p, IMMEDIATE_MASK): - for wave in range(16): - if p.mask & (1 << wave): add("IMMEDIATE", p, wave=wave) + add(name, p, width=width, idx=int("OTHER" in name), info=info) + if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p, info=info) + if isinstance(p, IMMEDIATE_MASK): add("IMMEDIATE", p, wave=unwrap(info.wave), info=info) if isinstance(p, (VMEMEXEC, ALUEXEC)): name = str(p.src).split('.')[1] if name == "VALU_SALU":