assembly/amd: disasm is test only (#14694)

* assembly/amd: disasm is test only

* viz uses str
This commit is contained in:
George Hotz
2026-02-12 12:33:46 +08:00
committed by GitHub
parent c331798201
commit befc1e800c
14 changed files with 103 additions and 71 deletions

View File

@@ -403,9 +403,7 @@ class Inst:
@classmethod
def _size(cls) -> int: return cls._base_size
def size(self) -> int: return self._base_size
def disasm(self) -> str:
from extra.assembly.amd.disasm import disasm
return disasm(self)
def disasm(self) -> str: raise NotImplementedError("disasm is no longer supported")
def to_bytes(self) -> bytes: return self._raw.to_bytes(self._base_size, 'little')

View File

@@ -66,57 +66,3 @@ def map_insts(data:bytes, lib:bytes, target:int) -> Iterator[tuple[PacketType, I
# for all other packets (VMEMEXEC, ALUEXEC, etc.), yield with None
yield (p, None)
# test to compare every packet with the rocprof decoder
def test_rocprof_inst_traces_match(sqtt, prg, target):
from tinygrad.viz.serve import amd_decode
from extra.sqtt.roc import decode as roc_decode, InstExec
addr_table = amd_decode(prg.lib, target)
disasm = {addr+prg.base:(inst.disasm(), inst.size()) for addr,inst in addr_table.items()}
rctx = roc_decode([sqtt], {prg.tag:disasm})
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts())
passed_insts = 0
for pkt, info in map_insts(sqtt.blob, prg.lib, target):
if DEBUG >= 2: print_packets([pkt])
if info is None: continue
if DEBUG >= 2: print(f"{' '*29}{info.inst.disasm()}")
rocprof_inst = next(rwaves_iter[info.wave][0])
ref_pc = rocprof_inst.pc-prg.base
# always check pc matches
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm[rocprof_inst.pc][0]} != {info.pc}:{info.inst.disasm()}"
# special handling for s_endpgm, it marks the wave completion.
if info.inst == s_endpgm():
completed_wave = list(rwaves_iter[info.wave].pop(0))
assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}"
# otherwise the packet timestamp is time + "stall"
else:
assert pkt._time == rocprof_inst.time+rocprof_inst.stall
passed_insts += 1
for k,v in rwaves_iter.items():
assert len(v) == 0, f"incomplete wave {k}"
if len(rwaves):
print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units")
if __name__ == "__main__":
import argparse, pickle, pathlib
from tinygrad.helpers import temp, DEBUG
parser = argparse.ArgumentParser()
parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)',
default=pathlib.Path(temp("profile.pkl", append_user=True)))
parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)')
args = parser.parse_args()
with open(args.profile, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
target = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"))).props["gfx_target_version"]
for e in sqtt_events:
if args.kernel is not None and args.kernel != e.kern: continue
if not e.itrace: continue
print(f"==== {e.kern}")
test_rocprof_inst_traces_match(e, kern_events[e.kern], target)

View File

@@ -7,10 +7,11 @@ from tinygrad.runtime.support.compiler_amd import HIPCompiler
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import s, v, Inst
from extra.assembly.amd.test.disasm import disasm as disasm_inst
def assemble_insts(insts:list[Inst], name:str, arch:str, kernarg_size:int=8) -> tuple[UOp, UOp]:
kd = {"kernarg_size":kernarg_size, "user_sgpr_kernarg_segment_ptr":1, "next_free_vgpr":8, "next_free_sgpr":8, "wavefront_size32":1}
disasm = "\n".join([inst.disasm() for inst in insts])
disasm = "\n".join([disasm_inst(inst) for inst in insts])
hsasrc = f".text\n.globl {name}\n.p2align 8\n.type fn_name,@function\n{name}:\n{disasm}\ns_code_end\n"
hsasrc += f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n"+"\n".join([f".amdhsa_{k} {v}" for k,v in kd.items()])+"\n.end_amdhsa_kernel"
binary = HIPCompiler(arch).compile(hsasrc)

View File

@@ -5,6 +5,7 @@ import unittest, struct
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.test.test_roundtrip import compile_asm
from extra.assembly.amd.test.disasm import disasm
class IntegrationTestBase(unittest.TestCase):
inst: Inst
@@ -12,7 +13,7 @@ class IntegrationTestBase(unittest.TestCase):
def tearDown(self):
if not hasattr(self, 'inst'): return
b = self.inst.to_bytes()
st = self.inst.disasm()
st = disasm(self.inst)
# Test that the instruction can be compiled by LLVM and produces the same bytes
desc = f"{st:25s} {self.inst} {b!r}"
self.assertEqual(b, compile_asm(st, arch=self.arch), desc)
@@ -160,9 +161,9 @@ class TestRegisterSliceSyntax(unittest.TestCase):
# Round-trip: DSL -> disasm -> DSL should preserve register count
reg = s[4:7] # 4 registers in AMD convention
inst = s_load_b128(reg, s[0:1], NULL, 0)
disasm = inst.disasm()
d = disasm(inst)
# Disasm shows s[4:7] - user should be able to copy this back
self.assertIn("s[4:7]", disasm)
self.assertIn("s[4:7]", d)
# And s[4:7] in DSL should give the same 4 registers
reg_from_disasm = s[4:7]
self.assertEqual(reg_from_disasm.sz, 4, "s[4:7] from disasm should give 4 registers")

View File

@@ -10,7 +10,7 @@ Only compute-relevant instruction formats are tested. Graphics-only formats not
"""
import unittest, re, subprocess, functools
from tinygrad.helpers import fetch
from extra.assembly.amd.disasm import disasm
from extra.assembly.amd.test.disasm import disasm
from extra.assembly.amd import decode_inst, detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_target, get_mattr

View File

@@ -2,6 +2,7 @@
import unittest, subprocess
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.test.helpers import get_llvm_mc
from extra.assembly.amd.test.disasm import disasm
def llvm_assemble(asm: str) -> bytes:
"""Assemble using llvm-mc and return bytes."""
@@ -67,7 +68,7 @@ global_store_b32 v[0:1], v2, off
s_endpgm
"""
expected = llvm_assemble(asm)
for inst,rt in zip(program, asm.strip().split("\n")): print(f"{inst.disasm():50s} {rt}")
for inst,rt in zip(program, asm.strip().split("\n")): print(f"{disasm(inst):50s} {rt}")
actual = b''.join(inst.to_bytes() for inst in program)
self.assertEqual(actual, expected)

View File

@@ -4,6 +4,7 @@ import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd import decode_inst, detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr
from extra.assembly.amd.test.disasm import disasm
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
@@ -113,7 +114,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
size = decoded.size() # actual size including literal
orig_bytes = remaining[:size]
reencoded = decoded.to_bytes()
our_disasm = decoded.disasm()
our_disasm = disasm(decoded)
decode_ok = reencoded == orig_bytes
decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))

