assembly/amd: add CDNA support to asm (#13982)

* add CDNA support

* more cdna tests

* something

* fix more stuff

* more work

* simpler

* simplier

* cdna

* disasm

* less skip

* fixes

* simpler
This commit is contained in:
George Hotz
2026-01-04 08:53:56 -08:00
committed by GitHub
parent ad041416ca
commit 7ebda28692
8 changed files with 370 additions and 84 deletions

View File

@@ -1,4 +1,4 @@
# RDNA3 assembler and disassembler
# RDNA3/CDNA assembler and disassembler
from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
@@ -8,6 +8,8 @@ from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
"""Check if word matches the encoding pattern of an instruction class."""
if cls._encoding is None: return False
@@ -81,10 +83,11 @@ def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v)
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str:
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> str:
"""Format VOP3 source operand with modifiers."""
if n > 1: s = _fmt_src(v, n)
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v))
if v == 255: s = inst.lit(v) # literal constant takes priority
elif n > 1: s = _fmt_src(v, n)
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l"
else: s = inst.lit(v)
if abs_: s = f"|{s}|"
return f"-{s}" if neg else s
@@ -92,9 +95,11 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
"""Format op_sel modifier string."""
if not need: return ""
if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]"
if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]"
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]"
# For VOP1 (n=1): op_sel:[src0_hi, dst_hi], for VOP2 (n=2): op_sel:[src0_hi, src1_hi, dst_hi], for VOP3 (n=3): op_sel:[src0_hi, src1_hi, src2_hi, dst_hi]
dst_hi = (opsel >> 3) & 1
if n == 1: return f" op_sel:[{opsel & 1},{dst_hi}]"
if n == 2: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{dst_hi}]"
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{dst_hi}]"
# ═══════════════════════════════════════════════════════════════════════════════
# DISASSEMBLER
@@ -108,30 +113,41 @@ def _disasm_vop1(inst: VOP1) -> str:
parts = name.split('_')
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
return f"{name}_e32 {dst}, {src}"
def _disasm_vop2(inst: VOP2) -> str:
name = inst.op_name.lower()
suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
name, cdna = inst.op_name.lower(), _is_cdna(inst)
suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
lit = getattr(inst, '_literal', None)
is16 = not cdna and inst.is_16bit()
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "")
if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
vcc = "vcc" if cdna else "vcc_lo"
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
def _disasm_vopc(inst: VOPC) -> str:
name = inst.op_name.lower()
s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna:
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0))
return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}"
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
def _disasm_sopp(inst: SOPP) -> str:
name = inst.op_name.lower()
if inst.op in NO_ARG_SOPP: return name
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
if inst.op == SOPPOp.S_WAITCNT:
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
@@ -154,12 +170,13 @@ def _disasm_smem(inst: SMEM) -> str:
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name = inst.op_name.lower()
name, cdna = inst.op_name.lower(), _is_cdna(inst)
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}"
else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
# saddr
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
elif inst.saddr == 124: saddr_s = ", off"
@@ -172,8 +189,9 @@ def _disasm_flat(inst: FLAT) -> str:
# addr width
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
glc_or_sc0 = inst.sc0 if cdna else inst.glc
if 'atomic' in name:
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if inst.glc else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
@@ -211,7 +229,9 @@ def _disasm_vop3(inst: VOP3) -> str:
# VOP3SD (shared encoding)
if isinstance(op, VOP3SDOp):
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
def src(v, neg, n):
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
@@ -225,18 +245,18 @@ def _disasm_vop3(inst: VOP3) -> str:
is16_s2 = is16_s
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
elif 'pack_b32' in name: is16_s = is16_s2 = True
elif 'sat_pk' in name: is16_d = True # v_sat_pk_* writes to 16-bit dest but takes 32-bit src
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
any_hi = inst.opsel != 0
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s)
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s)
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2)
# Destination
dn = inst.dst_regs()
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
elif dn > 1: dst = _vreg(inst.vdst, dn)
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l"
else: dst = f"v{inst.vdst}"
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
@@ -244,7 +264,7 @@ def _disasm_vop3(inst: VOP3) -> str:
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
if inst.op < 256: # VOPC
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
return f"{name}_e64 {s0}, {s1}{cl}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}{cl}"
if inst.op < 384: # VOP2
n = inst.num_srcs()
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
@@ -258,7 +278,9 @@ def _disasm_vop3(inst: VOP3) -> str:
def _disasm_vop3sd(inst: VOP3SD) -> str:
name = inst.op_name.lower()
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
def src(v, neg, n):
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
@@ -268,16 +290,22 @@ def _disasm_vop3sd(inst: VOP3SD) -> str:
def _disasm_vopd(inst: VOPD) -> str:
lit = inst._literal or inst.literal
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
def half(n, vd, s0, vs1):
if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}"
# fmamk: dst = src0 * K + vsrc1, fmaak: dst = src0 * vsrc1 + K
if 'fmamk' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, 0x{lit:x}, v{vs1}"
if 'fmaak' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}, 0x{lit:x}"
return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}"
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
def _disasm_vop3p(inst: VOP3P) -> str:
name = inst.op_name.lower()
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc)
if is_wmma:
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 8), _vreg(inst.vdst, 8)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
if is_fma_mix:
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
@@ -289,15 +317,17 @@ def _disasm_vop3p(inst: VOP3P) -> str:
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
name = inst.op_name.lower()
if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
name, cdna = inst.op_name.lower(), _is_cdna(inst)
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
if inst.tfe: w += 1
if hasattr(inst, 'tfe') and inst.tfe: w += 1
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c]
else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
@@ -348,25 +378,36 @@ def _disasm_mimg(inst: MIMG) -> str:
def _disasm_sop1(inst: SOP1) -> str:
op, name = inst.op, inst.op_name.lower()
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
if not _is_cdna(inst):
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}"
def _disasm_sop2(inst: SOP2) -> str:
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
def _disasm_sopc(inst: SOPC) -> str:
return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}"
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))
return f"{inst.op_name.lower()} {s0}, {s1}"
def _disasm_sopk(inst: SOPK) -> str:
op, name = inst.op, inst.op_name.lower()
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
# s_setreg_imm32_b32 has a 32-bit literal value
if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
return f"{name} {hs}, 0x{inst._literal:x}"
if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
def _disasm_vinterp(inst: VINTERP) -> str:
@@ -388,7 +429,8 @@ SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), '
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'}
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512',
's_atc_probe', 's_atc_probe_buffer'}
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
@@ -579,3 +621,69 @@ def asm(text: str) -> Inst:
except NameError:
if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns)
raise
# ═══════════════════════════════════════════════════════════════════════════════
# CDNA DISASSEMBLER SUPPORT
# ═══════════════════════════════════════════════════════════════════════════════
try:
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op)
def _cdna_src(inst, v, neg, abs_=0, n=1):
s = inst.lit(v) if v == 255 else _fmt_src(v, n)
if abs_: s = f"|{s}|"
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
def _disasm_vop3a(inst) -> str:
name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod)
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2))
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
suf = "_e64" if inst.op.value < 512 else ""
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}")
def _disasm_vop3b(inst) -> str:
name, n = inst.op_name.lower(), inst.num_srcs()
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4)
dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else ""
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}"
def _disasm_cdna_vop3p(inst) -> str:
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc)
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
_SEL = {0: 'BYTE_0', 1: 'BYTE_1', 2: 'BYTE_2', 3: 'BYTE_3', 4: 'WORD_0', 5: 'WORD_1', 6: 'DWORD'}
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
def _disasm_sdwa(inst) -> str:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0)
mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6]
return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "")
def _disasm_dpp(inst) -> str:
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
except ValueError: name = f"vop1_op_{inst.vop_op}"
src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl]
return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods)
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf,
VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, SDWA: _disasm_sdwa, DPP: _disasm_dpp})
except ImportError:
pass

