From befc1e800c8faefe783ee78ece235206a90c595f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 12 Feb 2026 12:33:46 +0800 Subject: [PATCH] assembly/amd: disasm is test only (#14694) * assembly/amd: disasm is test only * viz uses str --- extra/assembly/amd/dsl.py | 4 +- extra/assembly/amd/sqttmap.py | 54 ------------- extra/assembly/amd/{ => test}/disasm.py | 0 extra/assembly/amd/test/test_custom_kernel.py | 3 +- extra/assembly/amd/test/test_handwritten.py | 7 +- extra/assembly/amd/test/test_llvm.py | 2 +- extra/assembly/amd/test/test_rdna3_asm.py | 3 +- extra/assembly/amd/test/test_roundtrip.py | 3 +- extra/assembly/amd/test/test_sqttmap.py | 79 +++++++++++++++++++ extra/gemm/amd_asm_matmul.py | 3 +- extra/gemm/asm/cdna/asm.py | 3 +- .../external_test_am_fault_recovery.py | 3 +- test/external/external_test_gpu_crash.py | 4 +- tinygrad/viz/serve.py | 6 +- 14 files changed, 103 insertions(+), 71 deletions(-) rename extra/assembly/amd/{ => test}/disasm.py (100%) create mode 100644 extra/assembly/amd/test/test_sqttmap.py diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 1eed49aae5..87e3be4b30 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -403,9 +403,7 @@ class Inst: @classmethod def _size(cls) -> int: return cls._base_size def size(self) -> int: return self._base_size - def disasm(self) -> str: - from extra.assembly.amd.disasm import disasm - return disasm(self) + def disasm(self) -> str: raise NotImplementedError("disasm is no longer supported") def to_bytes(self) -> bytes: return self._raw.to_bytes(self._base_size, 'little') diff --git a/extra/assembly/amd/sqttmap.py b/extra/assembly/amd/sqttmap.py index addb481eec..efbf7d9ee1 100644 --- a/extra/assembly/amd/sqttmap.py +++ b/extra/assembly/amd/sqttmap.py @@ -66,57 +66,3 @@ def map_insts(data:bytes, lib:bytes, target:int) -> Iterator[tuple[PacketType, I # for all other packets (VMEMEXEC, ALUEXEC, etc.), yield with None yield (p, None) -# test to compare every packet with the rocprof decoder - -def test_rocprof_inst_traces_match(sqtt, prg, target): - from tinygrad.viz.serve import amd_decode - from extra.sqtt.roc import decode as roc_decode, InstExec - addr_table = amd_decode(prg.lib, target) - disasm = {addr+prg.base:(inst.disasm(), inst.size()) for addr,inst in addr_table.items()} - rctx = roc_decode([sqtt], {prg.tag:disasm}) - rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), []) - rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit - for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts()) - - passed_insts = 0 - for pkt, info in map_insts(sqtt.blob, prg.lib, target): - if DEBUG >= 2: print_packets([pkt]) - if info is None: continue - if DEBUG >= 2: print(f"{' '*29}{info.inst.disasm()}") - rocprof_inst = next(rwaves_iter[info.wave][0]) - ref_pc = rocprof_inst.pc-prg.base - # always check pc matches - assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm[rocprof_inst.pc][0]} != {info.pc}:{info.inst.disasm()}" - # special handling for s_endpgm, it marks the wave completion. - if info.inst == s_endpgm(): - completed_wave = list(rwaves_iter[info.wave].pop(0)) - assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}" - # otherwise the packet timestamp is time + "stall" - else: - assert pkt._time == rocprof_inst.time+rocprof_inst.stall - passed_insts += 1 - - for k,v in rwaves_iter.items(): - assert len(v) == 0, f"incomplete wave {k}" - - if len(rwaves): - print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units") - -if __name__ == "__main__": - import argparse, pickle, pathlib - from tinygrad.helpers import temp, DEBUG - parser = argparse.ArgumentParser() - parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)', - default=pathlib.Path(temp("profile.pkl", append_user=True))) - parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)') - args = parser.parse_args() - with open(args.profile, "rb") as f: - data = pickle.load(f) - sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"] - kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"} - target = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"))).props["gfx_target_version"] - for e in sqtt_events: - if args.kernel is not None and args.kernel != e.kern: continue - if not e.itrace: continue - print(f"==== {e.kern}") - test_rocprof_inst_traces_match(e, kern_events[e.kern], target) diff --git a/extra/assembly/amd/disasm.py b/extra/assembly/amd/test/disasm.py similarity index 100% rename from extra/assembly/amd/disasm.py rename to extra/assembly/amd/test/disasm.py diff --git a/extra/assembly/amd/test/test_custom_kernel.py b/extra/assembly/amd/test/test_custom_kernel.py index 2884246f2c..ea65972524 100644 --- a/extra/assembly/amd/test/test_custom_kernel.py +++ b/extra/assembly/amd/test/test_custom_kernel.py @@ -7,10 +7,11 @@ from tinygrad.runtime.support.compiler_amd import HIPCompiler from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.dsl import s, v, Inst +from extra.assembly.amd.test.disasm import disasm as disasm_inst def assemble_insts(insts:list[Inst], name:str, arch:str, kernarg_size:int=8) -> tuple[UOp, UOp]: kd = {"kernarg_size":kernarg_size, "user_sgpr_kernarg_segment_ptr":1, "next_free_vgpr":8, "next_free_sgpr":8, "wavefront_size32":1} - disasm = "\n".join([inst.disasm() for inst in insts]) + disasm = "\n".join([disasm_inst(inst) for inst in insts]) hsasrc = f".text\n.globl {name}\n.p2align 8\n.type fn_name,@function\n{name}:\n{disasm}\ns_code_end\n" hsasrc += f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n"+"\n".join([f".amdhsa_{k} {v}" for k,v in kd.items()])+"\n.end_amdhsa_kernel" binary = HIPCompiler(arch).compile(hsasrc) diff --git a/extra/assembly/amd/test/test_handwritten.py b/extra/assembly/amd/test/test_handwritten.py index dbb40dcb04..6317c2c680 100644 --- a/extra/assembly/amd/test/test_handwritten.py +++ b/extra/assembly/amd/test/test_handwritten.py @@ -5,6 +5,7 @@ import unittest, struct from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.dsl import Inst from extra.assembly.amd.test.test_roundtrip import compile_asm +from extra.assembly.amd.test.disasm import disasm class IntegrationTestBase(unittest.TestCase): inst: Inst @@ -12,7 +13,7 @@ class IntegrationTestBase(unittest.TestCase): def tearDown(self): if not hasattr(self, 'inst'): return b = self.inst.to_bytes() - st = self.inst.disasm() + st = disasm(self.inst) # Test that the instruction can be compiled by LLVM and produces the same bytes desc = f"{st:25s} {self.inst} {b!r}" self.assertEqual(b, compile_asm(st, arch=self.arch), desc) @@ -160,9 +161,9 @@ class TestRegisterSliceSyntax(unittest.TestCase): # Round-trip: DSL -> disasm -> DSL should preserve register count reg = s[4:7] # 4 registers in AMD convention inst = s_load_b128(reg, s[0:1], NULL, 0) - disasm = inst.disasm() + d = disasm(inst) # Disasm shows s[4:7] - user should be able to copy this back - self.assertIn("s[4:7]", disasm) + self.assertIn("s[4:7]", d) # And s[4:7] in DSL should give the same 4 registers reg_from_disasm = s[4:7] self.assertEqual(reg_from_disasm.sz, 4, "s[4:7] from disasm should give 4 registers") diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index 332c755173..60da5a9793 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -10,7 +10,7 @@ Only compute-relevant instruction formats are tested. Graphics-only formats not """ import unittest, re, subprocess, functools from tinygrad.helpers import fetch -from extra.assembly.amd.disasm import disasm +from extra.assembly.amd.test.disasm import disasm from extra.assembly.amd import decode_inst, detect_format from extra.assembly.amd.test.helpers import get_llvm_mc, get_target, get_mattr diff --git a/extra/assembly/amd/test/test_rdna3_asm.py b/extra/assembly/amd/test/test_rdna3_asm.py index 74c1f145c5..c157773311 100644 --- a/extra/assembly/amd/test/test_rdna3_asm.py +++ b/extra/assembly/amd/test/test_rdna3_asm.py @@ -2,6 +2,7 @@ import unittest, subprocess from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.test.helpers import get_llvm_mc +from extra.assembly.amd.test.disasm import disasm def llvm_assemble(asm: str) -> bytes: """Assemble using llvm-mc and return bytes.""" @@ -67,7 +68,7 @@ global_store_b32 v[0:1], v2, off s_endpgm """ expected = llvm_assemble(asm) - for inst,rt in zip(program, asm.strip().split("\n")): print(f"{inst.disasm():50s} {rt}") + for inst,rt in zip(program, asm.strip().split("\n")): print(f"{disasm(inst):50s} {rt}") actual = b''.join(inst.to_bytes() for inst in program) self.assertEqual(actual, expected) diff --git a/extra/assembly/amd/test/test_roundtrip.py b/extra/assembly/amd/test/test_roundtrip.py index 50cf95faae..9aa8a2f4a1 100644 --- a/extra/assembly/amd/test/test_roundtrip.py +++ b/extra/assembly/amd/test/test_roundtrip.py @@ -4,6 +4,7 @@ import unittest, io, sys, re, subprocess, os from extra.assembly.amd.dsl import Inst from extra.assembly.amd import decode_inst, detect_format from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr +from extra.assembly.amd.test.disasm import disasm def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]: """Disassemble ELF binary and return list of (instruction_text, machine_code_bytes).""" @@ -113,7 +114,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): size = decoded.size() # actual size including literal orig_bytes = remaining[:size] reencoded = decoded.to_bytes() - our_disasm = decoded.disasm() + our_disasm = disasm(decoded) decode_ok = reencoded == orig_bytes decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}" decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)) diff --git a/extra/assembly/amd/test/test_sqttmap.py b/extra/assembly/amd/test/test_sqttmap.py new file mode 100644 index 0000000000..f9861fb16b --- /dev/null +++ b/extra/assembly/amd/test/test_sqttmap.py @@ -0,0 +1,79 @@ +# test to compare every packet with the rocprof decoder +import unittest, pickle +from typing import Iterator +from pathlib import Path +from tinygrad.helpers import DEBUG +from extra.assembly.amd.sqtt import print_packets +from extra.assembly.amd.sqttmap import map_insts +from extra.assembly.amd.autogen.rdna3.ins import s_endpgm +from extra.assembly.amd.test.disasm import disasm + +EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples" + +def rocprof_inst_traces_match(sqtt, prg, target): + from tinygrad.viz.serve import amd_decode + from extra.sqtt.roc import decode as roc_decode, InstExec + addr_table = amd_decode(prg.lib, target) + disasm_map = {addr+prg.base:(disasm(inst), inst.size()) for addr,inst in addr_table.items()} + rctx = roc_decode([sqtt], {prg.tag:disasm_map}) + rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), []) + rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit + for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts()) + + if not rwaves: return 0, 0, 0 + + passed_insts = 0 + for pkt, info in map_insts(sqtt.blob, prg.lib, target): + if DEBUG >= 2: print_packets([pkt]) + if info is None: continue + if DEBUG >= 2: print(f"{' '*29}{disasm(info.inst)}") + rocprof_inst = next(rwaves_iter[info.wave][0]) + ref_pc = rocprof_inst.pc-prg.base + # always check pc matches + assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc][0]} != {info.pc}:{disasm(info.inst)}" + # special handling for s_endpgm, it marks the wave completion. + if info.inst == s_endpgm(): + completed_wave = list(rwaves_iter[info.wave].pop(0)) + assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}" + # otherwise the packet timestamp is time + "stall" + else: + assert pkt._time == rocprof_inst.time+rocprof_inst.stall + passed_insts += 1 + + for k,v in rwaves_iter.items(): + assert len(v) == 0, f"incomplete wave {k}" + + return passed_insts, len(rwaves), len(rwaves_iter) + +class TestSQTTMapBase(unittest.TestCase): + target: str + + @classmethod + def setUpClass(cls): + if cls is TestSQTTMapBase: raise unittest.SkipTest("base class") + cls.examples = {} + for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")): + with open(pkl_path, "rb") as f: + data = pickle.load(f) + sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"] + kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"} + dev = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD")), None) + if sqtt_events and kern_events and dev: + cls.examples[pkl_path.stem] = (sqtt_events, kern_events, dev.props["gfx_target_version"]) + + def test_rocprof_inst_traces_match(self): + for name, (events, kern_events, target) in self.examples.items(): + for event in events: + if not event.itrace: continue + if event.kern not in kern_events: continue + with self.subTest(example=name, kern=event.kern): + passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target) + if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units") + +class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100" + +@unittest.skip("this doesn't work") +class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200" + +if __name__ == "__main__": + unittest.main() diff --git a/extra/gemm/amd_asm_matmul.py b/extra/gemm/amd_asm_matmul.py index c2dc9f7456..d36dea8525 100644 --- a/extra/gemm/amd_asm_matmul.py +++ b/extra/gemm/amd_asm_matmul.py @@ -192,7 +192,8 @@ class Kernel: inst.simm16 = offset_dwords # TODO: replace this with direct ELF - body = ['\t' + inst.disasm() for inst in self.instructions] + from extra.assembly.amd.test.disasm import disasm + body = ['\t' + disasm(inst) for inst in self.instructions] # limit wave occupancy by using more LDS lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)) diff --git a/extra/gemm/asm/cdna/asm.py b/extra/gemm/asm/cdna/asm.py index ed05646579..db8de34e7f 100644 --- a/extra/gemm/asm/cdna/asm.py +++ b/extra/gemm/asm/cdna/asm.py @@ -73,7 +73,8 @@ class Kernel: lines, pos = [], 0 for inst in self.instructions: if (label := self.label_at_pos.get(pos)) is not None: lines.append(f"{label}:") - lines.append(f" {inst.disasm()}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}") + from extra.assembly.amd.test.disasm import disasm + lines.append(f" {disasm(inst)}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}") pos += inst.size() return "\n".join(lines) diff --git a/test/external/external_test_am_fault_recovery.py b/test/external/external_test_am_fault_recovery.py index b637582e5d..68e5a5caba 100644 --- a/test/external/external_test_am_fault_recovery.py +++ b/test/external/external_test_am_fault_recovery.py @@ -5,7 +5,8 @@ from extra.assembly.amd.dsl import s, v, Inst, NULL def assemble_kernel(insts:list[Inst], name:str="test") -> str: kd = {"next_free_vgpr": 8, "next_free_sgpr": 8, "wavefront_size32": 1, "user_sgpr_kernarg_segment_ptr": 1, "kernarg_size": 8} - disasm = "\n".join(inst.disasm() for inst in insts) + from extra.assembly.amd.test.disasm import disasm as _disasm + disasm = "\n".join(_disasm(inst) for inst in insts) hsasrc = f".text\n.globl {name}\n.p2align 8\n.type {name},@function\n{name}:\n{disasm}\n" return hsasrc + f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n" + "\n".join(f".amdhsa_{k} {v}" for k, v in kd.items()) + "\n.end_amdhsa_kernel" diff --git a/test/external/external_test_gpu_crash.py b/test/external/external_test_gpu_crash.py index f1ed7c87b3..f6e3e96b20 100644 --- a/test/external/external_test_gpu_crash.py +++ b/test/external/external_test_gpu_crash.py @@ -36,7 +36,9 @@ class TestGPUCrash(unittest.TestCase): prg = AMDProgram(self.dev, "test", self.compiler.compile(assemble(code))) prg(self.dev.allocator.alloc(64), global_size=(1,1,1), local_size=(1,1,1), wait=True) - def _run_insts(self, insts: list[Inst]): self._run("\n".join(i.disasm() for i in insts)) + def _run_insts(self, insts: list[Inst]): + from extra.assembly.amd.test.disasm import disasm + self._run("\n".join(disasm(i) for i in insts)) def _assert_gpu_fault(self, func): """Assert that func raises a RuntimeError indicating a GPU fault (not a setup error).""" diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index a30cd2a447..8d7687e5df 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -316,7 +316,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:int) -> list[ProfileEvent]: def add(name:str, p:PacketType, idx=0, width=1, op_name=None, wave=None, info:InstructionInfo|None=None) -> None: if hasattr(p, "wave"): wave = p.wave rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}")) - key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=info.inst.disasm() if info is not None else None) + key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=str(info.inst) if info is not None else None) ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width))) for p, info in map_insts(data, lib, target): if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break @@ -344,7 +344,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[ from extra.sqtt.roc import decode base = unwrap(p.base) addr_table = amd_decode(unwrap(p.lib), amdgpu_targets[p.device]) - disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()} + disasm:dict[int, tuple[str, int]] = {addr+base:(str(inst), inst.size()) for addr, inst in addr_table.items()} rctx = decode(data, {p.tag:disasm}) cu_events:dict[str, list[ProfileEvent]] = {} # * INST waves @@ -470,7 +470,7 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict: blocks:dict[int, list[int]] = {} paths:dict[int, dict[int, int]] = {} lines:list[str] = [] - disasm = {pc:inst.disasm() for pc,inst in pc_table.items()} + disasm = {pc:str(inst) for pc,inst in pc_table.items()} asm_width = max(len(asm) for asm in disasm.values()) for pc, inst in pc_table.items(): # skip instructions only used for padding