View File

@@ -0,0 +1,79 @@
# test to compare every packet with the rocprof decoder
import unittest, pickle
from typing import Iterator
from pathlib import Path
from tinygrad.helpers import DEBUG
from extra.assembly.amd.sqtt import print_packets
from extra.assembly.amd.sqttmap import map_insts
from extra.assembly.amd.autogen.rdna3.ins import s_endpgm
from extra.assembly.amd.test.disasm import disasm
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
def rocprof_inst_traces_match(sqtt, prg, target):
from tinygrad.viz.serve import amd_decode
from extra.sqtt.roc import decode as roc_decode, InstExec
addr_table = amd_decode(prg.lib, target)
disasm_map = {addr+prg.base:(disasm(inst), inst.size()) for addr,inst in addr_table.items()}
rctx = roc_decode([sqtt], {prg.tag:disasm_map})
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts())
if not rwaves: return 0, 0, 0
passed_insts = 0
for pkt, info in map_insts(sqtt.blob, prg.lib, target):
if DEBUG >= 2: print_packets([pkt])
if info is None: continue
if DEBUG >= 2: print(f"{' '*29}{disasm(info.inst)}")
rocprof_inst = next(rwaves_iter[info.wave][0])
ref_pc = rocprof_inst.pc-prg.base
# always check pc matches
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc][0]} != {info.pc}:{disasm(info.inst)}"
# special handling for s_endpgm, it marks the wave completion.
if info.inst == s_endpgm():
completed_wave = list(rwaves_iter[info.wave].pop(0))
assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}"
# otherwise the packet timestamp is time + "stall"
else:
assert pkt._time == rocprof_inst.time+rocprof_inst.stall
passed_insts += 1
for k,v in rwaves_iter.items():
assert len(v) == 0, f"incomplete wave {k}"
return passed_insts, len(rwaves), len(rwaves_iter)
class TestSQTTMapBase(unittest.TestCase):
target: str
@classmethod
def setUpClass(cls):
if cls is TestSQTTMapBase: raise unittest.SkipTest("base class")
cls.examples = {}
for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")):
with open(pkl_path, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
dev = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD")), None)
if sqtt_events and kern_events and dev:
cls.examples[pkl_path.stem] = (sqtt_events, kern_events, dev.props["gfx_target_version"])
def test_rocprof_inst_traces_match(self):
for name, (events, kern_events, target) in self.examples.items():
for event in events:
if not event.itrace: continue
if event.kern not in kern_events: continue
with self.subTest(example=name, kern=event.kern):
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target)
if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units")
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
@unittest.skip("this doesn't work")
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
if __name__ == "__main__":
unittest.main()