View File

@@ -7,20 +7,18 @@ import functools
# instruction formats
class DPP(Inst64):
encoding = bits[31:26] == 0b110110
src1_sel = bits[58:56]
src1_sext = bits[59]
src1_neg = bits[60]
src1_abs = bits[61]
s1 = bits[63]
offset0 = bits[7:0]
offset1 = bits[15:8]
op = bits[24:17]
acc = bits[25]
addr:VGPRField = bits[39:32]
data0:VGPRField = bits[47:40]
data1:VGPRField = bits[55:48]
vdst:VGPRField = bits[63:56]
encoding = bits[8:0] == 0b11111010
vop_op = bits[16:9]
vdst:VGPRField = bits[24:17]
vop2_op = bits[31:25]
src0:Src = bits[39:32]
dpp_ctrl = bits[48:40]
bound_ctrl = bits[51]
src0_neg = bits[52]
src0_abs = bits[53]
src1_neg = bits[54]
src1_abs = bits[55]
bank_mask = bits[59:56]
row_mask = bits[63:60]
class DS(Inst64):
@@ -82,6 +80,10 @@ class MUBUF(Inst64):
acc = bits[55]
class SDWA(Inst64):
encoding = bits[8:0] == 0b11111001
vop_op = bits[16:9]
vdst:VGPRField = bits[24:17]
vop2_op = bits[31:25]
src0:Src = bits[39:32]
dst_sel = bits[42:40]
dst_u = bits[44:43]
@@ -97,9 +99,6 @@ class SDWA(Inst64):
src1_neg = bits[60]
src1_abs = bits[61]
s1 = bits[63]
sdst:SGPRField = bits[46:40]
sd = bits[47]
row_mask = bits[63:60]
class SDWAB(Inst64):
src0:Src = bits[39:32]

