mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
assembly/amd: disasm is test only (#14694)
* assembly/amd: disasm is test only * viz uses str
This commit is contained in:
@@ -403,9 +403,7 @@ class Inst:
|
||||
@classmethod
|
||||
def _size(cls) -> int: return cls._base_size
|
||||
def size(self) -> int: return self._base_size
|
||||
def disasm(self) -> str:
|
||||
from extra.assembly.amd.disasm import disasm
|
||||
return disasm(self)
|
||||
def disasm(self) -> str: raise NotImplementedError("disasm is no longer supported")
|
||||
|
||||
def to_bytes(self) -> bytes: return self._raw.to_bytes(self._base_size, 'little')
|
||||
|
||||
|
||||
@@ -66,57 +66,3 @@ def map_insts(data:bytes, lib:bytes, target:int) -> Iterator[tuple[PacketType, I
|
||||
# for all other packets (VMEMEXEC, ALUEXEC, etc.), yield with None
|
||||
yield (p, None)
|
||||
|
||||
# test to compare every packet with the rocprof decoder
|
||||
|
||||
def test_rocprof_inst_traces_match(sqtt, prg, target):
|
||||
from tinygrad.viz.serve import amd_decode
|
||||
from extra.sqtt.roc import decode as roc_decode, InstExec
|
||||
addr_table = amd_decode(prg.lib, target)
|
||||
disasm = {addr+prg.base:(inst.disasm(), inst.size()) for addr,inst in addr_table.items()}
|
||||
rctx = roc_decode([sqtt], {prg.tag:disasm})
|
||||
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
|
||||
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
|
||||
for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts())
|
||||
|
||||
passed_insts = 0
|
||||
for pkt, info in map_insts(sqtt.blob, prg.lib, target):
|
||||
if DEBUG >= 2: print_packets([pkt])
|
||||
if info is None: continue
|
||||
if DEBUG >= 2: print(f"{' '*29}{info.inst.disasm()}")
|
||||
rocprof_inst = next(rwaves_iter[info.wave][0])
|
||||
ref_pc = rocprof_inst.pc-prg.base
|
||||
# always check pc matches
|
||||
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm[rocprof_inst.pc][0]} != {info.pc}:{info.inst.disasm()}"
|
||||
# special handling for s_endpgm, it marks the wave completion.
|
||||
if info.inst == s_endpgm():
|
||||
completed_wave = list(rwaves_iter[info.wave].pop(0))
|
||||
assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}"
|
||||
# otherwise the packet timestamp is time + "stall"
|
||||
else:
|
||||
assert pkt._time == rocprof_inst.time+rocprof_inst.stall
|
||||
passed_insts += 1
|
||||
|
||||
for k,v in rwaves_iter.items():
|
||||
assert len(v) == 0, f"incomplete wave {k}"
|
||||
|
||||
if len(rwaves):
|
||||
print(f"passed for {passed_insts} instructions across {len(rwaves)} waves scheduled on {len(rwaves_iter)} wave units")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse, pickle, pathlib
|
||||
from tinygrad.helpers import temp, DEBUG
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--profile', type=pathlib.Path, metavar="PATH", help='Path to profile (optional file, default: latest profile)',
|
||||
default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
parser.add_argument('--kernel', type=str, default=None, metavar="NAME", help='Kernel to focus on (optional name, default: all kernels)')
|
||||
args = parser.parse_args()
|
||||
with open(args.profile, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
|
||||
kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
|
||||
target = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"))).props["gfx_target_version"]
|
||||
for e in sqtt_events:
|
||||
if args.kernel is not None and args.kernel != e.kern: continue
|
||||
if not e.itrace: continue
|
||||
print(f"==== {e.kern}")
|
||||
test_rocprof_inst_traces_match(e, kern_events[e.kern], target)
|
||||
|
||||
@@ -7,10 +7,11 @@ from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import s, v, Inst
|
||||
from extra.assembly.amd.test.disasm import disasm as disasm_inst
|
||||
|
||||
def assemble_insts(insts:list[Inst], name:str, arch:str, kernarg_size:int=8) -> tuple[UOp, UOp]:
|
||||
kd = {"kernarg_size":kernarg_size, "user_sgpr_kernarg_segment_ptr":1, "next_free_vgpr":8, "next_free_sgpr":8, "wavefront_size32":1}
|
||||
disasm = "\n".join([inst.disasm() for inst in insts])
|
||||
disasm = "\n".join([disasm_inst(inst) for inst in insts])
|
||||
hsasrc = f".text\n.globl {name}\n.p2align 8\n.type fn_name,@function\n{name}:\n{disasm}\ns_code_end\n"
|
||||
hsasrc += f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n"+"\n".join([f".amdhsa_{k} {v}" for k,v in kd.items()])+"\n.end_amdhsa_kernel"
|
||||
binary = HIPCompiler(arch).compile(hsasrc)
|
||||
|
||||
@@ -5,6 +5,7 @@ import unittest, struct
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd.test.test_roundtrip import compile_asm
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
|
||||
class IntegrationTestBase(unittest.TestCase):
|
||||
inst: Inst
|
||||
@@ -12,7 +13,7 @@ class IntegrationTestBase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
if not hasattr(self, 'inst'): return
|
||||
b = self.inst.to_bytes()
|
||||
st = self.inst.disasm()
|
||||
st = disasm(self.inst)
|
||||
# Test that the instruction can be compiled by LLVM and produces the same bytes
|
||||
desc = f"{st:25s} {self.inst} {b!r}"
|
||||
self.assertEqual(b, compile_asm(st, arch=self.arch), desc)
|
||||
@@ -160,9 +161,9 @@ class TestRegisterSliceSyntax(unittest.TestCase):
|
||||
# Round-trip: DSL -> disasm -> DSL should preserve register count
|
||||
reg = s[4:7] # 4 registers in AMD convention
|
||||
inst = s_load_b128(reg, s[0:1], NULL, 0)
|
||||
disasm = inst.disasm()
|
||||
d = disasm(inst)
|
||||
# Disasm shows s[4:7] - user should be able to copy this back
|
||||
self.assertIn("s[4:7]", disasm)
|
||||
self.assertIn("s[4:7]", d)
|
||||
# And s[4:7] in DSL should give the same 4 registers
|
||||
reg_from_disasm = s[4:7]
|
||||
self.assertEqual(reg_from_disasm.sz, 4, "s[4:7] from disasm should give 4 registers")
|
||||
|
||||
@@ -10,7 +10,7 @@ Only compute-relevant instruction formats are tested. Graphics-only formats not
|
||||
"""
|
||||
import unittest, re, subprocess, functools
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.disasm import disasm
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
from extra.assembly.amd import decode_inst, detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc, get_target, get_mattr
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import unittest, subprocess
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
|
||||
def llvm_assemble(asm: str) -> bytes:
|
||||
"""Assemble using llvm-mc and return bytes."""
|
||||
@@ -67,7 +68,7 @@ global_store_b32 v[0:1], v2, off
|
||||
s_endpgm
|
||||
"""
|
||||
expected = llvm_assemble(asm)
|
||||
for inst,rt in zip(program, asm.strip().split("\n")): print(f"{inst.disasm():50s} {rt}")
|
||||
for inst,rt in zip(program, asm.strip().split("\n")): print(f"{disasm(inst):50s} {rt}")
|
||||
actual = b''.join(inst.to_bytes() for inst in program)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import unittest, io, sys, re, subprocess, os
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd import decode_inst, detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
|
||||
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
|
||||
@@ -113,7 +114,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
size = decoded.size() # actual size including literal
|
||||
orig_bytes = remaining[:size]
|
||||
reencoded = decoded.to_bytes()
|
||||
our_disasm = decoded.disasm()
|
||||
our_disasm = disasm(decoded)
|
||||
decode_ok = reencoded == orig_bytes
|
||||
decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))
|
||||
|
||||
79
extra/assembly/amd/test/test_sqttmap.py
Normal file
79
extra/assembly/amd/test/test_sqttmap.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# test to compare every packet with the rocprof decoder
|
||||
import unittest, pickle
|
||||
from typing import Iterator
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import DEBUG
|
||||
from extra.assembly.amd.sqtt import print_packets
|
||||
from extra.assembly.amd.sqttmap import map_insts
|
||||
from extra.assembly.amd.autogen.rdna3.ins import s_endpgm
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
|
||||
|
||||
def rocprof_inst_traces_match(sqtt, prg, target):
|
||||
from tinygrad.viz.serve import amd_decode
|
||||
from extra.sqtt.roc import decode as roc_decode, InstExec
|
||||
addr_table = amd_decode(prg.lib, target)
|
||||
disasm_map = {addr+prg.base:(disasm(inst), inst.size()) for addr,inst in addr_table.items()}
|
||||
rctx = roc_decode([sqtt], {prg.tag:disasm_map})
|
||||
rwaves = rctx.inst_execs.get((sqtt.kern, sqtt.exec_tag), [])
|
||||
rwaves_iter:dict[int, list[Iterator[InstExec]]] = {} # wave unit (0-15) -> list of inst trace iterators for all executions on that unit
|
||||
for w in rwaves: rwaves_iter.setdefault(w.wave_id, []).append(w.unpack_insts())
|
||||
|
||||
if not rwaves: return 0, 0, 0
|
||||
|
||||
passed_insts = 0
|
||||
for pkt, info in map_insts(sqtt.blob, prg.lib, target):
|
||||
if DEBUG >= 2: print_packets([pkt])
|
||||
if info is None: continue
|
||||
if DEBUG >= 2: print(f"{' '*29}{disasm(info.inst)}")
|
||||
rocprof_inst = next(rwaves_iter[info.wave][0])
|
||||
ref_pc = rocprof_inst.pc-prg.base
|
||||
# always check pc matches
|
||||
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc][0]} != {info.pc}:{disasm(info.inst)}"
|
||||
# special handling for s_endpgm, it marks the wave completion.
|
||||
if info.inst == s_endpgm():
|
||||
completed_wave = list(rwaves_iter[info.wave].pop(0))
|
||||
assert len(completed_wave) == 0, f"incomplete instructions in wave {info.wave}"
|
||||
# otherwise the packet timestamp is time + "stall"
|
||||
else:
|
||||
assert pkt._time == rocprof_inst.time+rocprof_inst.stall
|
||||
passed_insts += 1
|
||||
|
||||
for k,v in rwaves_iter.items():
|
||||
assert len(v) == 0, f"incomplete wave {k}"
|
||||
|
||||
return passed_insts, len(rwaves), len(rwaves_iter)
|
||||
|
||||
class TestSQTTMapBase(unittest.TestCase):
|
||||
target: str
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls is TestSQTTMapBase: raise unittest.SkipTest("base class")
|
||||
cls.examples = {}
|
||||
for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")):
|
||||
with open(pkl_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
|
||||
kern_events = {e.tag:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
|
||||
dev = next((e for e in data if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD")), None)
|
||||
if sqtt_events and kern_events and dev:
|
||||
cls.examples[pkl_path.stem] = (sqtt_events, kern_events, dev.props["gfx_target_version"])
|
||||
|
||||
def test_rocprof_inst_traces_match(self):
|
||||
for name, (events, kern_events, target) in self.examples.items():
|
||||
for event in events:
|
||||
if not event.itrace: continue
|
||||
if event.kern not in kern_events: continue
|
||||
with self.subTest(example=name, kern=event.kern):
|
||||
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target)
|
||||
if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units")
|
||||
|
||||
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
||||
@unittest.skip("this doesn't work")
|
||||
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -192,7 +192,8 @@ class Kernel:
|
||||
inst.simm16 = offset_dwords
|
||||
|
||||
# TODO: replace this with direct ELF
|
||||
body = ['\t' + inst.disasm() for inst in self.instructions]
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
body = ['\t' + disasm(inst) for inst in self.instructions]
|
||||
|
||||
# limit wave occupancy by using more LDS
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
|
||||
|
||||
@@ -73,7 +73,8 @@ class Kernel:
|
||||
lines, pos = [], 0
|
||||
for inst in self.instructions:
|
||||
if (label := self.label_at_pos.get(pos)) is not None: lines.append(f"{label}:")
|
||||
lines.append(f" {inst.disasm()}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}")
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
lines.append(f" {disasm(inst)}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}")
|
||||
pos += inst.size()
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ from extra.assembly.amd.dsl import s, v, Inst, NULL
|
||||
|
||||
def assemble_kernel(insts:list[Inst], name:str="test") -> str:
|
||||
kd = {"next_free_vgpr": 8, "next_free_sgpr": 8, "wavefront_size32": 1, "user_sgpr_kernarg_segment_ptr": 1, "kernarg_size": 8}
|
||||
disasm = "\n".join(inst.disasm() for inst in insts)
|
||||
from extra.assembly.amd.test.disasm import disasm as _disasm
|
||||
disasm = "\n".join(_disasm(inst) for inst in insts)
|
||||
hsasrc = f".text\n.globl {name}\n.p2align 8\n.type {name},@function\n{name}:\n{disasm}\n"
|
||||
return hsasrc + f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n" + "\n".join(f".amdhsa_{k} {v}" for k, v in kd.items()) + "\n.end_amdhsa_kernel"
|
||||
|
||||
|
||||
4
test/external/external_test_gpu_crash.py
vendored
4
test/external/external_test_gpu_crash.py
vendored
@@ -36,7 +36,9 @@ class TestGPUCrash(unittest.TestCase):
|
||||
prg = AMDProgram(self.dev, "test", self.compiler.compile(assemble(code)))
|
||||
prg(self.dev.allocator.alloc(64), global_size=(1,1,1), local_size=(1,1,1), wait=True)
|
||||
|
||||
def _run_insts(self, insts: list[Inst]): self._run("\n".join(i.disasm() for i in insts))
|
||||
def _run_insts(self, insts: list[Inst]):
|
||||
from extra.assembly.amd.test.disasm import disasm
|
||||
self._run("\n".join(disasm(i) for i in insts))
|
||||
|
||||
def _assert_gpu_fault(self, func):
|
||||
"""Assert that func raises a RuntimeError indicating a GPU fault (not a setup error)."""
|
||||
|
||||
@@ -316,7 +316,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:int) -> list[ProfileEvent]:
|
||||
def add(name:str, p:PacketType, idx=0, width=1, op_name=None, wave=None, info:InstructionInfo|None=None) -> None:
|
||||
if hasattr(p, "wave"): wave = p.wave
|
||||
rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}"))
|
||||
key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=info.inst.disasm() if info is not None else None)
|
||||
key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=str(info.inst) if info is not None else None)
|
||||
ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width)))
|
||||
for p, info in map_insts(data, lib, target):
|
||||
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
|
||||
@@ -344,7 +344,7 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
|
||||
from extra.sqtt.roc import decode
|
||||
base = unwrap(p.base)
|
||||
addr_table = amd_decode(unwrap(p.lib), amdgpu_targets[p.device])
|
||||
disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()}
|
||||
disasm:dict[int, tuple[str, int]] = {addr+base:(str(inst), inst.size()) for addr, inst in addr_table.items()}
|
||||
rctx = decode(data, {p.tag:disasm})
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
# * INST waves
|
||||
@@ -470,7 +470,7 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict:
|
||||
blocks:dict[int, list[int]] = {}
|
||||
paths:dict[int, dict[int, int]] = {}
|
||||
lines:list[str] = []
|
||||
disasm = {pc:inst.disasm() for pc,inst in pc_table.items()}
|
||||
disasm = {pc:str(inst) for pc,inst in pc_table.items()}
|
||||
asm_width = max(len(asm) for asm in disasm.values())
|
||||
for pc, inst in pc_table.items():
|
||||
# skip instructions only used for padding
|
||||
|
||||
Reference in New Issue
Block a user