View File

@@ -192,7 +192,8 @@ class Kernel:
inst.simm16 = offset_dwords
# TODO: replace this with direct ELF
body = ['\t' + inst.disasm() for inst in self.instructions]
from extra.assembly.amd.test.disasm import disasm
body = ['\t' + disasm(inst) for inst in self.instructions]
# limit wave occupancy by using more LDS
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))

View File

@@ -73,7 +73,8 @@ class Kernel:
lines, pos = [], 0
for inst in self.instructions:
if (label := self.label_at_pos.get(pos)) is not None: lines.append(f"{label}:")
lines.append(f" {inst.disasm()}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}")
from extra.assembly.amd.test.disasm import disasm
lines.append(f" {disasm(inst)}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}")
pos += inst.size()
return "\n".join(lines)

View File

@@ -5,7 +5,8 @@ from extra.assembly.amd.dsl import s, v, Inst, NULL
def assemble_kernel(insts:list[Inst], name:str="test") -> str:
kd = {"next_free_vgpr": 8, "next_free_sgpr": 8, "wavefront_size32": 1, "user_sgpr_kernarg_segment_ptr": 1, "kernarg_size": 8}
disasm = "\n".join(inst.disasm() for inst in insts)
from extra.assembly.amd.test.disasm import disasm as _disasm
disasm = "\n".join(_disasm(inst) for inst in insts)
hsasrc = f".text\n.globl {name}\n.p2align 8\n.type {name},@function\n{name}:\n{disasm}\n"
return hsasrc + f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n" + "\n".join(f".amdhsa_{k} {v}" for k, v in kd.items()) + "\n.end_amdhsa_kernel"

View File

@@ -36,7 +36,9 @@ class TestGPUCrash(unittest.TestCase):
prg = AMDProgram(self.dev, "test", self.compiler.compile(assemble(code)))
prg(self.dev.allocator.alloc(64), global_size=(1,1,1), local_size=(1,1,1), wait=True)
def _run_insts(self, insts: list[Inst]): self._run("\n".join(i.disasm() for i in insts))
def _run_insts(self, insts: list[Inst]):
from extra.assembly.amd.test.disasm import disasm
self._run("\n".join(disasm(i) for i in insts))
def _assert_gpu_fault(self, func):
"""Assert that func raises a RuntimeError indicating a GPU fault (not a setup error)."""

View File

@@ -316,7 +316,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:int) -> list[ProfileEvent]:
def add(name:str, p:PacketType, idx=0, width=1, op_name=None, wave=None, info:InstructionInfo|None=None) -> None:
if hasattr(p, "wave"): wave = p.wave
rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}"))
key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=info.inst.disasm() if info is not None else None)
key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=str(info.inst) if info is not None else None)
ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width)))
for p, info in map_insts(data, lib, target):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
@@ -344,7 +344,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
from extra.sqtt.roc import decode
base = unwrap(p.base)
addr_table = amd_decode(unwrap(p.lib), amdgpu_targets[p.device])
disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()}
disasm:dict[int, tuple[str, int]] = {addr+base:(str(inst), inst.size()) for addr, inst in addr_table.items()}
rctx = decode(data, {p.tag:disasm})
cu_events:dict[str, list[ProfileEvent]] = {}
# * INST waves
@@ -470,7 +470,7 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
blocks:dict[int, list[int]] = {}
paths:dict[int, dict[int, int]] = {}
lines:list[str] = []
disasm = {pc:inst.disasm() for pc,inst in pc_table.items()}
disasm = {pc:str(inst) for pc,inst in pc_table.items()}
asm_width = max(len(asm) for asm in disasm.values())
for pc, inst in pc_table.items():
# skip instructions only used for padding