View File

@@ -488,6 +488,8 @@ class SMEMOp(IntEnum):
S_BUFFER_LOAD_B512 = 12
S_GL1_INV = 32
S_DCACHE_INV = 33
S_ATC_PROBE = 34
S_ATC_PROBE_BUFFER = 35
class SOP1Op(IntEnum):
S_MOV_B32 = 0
@@ -710,6 +712,8 @@ class SOPKOp(IntEnum):
S_SETREG_B32 = 18
S_SETREG_IMM32_B32 = 19
S_CALL_B64 = 20
S_SUBVECTOR_LOOP_BEGIN = 22
S_SUBVECTOR_LOOP_END = 23
S_WAITCNT_VSCNT = 24
S_WAITCNT_VMCNT = 25
S_WAITCNT_EXPCNT = 26
@@ -751,6 +755,8 @@ class SOPPOp(IntEnum):
S_SENDMSGHALT = 55
S_INCPERFLEVEL = 56
S_DECPERFLEVEL = 57
S_TTRACEDATA = 58
S_TTRACEDATA_IMM = 59
S_ICACHE_INV = 60
S_BARRIER = 61

View File

@@ -692,6 +692,8 @@ s_buffer_load_b256 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B256)
s_buffer_load_b512 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B512)
s_gl1_inv = functools.partial(SMEM, SMEMOp.S_GL1_INV)
s_dcache_inv = functools.partial(SMEM, SMEMOp.S_DCACHE_INV)
s_atc_probe = functools.partial(SMEM, SMEMOp.S_ATC_PROBE)
s_atc_probe_buffer = functools.partial(SMEM, SMEMOp.S_ATC_PROBE_BUFFER)
s_mov_b32 = functools.partial(SOP1, SOP1Op.S_MOV_B32)
s_mov_b64 = functools.partial(SOP1, SOP1Op.S_MOV_B64)
s_cmov_b32 = functools.partial(SOP1, SOP1Op.S_CMOV_B32)
@@ -906,6 +908,8 @@ s_getreg_b32 = functools.partial(SOPK, SOPKOp.S_GETREG_B32)
s_setreg_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_B32)
s_setreg_imm32_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_IMM32_B32)
s_call_b64 = functools.partial(SOPK, SOPKOp.S_CALL_B64)
s_subvector_loop_begin = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_BEGIN)
s_subvector_loop_end = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_END)
s_waitcnt_vscnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VSCNT)
s_waitcnt_vmcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VMCNT)
s_waitcnt_expcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_EXPCNT)
@@ -945,6 +949,8 @@ s_sendmsg = functools.partial(SOPP, SOPPOp.S_SENDMSG)
s_sendmsghalt = functools.partial(SOPP, SOPPOp.S_SENDMSGHALT)
s_incperflevel = functools.partial(SOPP, SOPPOp.S_INCPERFLEVEL)
s_decperflevel = functools.partial(SOPP, SOPPOp.S_DECPERFLEVEL)
s_ttracedata = functools.partial(SOPP, SOPPOp.S_TTRACEDATA)
s_ttracedata_imm = functools.partial(SOPP, SOPPOp.S_TTRACEDATA_IMM)
s_icache_inv = functools.partial(SOPP, SOPPOp.S_ICACHE_INV)
s_barrier = functools.partial(SOPP, SOPPOp.S_BARRIER)
v_interp_p10_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F32)

