From bf2d9d138ff47cd8965f84652e7bfa0a086c4dfd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 25 Jan 2026 01:21:45 -0500 Subject: [PATCH] viz: simplify amdgpu cfg (#14326) * viz: replace llvm disasm with our disasm * it starts with more code * then it becomes less * simpler, cdna disassembles with decimal simm16 * s_branch is upper case, add test * simm16s and others --- extra/assembly/amd/disasm.py | 2 +- test/testextra/test_cfg_viz.py | 2 + tinygrad/viz/serve.py | 73 +++++++++++++--------------------- 3 files changed, 30 insertions(+), 47 deletions(-) diff --git a/extra/assembly/amd/disasm.py b/extra/assembly/amd/disasm.py index 8284299b7e..df0278e8d4 100644 --- a/extra/assembly/amd/disasm.py +++ b/extra/assembly/amd/disasm.py @@ -258,7 +258,7 @@ def _disasm_sopp(inst: SOPP) -> str: dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v) p = [f"instid0({dep(id0)})" if id0 else "", f"instskip({skips[skip]})" if skip else "", f"instid1({dep(id1)})" if id1 else ""] return f"s_delay_alu {' | '.join(x for x in p if x) or '0'}" - if name.startswith(('s_cbranch', 's_branch')): return f"{name} 0x{inst.simm16:x}" + if name.startswith(('s_cbranch', 's_branch')): return f"{name} {inst.simm16}" return f"{name} 0x{inst.simm16:x}" def _disasm_smem(inst: SMEM) -> str: diff --git a/test/testextra/test_cfg_viz.py b/test/testextra/test_cfg_viz.py index e339163b7b..6d44ee8329 100644 --- a/test/testextra/test_cfg_viz.py +++ b/test/testextra/test_cfg_viz.py @@ -111,6 +111,8 @@ class TestCfg(unittest.TestCase): _, lib = assemble("diamond", insts, Device[Device.DEFAULT].compiler) cfg = amdgpu_cfg(lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"] self.assertEqual(len(cfg["blocks"]), 5) + edge_count = sum(len(v) for v in cfg["paths"].values()) + self.assertEqual(edge_count, 5) references:dict[str, list[str]] = {} for pc, tokens in cfg["pc_tokens"].items(): for t in tokens: diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 54664c11ee..5d3d0d4e27 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -345,7 +345,8 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[ # * init decoder from extra.sqtt.roc import decode base = unwrap(p.base) - disasm = {addr+base:inst_disasm for addr,inst_disasm in amd_disasm(device_props[p.device]["gfx_target_version"], unwrap(p.lib)).items()} + addr_table = amd_decode(device_props[p.device]["gfx_target_version"], unwrap(p.lib)) + disasm:dict[int, tuple[str, int]] = {addr+base:(inst.disasm(), inst.size()) for addr, inst in addr_table.items()} rctx = decode(data, {p.name:disasm}) cu_events:dict[str, list[ProfileEvent]] = {} # * INST waves @@ -431,74 +432,49 @@ def amd_readelf(lib:bytes) -> list[dict]: ".group_segment_fixed_size":"LDS size", ".private_segment_fixed_size":"Scratch size"} return [{"label":label, "value":v} for k,label in keys.items() if (v:=notes["amdhsa.kernels"][0][k]) > 0] -def amd_disasm(target:int, lib:bytes) -> dict[int, tuple[str, int]]: +def amd_decode(target:int, lib:bytes) -> dict[int, Any]: # Any is the Inst class from extra.assembly.amd.dsl from tinygrad.runtime.support.elf import elf_loader from extra.assembly.amd.decode import detect_format + from extra.assembly.amd.dsl import Inst image, sections, _ = elf_loader(lib) text = next((sh for sh in sections if sh.name == ".text"), None) assert text is not None, "no .text section found in ELF" off, buf = text.header.sh_addr, text.content arch = {11:"rdna3", 12:"rdna4"}.get(target//10000, "cdna") - addr_table:dict[int, tuple[str, int]] = {} + addr_table:dict[int, Inst] = {} offset = 0 while offset < len(buf): remaining = buf[offset:] fmt = detect_format(remaining, arch) decoded = fmt.from_bytes(remaining) - disasm = decoded.disasm() - # note: rocprof trace decoder assumes simm16 is a decimal integer, our disasm uses hex - # keep the decimal int for backwards compatibility, remove once there's no rocprof decoder - if "branch" in disasm: disasm = f"{decoded.op_name.lower()} {decoded.simm16}" - addr_table[off+offset] = (disasm, decoded.size()) + addr_table[off+offset] = decoded offset += decoded.size() return addr_table -SOPP_INSTS = {"s_branch", "s_cbranch_scc0", "s_cbranch_scc1", "s_cbranch_vccz", "s_cbranch_vccnz", "s_cbranch_execz", "s_cbranch_execnz"} -def parse_branch(asm:str) -> int|None: - inst, *operands = asm.split(" ") - if inst in SOPP_INSTS: - x = int(operands[0]) & 0xffff +def parse_branch(inst) -> int|None: + if "branch" in getattr(inst, "op_name", "").lower(): + x = inst.simm16 & 0xffff return (x - 0x10000 if x & 0x8000 else x)*4 return None -def _op2dsl(op: str) -> str: - """Convert LLVM asm operand (s0, s[0:1], v0) to DSL format (s[0], s[0:1], v[0]).""" - import re - op = op.strip() - lo = op.lower() - SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC', - 'scc': 'SCC', 'm0': 'M0', 'null': 'NULL', 'off': 'OFF'} - if lo in SPEC_DSL: return SPEC_DSL[lo] - rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'} - if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]" - if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return f"{rp[m.group(1)]}[{m.group(2)}]" - return op - -def amdgpu_tokenize(st:str) -> list[str]: - try: - from extra.assembly.amd.dsl import s, v, Reg, VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF - dsl = eval(_op2dsl(st), {'s':s, 'v':v, 'VCC_LO':VCC_LO, 'VCC_HI':VCC_HI, 'VCC':VCC, 'EXEC_LO':EXEC_LO, 'EXEC_HI':EXEC_HI, 'EXEC':EXEC, - 'SCC':SCC, 'M0':M0, 'NULL':NULL, 'OFF':OFF}) - return [f"{type(dsl).__name__[0].lower()}{dsl.offset + i}" for i in range(dsl.sz)] if isinstance(dsl, Reg) else [st] - except (ImportError, NameError, SyntaxError, TypeError): return [] - COND_TAKEN, COND_NOT_TAKEN, UNCOND = range(3) def amdgpu_cfg(lib:bytes, target:int) -> dict: - # disassemble - pc_table = amd_disasm(target, lib) + # decode + pc_table = amd_decode(target, lib) # get leaders leaders:set[int] = {next(iter(pc_table))} - for pc, (asm, sz) in pc_table.items(): - if (offset:=parse_branch(asm)) is not None: leaders.update((pc+sz+offset, pc+sz)) + for pc, inst in pc_table.items(): + if (offset:=parse_branch(inst)) is not None: leaders.update((pc+inst.size()+offset, pc+inst.size())) # build the cfg curr:int|None = None blocks:dict[int, list[int]] = {} paths:dict[int, dict[int, int]] = {} lines:list[str] = [] - asm_width = max(len(asm) for asm, _ in pc_table.values()) - for pc, (asm, sz) in pc_table.items(): + disasm = {pc:inst.disasm() for pc,inst in pc_table.items()} + asm_width = max(len(asm) for asm in disasm.values()) + for pc, inst in pc_table.items(): # skip instructions only used for padding - if asm == "s_code_end": continue + if (asm:=disasm[pc]) == "s_code_end": continue lines.append(f" {asm:<{asm_width}} // {pc:012X}") if pc in leaders: paths[curr:=pc] = {} @@ -506,14 +482,19 @@ def amdgpu_cfg(lib:bytes, target:int) -> dict: else: assert curr is not None, f"no basic block found for {pc}" blocks[curr].append(pc) # otherwise a basic block can have exactly one or two paths - nx = pc+sz - if (offset:=parse_branch(asm)) is not None: - if asm.startswith("s_branch"): paths[curr][nx+offset] = UNCOND + nx = pc+inst.size() + if (offset:=parse_branch(inst)) is not None: + if inst.op_name == "S_BRANCH": paths[curr][nx+offset] = UNCOND else: paths[curr].update([(nx+offset, COND_TAKEN), (nx, COND_NOT_TAKEN)]) elif nx in leaders: paths[curr][nx] = UNCOND pc_tokens:dict[int, list[dict]] = {} - for pc, (text, _) in pc_table.items(): - pc_tokens[pc] = [{"st":s, "keys":amdgpu_tokenize(s) if i>0 else [s], "kind":int(i>0)} for i,s in enumerate(text.replace(",", " , ").split(" "))] + from extra.assembly.amd.dsl import Reg + for pc, inst in pc_table.items(): + pc_tokens[pc] = tokens = [] + for name, field in inst._fields: + if isinstance(val:=getattr(inst, name), Reg): tokens.append({"st":val.fmt(), "keys":[f"r{val.offset+i}" for i in range(val.sz)], "kind":1}) + elif name in {"op","opx","opy"}: tokens.append({"st":(op_name:=val.name.lower()), "keys":[op_name], "kind":0}) + elif name != "encoding" and val != field.default: tokens.append({"st":(s:=repr(val)), "keys":[s], "kind":1}) return {"data":{"blocks":blocks, "paths":paths, "pc_tokens":pc_tokens}, "src":"\n".join(lines)} # ** Main render function to get the complete details about a trace event