View File

@@ -458,10 +458,16 @@ class Inst:
@classmethod
def from_bytes(cls, data: bytes):
import typing
inst = cls.from_int(int.from_bytes(data[:cls._size()], 'little'))
op_val = inst._values.get('op', 0)
has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56)
has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70))
# Check for instructions that always have a literal constant (FMAMK/FMAAK/MADMK/MADAK, SETREG_IMM32)
op_name = ''
if cls.__name__ in ('VOP2', 'SOP2', 'SOPK') and 'op' in (hints := typing.get_type_hints(cls, include_extras=True)):
if typing.get_origin(hints['op']) is typing.Annotated:
try: op_name = typing.get_args(hints['op'])[1](op_val).name
except (ValueError, TypeError): pass
has_literal = any(x in op_name for x in ('FMAMK', 'FMAAK', 'MADMK', 'MADAK', 'SETREG_IMM32'))
# VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2)
opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0)
has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2)))
@@ -475,7 +481,7 @@ class Inst:
lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little')
# Find which source has literal (255) and check its register count
lit_src_is_64 = False
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]:
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255:
lit_src_is_64 = inst.src_regs(idx) == 2
break
@@ -495,7 +501,12 @@ class Inst:
return unwrap(self._values.get(name, 0))
def lit(self, v: int, neg: bool = False) -> str:
s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
if v == 255 and self._literal is not None:
# For 64-bit sources, literal is stored shifted - extract the 32-bit value
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
s = f"0x{lit32:x}"
else:
s = decode_src(v)
return f"-{s}" if neg else s
def __eq__(self, other):
@@ -528,6 +539,10 @@ class Inst:
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
else: self.op = VOP3Op(val)
except ValueError: self.op = val
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
try: self.op = marker(val)
except ValueError: self.op = val
elif cls_name in self._enum_map:
try: self.op = self._enum_map[cls_name](val)
except ValueError: self.op = val

View File

@@ -194,13 +194,30 @@ def _parse_single_pdf(url: str):
if fmt_name in formats:
formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]]
if doc_name in ('RDNA3', 'RDNA3.5'):
if 'SOPPOp' in enums: assert 8 not in enums['SOPPOp']; enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR'
if 'SOPPOp' in enums:
for k, v in {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}.items():
assert k not in enums['SOPPOp']; enums['SOPPOp'][k] = v
if 'SOPKOp' in enums:
for k, v in {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'}.items():
assert k not in enums['SOPKOp']; enums['SOPKOp'][k] = v
if 'SMEMOp' in enums:
for k, v in {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}.items():
assert k not in enums['SMEMOp']; enums['SMEMOp'][k] = v
if 'DSOp' in enums:
for k, v in {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}.items():
assert k not in enums['DSOp']; enums['DSOp'][k] = v
if 'FLATOp' in enums:
for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items():
assert k not in enums['FLATOp']; enums['FLATOp'][k] = v
# CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding
if is_cdna:
if 'SDWA' in formats:
formats['SDWA'] = [('ENCODING', 8, 0, 0xf9, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None)] + \
[f for f in formats['SDWA'] if f[0] not in ('ENCODING', 'SDST', 'SD', 'ROW_MASK')]
if 'DPP' in formats:
formats['DPP'] = [('ENCODING', 8, 0, 0xfa, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None),
('SRC0', 39, 32, None, 'Src'), ('DPP_CTRL', 48, 40, None, None), ('BOUND_CTRL', 51, 51, None, None), ('SRC0_NEG', 52, 52, None, None), ('SRC0_ABS', 53, 53, None, None),
('SRC1_NEG', 54, 54, None, None), ('SRC1_ABS', 55, 55, None, None), ('BANK_MASK', 59, 56, None, None), ('ROW_MASK', 63, 60, None, None)]
# Extract pseudocode for instructions
all_text = '\n'.join(pdf.text(i) for i in range(instr_start, instr_end))

View File

@@ -131,28 +131,19 @@ def _make_asm_test(name):
def _make_disasm_test(name):
def test(self):
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
_, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name]
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
# First pass: decode all instructions and collect disasm strings
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
skipped = 0
for asm_text, data in self.tests.get(name, []):
if len(data) > fmt_cls._size(): continue
temp_inst = fmt_cls.from_bytes(data)
temp_op = temp_inst._values.get('op', 0)
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
if temp_op in undocumented.get(name, set()): skipped += 1; continue
if name == 'sopp':
simm16 = temp_inst._values.get('simm16', 0)
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
# Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword
is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35
fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum)
try:
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
temp = VOP3.from_bytes(data)
op_val = temp._values.get('op', 0)
op_val = op_val.val if hasattr(op_val, 'val') else op_val
@@ -188,7 +179,7 @@ def _make_disasm_test(name):
if llvm_bytes is not None and llvm_bytes == data: passed += 1
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
print(f"{name.upper()} disasm: {passed} passed, {failed} failed")
if failures[:10]: print(" " + "\n ".join(failures[:10]))
self.assertEqual(failed, 0)
return test

View File

@@ -0,0 +1,144 @@
#!/usr/bin/env python3
"""Test CDNA assembler/disassembler against LLVM test vectors."""
import unittest, re, subprocess
from tinygrad.helpers import fetch
from extra.assembly.amd.autogen.cdna.ins import *
from extra.assembly.amd.asm import disasm
from extra.assembly.amd.test.helpers import get_llvm_mc
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]:
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
tests, lines = [], text.split('\n')
for i, line in enumerate(lines):
line = line.strip()
if not line or line.startswith(('//', '.', ';')): continue
asm_text = line.split('//')[0].strip()
if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue
for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))):
if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
else: continue
try:
data = bytes.fromhex(hex_bytes)
if size_filter is None or len(data) == size_filter: tests.append((asm_text, data))
except ValueError: pass
break
return tests
# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions
# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter)
CDNA_TEST_FILES = {
# Scalar ALU - encoding is stable across GFX9/CDNA
'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None),
'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None),
'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None),
'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None),
'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None),
'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None),
# Vector ALU - encoding is mostly stable
'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None),
'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None),
'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None),
'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None),
'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None),
'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions
# Memory instructions
'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None),
'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None),
# CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers)
'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None),
'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None),
'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None),
'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None),
'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None),
# CDNA-specific: MFMA/MAI instructions
'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None),
# SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately)
'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None),
'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None),
}
class TestLLVMCDNA(unittest.TestCase):
"""Test CDNA instruction format decode/encode roundtrip and disassembly."""
tests: dict[str, list[tuple[str, bytes]]] = {}
@classmethod
def setUpClass(cls):
for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items():
try:
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter)
except Exception as e:
print(f"Warning: couldn't fetch {filename}: {e}")
cls.tests[name] = []
def _get_val(v): return v.val if hasattr(v, 'val') else v
def _filter_and_decode(tests, fmt_cls, op_enum):
"""Filter tests and decode instructions, yielding (asm_text, data, decoded, error)."""
fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP'
for asm_text, data in tests:
has_lit = False
# SDWA/DPP format tests: only accept matching 8-byte instructions
if is_sdwa:
if len(data) != 8 or data[0] != 0xf9: continue
elif is_dpp:
if len(data) != 8 or data[0] != 0xfa: continue
elif fmt_cls._size() == 4 and len(data) == 8:
if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately)
has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC'))
if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20
if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37)
if not has_lit: continue
if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue
try:
decoded = fmt_cls.from_bytes(data)
# For SDWA/DPP, opcode location depends on VOP1 vs VOP2
if is_sdwa or is_dpp:
vop2_op = _get_val(decoded._values.get('vop2_op', 0))
op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op
else:
op_val = _get_val(decoded._values.get('op', 0))
try: op_enum(op_val)
except ValueError: continue
yield asm_text, data, decoded, None
except Exception as e:
yield asm_text, data, None, str(e)
def _make_roundtrip_test(name):
def test(self):
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
passed, failed, failures = 0, 0, []
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
if decoded.to_bytes()[:len(data)] == data: passed += 1
else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}")
print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed")
if failures[:5]: print(" " + "\n ".join(failures[:5]))
self.assertEqual(failed, 0)
return test
def _make_disasm_test(name):
def test(self):
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
passed, failed, failures = 0, 0, []
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue
if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue
passed += 1
print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed")
if failures[:5]: print(" " + "\n ".join(failures[:5]))
self.assertEqual(failed, 0)
return test
for name in CDNA_TEST_FILES:
setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name))
setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name))
if __name__ == "__main__":
unittest.main()