mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
assembly/amd: add CDNA support to asm (#13982)
* add CDNA support * more cdna tests * something * fix more stuff * more work * simpler * simplier * cdna * disasm * less skip * fixes * simpler
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# RDNA3 assembler and disassembler
|
||||
# RDNA3/CDNA assembler and disassembler
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
@@ -8,6 +8,8 @@ from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
|
||||
|
||||
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
|
||||
|
||||
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
||||
"""Check if word matches the encoding pattern of an instruction class."""
|
||||
if cls._encoding is None: return False
|
||||
@@ -81,10 +83,11 @@ def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v)
|
||||
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
|
||||
def _fmt_bits(label: str, val: int, count: int) -> str: return f"{label}:[{','.join(str((val >> i) & 1) for i in range(count))}]"
|
||||
|
||||
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any_hi: bool) -> str:
|
||||
def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool) -> str:
|
||||
"""Format VOP3 source operand with modifiers."""
|
||||
if n > 1: s = _fmt_src(v, n)
|
||||
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else (f"v{v - 256}.l" if any_hi else inst.lit(v))
|
||||
if v == 255: s = inst.lit(v) # literal constant takes priority
|
||||
elif n > 1: s = _fmt_src(v, n)
|
||||
elif f16 and v >= 256: s = f"v{v - 256}.h" if hi else f"v{v - 256}.l"
|
||||
else: s = inst.lit(v)
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"-{s}" if neg else s
|
||||
@@ -92,9 +95,11 @@ def _vop3_src(inst, v: int, neg: int, abs_: int, hi: int, n: int, f16: bool, any
|
||||
def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
|
||||
"""Format op_sel modifier string."""
|
||||
if not need: return ""
|
||||
if is16_d and (opsel & 8): return f" op_sel:[1,1,1{',1' if n == 3 else ''}]"
|
||||
if n == 3: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{(opsel >> 3) & 1}]"
|
||||
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]"
|
||||
# For VOP1 (n=1): op_sel:[src0_hi, dst_hi], for VOP2 (n=2): op_sel:[src0_hi, src1_hi, dst_hi], for VOP3 (n=3): op_sel:[src0_hi, src1_hi, src2_hi, dst_hi]
|
||||
dst_hi = (opsel >> 3) & 1
|
||||
if n == 1: return f" op_sel:[{opsel & 1},{dst_hi}]"
|
||||
if n == 2: return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{dst_hi}]"
|
||||
return f" op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1},{dst_hi}]"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DISASSEMBLER
|
||||
@@ -108,30 +113,41 @@ def _disasm_vop1(inst: VOP1) -> str:
|
||||
parts = name.split('_')
|
||||
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
src = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
return f"{name}_e32 {dst}, {src}"
|
||||
|
||||
def _disasm_vop2(inst: VOP2) -> str:
|
||||
name = inst.op_name.lower()
|
||||
suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
suf = "" if not cdna and inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
lit = getattr(inst, '_literal', None)
|
||||
is16 = not cdna and inst.is_16bit()
|
||||
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
|
||||
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
|
||||
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
|
||||
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "")
|
||||
if 'fmaak' in name or (not cdna and inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}, 0x{lit:x}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{lit:x}"
|
||||
if 'fmamk' in name or (not cdna and inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16)):
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, 0x{lit:x}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{lit:x}, v{inst.vsrc1}"
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
vcc = "vcc" if cdna else "vcc_lo"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (f", {vcc}" if name == 'v_cndmask_b32' else "")
|
||||
|
||||
def _disasm_vopc(inst: VOPC) -> str:
|
||||
name = inst.op_name.lower()
|
||||
s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna:
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0))
|
||||
return f"{name}_e32 {s0}, v{inst.vsrc1}" if inst.op.value >= 128 else f"{name}_e32 vcc, {s0}, v{inst.vsrc1}"
|
||||
s0 = inst.lit(inst.src0) if inst.src0 == 255 else _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
|
||||
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
|
||||
NO_ARG_SOPP = {SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE, SOPPOp.S_TTRACEDATA}
|
||||
|
||||
def _disasm_sopp(inst: SOPP) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in NO_ARG_SOPP: return name
|
||||
if inst.op == SOPPOp.S_ENDPGM: return name if inst.simm16 == 0 else f"{name} {inst.simm16}"
|
||||
if inst.op == SOPPOp.S_WAITCNT:
|
||||
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
|
||||
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
@@ -154,12 +170,13 @@ def _disasm_smem(inst: SMEM) -> str:
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name = inst.op_name.lower()
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
|
||||
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
|
||||
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
if cdna: mods = f"{f' offset:{off_val}' if off_val else ''}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if inst.sc1 else ''}"
|
||||
else: mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
# saddr
|
||||
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
|
||||
elif inst.saddr == 124: saddr_s = ", off"
|
||||
@@ -172,8 +189,9 @@ def _disasm_flat(inst: FLAT) -> str:
|
||||
# addr width
|
||||
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, 1 if seg == 'scratch' or (inst.saddr not in (0x7F, 124)) else 2)
|
||||
data_s, vdst_s = _vreg(inst.data, w), _vreg(inst.vdst, w // 2 if 'cmpswap' in name else w)
|
||||
glc_or_sc0 = inst.sc0 if cdna else inst.glc
|
||||
if 'atomic' in name:
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if inst.glc else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
|
||||
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
|
||||
@@ -211,7 +229,9 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
# VOP3SD (shared encoding)
|
||||
if isinstance(op, VOP3SDOp):
|
||||
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
def src(v, neg, n):
|
||||
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
@@ -225,18 +245,18 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
is16_s2 = is16_s
|
||||
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
|
||||
elif 'pack_b32' in name: is16_s = is16_s2 = True
|
||||
elif 'sat_pk' in name: is16_d = True # v_sat_pk_* writes to 16-bit dest but takes 32-bit src
|
||||
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
|
||||
|
||||
any_hi = inst.opsel != 0
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2)
|
||||
|
||||
# Destination
|
||||
dn = inst.dst_regs()
|
||||
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
|
||||
elif dn > 1: dst = _vreg(inst.vdst, dn)
|
||||
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
|
||||
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l"
|
||||
else: dst = f"v{inst.vdst}"
|
||||
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
@@ -244,7 +264,7 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
need_opsel = nonvgpr_opsel or (inst.opsel and not is16_s)
|
||||
|
||||
if inst.op < 256: # VOPC
|
||||
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
return f"{name}_e64 {s0}, {s1}{cl}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}{cl}"
|
||||
if inst.op < 384: # VOP2
|
||||
n = inst.num_srcs()
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
@@ -258,7 +278,9 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
|
||||
def _disasm_vop3sd(inst: VOP3SD) -> str:
|
||||
name = inst.op_name.lower()
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
def src(v, neg, n):
|
||||
s = inst.lit(v) if v == 255 else (_fmt_src(v, n) if n > 1 else inst.lit(v))
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
@@ -268,16 +290,22 @@ def _disasm_vop3sd(inst: VOP3SD) -> str:
|
||||
def _disasm_vopd(inst: VOPD) -> str:
|
||||
lit = inst._literal or inst.literal
|
||||
vdst_y, nx, ny = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), VOPDOp(inst.opx).name.lower(), VOPDOp(inst.opy).name.lower()
|
||||
def half(n, vd, s0, vs1): return f"{n} v{vd}, {inst.lit(s0)}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}" if 'mov' in n else f"{n} v{vd}, {inst.lit(s0)}, v{vs1}{f', 0x{lit:x}' if lit and _has(n, 'fmaak', 'fmamk') else ''}"
|
||||
def half(n, vd, s0, vs1):
|
||||
if 'mov' in n: return f"{n} v{vd}, {inst.lit(s0)}"
|
||||
# fmamk: dst = src0 * K + vsrc1, fmaak: dst = src0 * vsrc1 + K
|
||||
if 'fmamk' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, 0x{lit:x}, v{vs1}"
|
||||
if 'fmaak' in n and lit: return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}, 0x{lit:x}"
|
||||
return f"{n} v{vd}, {inst.lit(s0)}, v{vs1}"
|
||||
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
|
||||
|
||||
def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
name = inst.op_name.lower()
|
||||
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
|
||||
def get_src(v, sc): return inst.lit(v) if v == 255 else _fmt_src(v, sc)
|
||||
if is_wmma:
|
||||
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
if is_fma_mix:
|
||||
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
|
||||
@@ -289,15 +317,17 @@ def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
name, cdna = inst.op_name.lower(), _is_cdna(inst)
|
||||
if cdna and name in ('buffer_wbl2', 'buffer_inv'): return name
|
||||
if not cdna and inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
|
||||
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
|
||||
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
|
||||
if inst.tfe: w += 1
|
||||
if hasattr(inst, 'tfe') and inst.tfe: w += 1
|
||||
vaddr = _vreg(inst.vaddr, 2) if inst.offen and inst.idxen else f"v{inst.vaddr}" if inst.offen or inst.idxen else "off"
|
||||
srsrc = _sreg_or_ttmp(inst.srsrc*4, 4)
|
||||
mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
if cdna: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.sc0,"sc0"),(inst.nt,"nt"),(inst.sc1,"sc1")] if c]
|
||||
else: mods = ([f"format:{inst.format}"] if isinstance(inst, MTBUF) else []) + [m for c, m in [(inst.idxen,"idxen"),(inst.offen,"offen"),(inst.offset,f"offset:{inst.offset}"),(inst.glc,"glc"),(inst.dlc,"dlc"),(inst.slc,"slc"),(inst.tfe,"tfe")] if c]
|
||||
return f"{name} {_vreg(inst.vdata, w)}, {vaddr}, {srsrc}, {decode_src(inst.soffset)}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
||||
@@ -348,25 +378,36 @@ def _disasm_mimg(inst: MIMG) -> str:
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
|
||||
src = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
if not _is_cdna(inst):
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {src}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {src}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {src}"
|
||||
|
||||
def _disasm_sop2(inst: SOP2) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
|
||||
def _disasm_sopc(inst: SOPC) -> str:
|
||||
return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
s0 = inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))
|
||||
s1 = inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))
|
||||
return f"{inst.op_name.lower()} {s0}, {s1}"
|
||||
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
|
||||
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
|
||||
op, name, cdna = inst.op, inst.op_name.lower(), _is_cdna(inst)
|
||||
# s_setreg_imm32_b32 has a 32-bit literal value
|
||||
if name == 's_setreg_imm32_b32' or (not cdna and op == SOPKOp.S_SETREG_IMM32_B32):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
return f"{name} {hs}, 0x{inst._literal:x}"
|
||||
if not cdna and op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
|
||||
if (not cdna and op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32)) or (cdna and name in ('s_setreg_b32', 's_getreg_b32')):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if 'setreg' in name else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
if not cdna and op in (SOPKOp.S_SUBVECTOR_LOOP_BEGIN, SOPKOp.S_SUBVECTOR_LOOP_END):
|
||||
return f"{name} {_fmt_sdst(inst.sdst, 1)}, 0x{inst.simm16:x}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_vinterp(inst: VINTERP) -> str:
|
||||
@@ -388,7 +429,8 @@ SPEC_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), '
|
||||
FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '-0.5', '1.0', etc.
|
||||
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
|
||||
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
|
||||
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512'}
|
||||
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512',
|
||||
's_atc_probe', 's_atc_probe_buffer'}
|
||||
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
|
||||
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
|
||||
|
||||
@@ -579,3 +621,69 @@ def asm(text: str) -> Inst:
|
||||
except NameError:
|
||||
if m := re.match(r'^(v_\w+)(\(.*\))$', dsl): return eval(f"{m.group(1)}_e32{m.group(2)}", ns)
|
||||
raise
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# CDNA DISASSEMBLER SUPPORT
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
try:
|
||||
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP2 as CDNA_VOP2, VOPC as CDNA_VOPC, VOP3A, VOP3B, VOP3P as CDNA_VOP3P,
|
||||
SOP1 as CDNA_SOP1, SOP2 as CDNA_SOP2, SOPC as CDNA_SOPC, SOPK as CDNA_SOPK, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS,
|
||||
FLAT as CDNA_FLAT, MUBUF as CDNA_MUBUF, MTBUF as CDNA_MTBUF, SDWA, DPP, VOP1Op as CDNA_VOP1Op)
|
||||
|
||||
def _cdna_src(inst, v, neg, abs_=0, n=1):
|
||||
s = inst.lit(v) if v == 255 else _fmt_src(v, n)
|
||||
if abs_: s = f"|{s}|"
|
||||
return f"neg({s})" if neg and v == 255 else (f"-{s}" if neg else s)
|
||||
|
||||
def _disasm_vop3a(inst) -> str:
|
||||
name, n, cl, om = inst.op_name.lower(), inst.num_srcs(), " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.src_regs(0)), _cdna_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.src_regs(1)), _cdna_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
if inst.op.value < 256: return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
suf = "_e64" if inst.op.value < 512 else ""
|
||||
return f"{name}{suf} {dst}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else (f"{name}{suf}" if name == 'v_nop' else f"{name}{suf} {dst}, {s0}, {s1}{cl}{om}" if n == 2 else f"{name}{suf} {dst}, {s0}{cl}{om}")
|
||||
|
||||
def _disasm_vop3b(inst) -> str:
|
||||
name, n = inst.op_name.lower(), inst.num_srcs()
|
||||
s0, s1, s2 = _cdna_src(inst, inst.src0, inst.neg&1), _cdna_src(inst, inst.src1, inst.neg&2), _cdna_src(inst, inst.src2, inst.neg&4)
|
||||
dst, suf = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}", "_e64" if 'co_' in name else ""
|
||||
cl, om = " clamp" if inst.clmp else "", _omod(inst.omod)
|
||||
return f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}, {s2}{cl}{om}" if n == 3 else f"{name}{suf} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{cl}{om}"
|
||||
|
||||
def _disasm_cdna_vop3p(inst) -> str:
|
||||
name, n, is_mfma = inst.op_name.lower(), inst.num_srcs(), 'mfma' in inst.op_name.lower() or 'smfmac' in inst.op_name.lower()
|
||||
get_src = lambda v, sc: inst.lit(v) if v == 255 else _fmt_src(v, sc)
|
||||
if is_mfma: sc = 2 if 'iu4' in name else 4 if 'iu8' in name or 'i4' in name else 8 if 'f16' in name or 'bf16' in name else 4; src0, src1, src2, dst = get_src(inst.src0, sc), get_src(inst.src1, sc), get_src(inst.src2, 16), _vreg(inst.vdst, 16)
|
||||
else: src0, src1, src2, dst = get_src(inst.src0, 1), get_src(inst.src1, 1), get_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
_SEL = {0: 'BYTE_0', 1: 'BYTE_1', 2: 'BYTE_2', 3: 'BYTE_3', 4: 'WORD_0', 5: 'WORD_1', 6: 'DWORD'}
|
||||
_UNUSED = {0: 'UNUSED_PAD', 1: 'UNUSED_SEXT', 2: 'UNUSED_PRESERVE'}
|
||||
_DPP = {0x130: "wave_shl:1", 0x134: "wave_rol:1", 0x138: "wave_shr:1", 0x13c: "wave_ror:1", 0x140: "row_mirror", 0x141: "row_half_mirror", 0x142: "row_bcast:15", 0x143: "row_bcast:31"}
|
||||
|
||||
def _disasm_sdwa(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0)
|
||||
mods = [f"dst_sel:{_SEL[inst.dst_sel]}" for _ in [1] if inst.dst_sel != 6] + [f"dst_unused:{_UNUSED[inst.dst_u]}" for _ in [1] if inst.dst_u] + [f"src0_sel:{_SEL[inst.src0_sel]}" for _ in [1] if inst.src0_sel != 6]
|
||||
return f"{name}_sdwa v{inst.vdst}, {src}" + (" " + " ".join(mods) if mods else "")
|
||||
|
||||
def _disasm_dpp(inst) -> str:
|
||||
try: name = CDNA_VOP1Op(inst.vop_op).name.lower()
|
||||
except ValueError: name = f"vop1_op_{inst.vop_op}"
|
||||
src, ctrl = f"v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}" if isinstance(inst.src0, int) else str(inst.src0), inst.dpp_ctrl
|
||||
dpp = f"quad_perm:[{ctrl&3},{(ctrl>>2)&3},{(ctrl>>4)&3},{(ctrl>>6)&3}]" if ctrl < 0x100 else f"row_shl:{ctrl&0xf}" if ctrl < 0x110 else f"row_shr:{ctrl&0xf}" if ctrl < 0x120 else f"row_ror:{ctrl&0xf}" if ctrl < 0x130 else _DPP.get(ctrl, f"dpp_ctrl:0x{ctrl:x}")
|
||||
mods = [dpp] + [f"row_mask:0x{inst.row_mask:x}" for _ in [1] if inst.row_mask != 0xf] + [f"bank_mask:0x{inst.bank_mask:x}" for _ in [1] if inst.bank_mask != 0xf] + ["bound_ctrl:1" for _ in [1] if inst.bound_ctrl]
|
||||
return f"{name}_dpp v{inst.vdst}, {src} " + " ".join(mods)
|
||||
|
||||
# Register CDNA handlers - shared formats use merged disassemblers, CDNA-only formats use dedicated ones
|
||||
DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP2: _disasm_vop2, CDNA_VOPC: _disasm_vopc,
|
||||
CDNA_SOP1: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPP: _disasm_sopp,
|
||||
CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_MUBUF: _disasm_buf, CDNA_MTBUF: _disasm_buf,
|
||||
VOP3A: _disasm_vop3a, VOP3B: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, SDWA: _disasm_sdwa, DPP: _disasm_dpp})
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -7,20 +7,18 @@ import functools
|
||||
|
||||
# instruction formats
|
||||
class DPP(Inst64):
|
||||
encoding = bits[31:26] == 0b110110
|
||||
src1_sel = bits[58:56]
|
||||
src1_sext = bits[59]
|
||||
src1_neg = bits[60]
|
||||
src1_abs = bits[61]
|
||||
s1 = bits[63]
|
||||
offset0 = bits[7:0]
|
||||
offset1 = bits[15:8]
|
||||
op = bits[24:17]
|
||||
acc = bits[25]
|
||||
addr:VGPRField = bits[39:32]
|
||||
data0:VGPRField = bits[47:40]
|
||||
data1:VGPRField = bits[55:48]
|
||||
vdst:VGPRField = bits[63:56]
|
||||
encoding = bits[8:0] == 0b11111010
|
||||
vop_op = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
vop2_op = bits[31:25]
|
||||
src0:Src = bits[39:32]
|
||||
dpp_ctrl = bits[48:40]
|
||||
bound_ctrl = bits[51]
|
||||
src0_neg = bits[52]
|
||||
src0_abs = bits[53]
|
||||
src1_neg = bits[54]
|
||||
src1_abs = bits[55]
|
||||
bank_mask = bits[59:56]
|
||||
row_mask = bits[63:60]
|
||||
|
||||
class DS(Inst64):
|
||||
@@ -82,6 +80,10 @@ class MUBUF(Inst64):
|
||||
acc = bits[55]
|
||||
|
||||
class SDWA(Inst64):
|
||||
encoding = bits[8:0] == 0b11111001
|
||||
vop_op = bits[16:9]
|
||||
vdst:VGPRField = bits[24:17]
|
||||
vop2_op = bits[31:25]
|
||||
src0:Src = bits[39:32]
|
||||
dst_sel = bits[42:40]
|
||||
dst_u = bits[44:43]
|
||||
@@ -97,9 +99,6 @@ class SDWA(Inst64):
|
||||
src1_neg = bits[60]
|
||||
src1_abs = bits[61]
|
||||
s1 = bits[63]
|
||||
sdst:SGPRField = bits[46:40]
|
||||
sd = bits[47]
|
||||
row_mask = bits[63:60]
|
||||
|
||||
class SDWAB(Inst64):
|
||||
src0:Src = bits[39:32]
|
||||
|
||||
@@ -488,6 +488,8 @@ class SMEMOp(IntEnum):
|
||||
S_BUFFER_LOAD_B512 = 12
|
||||
S_GL1_INV = 32
|
||||
S_DCACHE_INV = 33
|
||||
S_ATC_PROBE = 34
|
||||
S_ATC_PROBE_BUFFER = 35
|
||||
|
||||
class SOP1Op(IntEnum):
|
||||
S_MOV_B32 = 0
|
||||
@@ -710,6 +712,8 @@ class SOPKOp(IntEnum):
|
||||
S_SETREG_B32 = 18
|
||||
S_SETREG_IMM32_B32 = 19
|
||||
S_CALL_B64 = 20
|
||||
S_SUBVECTOR_LOOP_BEGIN = 22
|
||||
S_SUBVECTOR_LOOP_END = 23
|
||||
S_WAITCNT_VSCNT = 24
|
||||
S_WAITCNT_VMCNT = 25
|
||||
S_WAITCNT_EXPCNT = 26
|
||||
@@ -751,6 +755,8 @@ class SOPPOp(IntEnum):
|
||||
S_SENDMSGHALT = 55
|
||||
S_INCPERFLEVEL = 56
|
||||
S_DECPERFLEVEL = 57
|
||||
S_TTRACEDATA = 58
|
||||
S_TTRACEDATA_IMM = 59
|
||||
S_ICACHE_INV = 60
|
||||
S_BARRIER = 61
|
||||
|
||||
|
||||
@@ -692,6 +692,8 @@ s_buffer_load_b256 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B256)
|
||||
s_buffer_load_b512 = functools.partial(SMEM, SMEMOp.S_BUFFER_LOAD_B512)
|
||||
s_gl1_inv = functools.partial(SMEM, SMEMOp.S_GL1_INV)
|
||||
s_dcache_inv = functools.partial(SMEM, SMEMOp.S_DCACHE_INV)
|
||||
s_atc_probe = functools.partial(SMEM, SMEMOp.S_ATC_PROBE)
|
||||
s_atc_probe_buffer = functools.partial(SMEM, SMEMOp.S_ATC_PROBE_BUFFER)
|
||||
s_mov_b32 = functools.partial(SOP1, SOP1Op.S_MOV_B32)
|
||||
s_mov_b64 = functools.partial(SOP1, SOP1Op.S_MOV_B64)
|
||||
s_cmov_b32 = functools.partial(SOP1, SOP1Op.S_CMOV_B32)
|
||||
@@ -906,6 +908,8 @@ s_getreg_b32 = functools.partial(SOPK, SOPKOp.S_GETREG_B32)
|
||||
s_setreg_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_B32)
|
||||
s_setreg_imm32_b32 = functools.partial(SOPK, SOPKOp.S_SETREG_IMM32_B32)
|
||||
s_call_b64 = functools.partial(SOPK, SOPKOp.S_CALL_B64)
|
||||
s_subvector_loop_begin = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_BEGIN)
|
||||
s_subvector_loop_end = functools.partial(SOPK, SOPKOp.S_SUBVECTOR_LOOP_END)
|
||||
s_waitcnt_vscnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VSCNT)
|
||||
s_waitcnt_vmcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_VMCNT)
|
||||
s_waitcnt_expcnt = functools.partial(SOPK, SOPKOp.S_WAITCNT_EXPCNT)
|
||||
@@ -945,6 +949,8 @@ s_sendmsg = functools.partial(SOPP, SOPPOp.S_SENDMSG)
|
||||
s_sendmsghalt = functools.partial(SOPP, SOPPOp.S_SENDMSGHALT)
|
||||
s_incperflevel = functools.partial(SOPP, SOPPOp.S_INCPERFLEVEL)
|
||||
s_decperflevel = functools.partial(SOPP, SOPPOp.S_DECPERFLEVEL)
|
||||
s_ttracedata = functools.partial(SOPP, SOPPOp.S_TTRACEDATA)
|
||||
s_ttracedata_imm = functools.partial(SOPP, SOPPOp.S_TTRACEDATA_IMM)
|
||||
s_icache_inv = functools.partial(SOPP, SOPPOp.S_ICACHE_INV)
|
||||
s_barrier = functools.partial(SOPP, SOPPOp.S_BARRIER)
|
||||
v_interp_p10_f32 = functools.partial(VINTERP, VINTERPOp.V_INTERP_P10_F32)
|
||||
|
||||
@@ -458,10 +458,16 @@ class Inst:
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
import typing
|
||||
inst = cls.from_int(int.from_bytes(data[:cls._size()], 'little'))
|
||||
op_val = inst._values.get('op', 0)
|
||||
has_literal = cls.__name__ == 'VOP2' and op_val in (44, 45, 55, 56)
|
||||
has_literal = has_literal or (cls.__name__ == 'SOP2' and op_val in (69, 70))
|
||||
# Check for instructions that always have a literal constant (FMAMK/FMAAK/MADMK/MADAK, SETREG_IMM32)
|
||||
op_name = ''
|
||||
if cls.__name__ in ('VOP2', 'SOP2', 'SOPK') and 'op' in (hints := typing.get_type_hints(cls, include_extras=True)):
|
||||
if typing.get_origin(hints['op']) is typing.Annotated:
|
||||
try: op_name = typing.get_args(hints['op'])[1](op_val).name
|
||||
except (ValueError, TypeError): pass
|
||||
has_literal = any(x in op_name for x in ('FMAMK', 'FMAAK', 'MADMK', 'MADAK', 'SETREG_IMM32'))
|
||||
# VOPD fmaak/fmamk always have a literal (opx/opy value 1 or 2)
|
||||
opx, opy = inst._values.get('opx', 0), inst._values.get('opy', 0)
|
||||
has_literal = has_literal or (cls.__name__ == 'VOPD' and (opx in (1, 2) or opy in (1, 2)))
|
||||
@@ -475,7 +481,7 @@ class Inst:
|
||||
lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little')
|
||||
# Find which source has literal (255) and check its register count
|
||||
lit_src_is_64 = False
|
||||
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]:
|
||||
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
|
||||
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255:
|
||||
lit_src_is_64 = inst.src_regs(idx) == 2
|
||||
break
|
||||
@@ -495,7 +501,12 @@ class Inst:
|
||||
return unwrap(self._values.get(name, 0))
|
||||
|
||||
def lit(self, v: int, neg: bool = False) -> str:
|
||||
s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
|
||||
if v == 255 and self._literal is not None:
|
||||
# For 64-bit sources, literal is stored shifted - extract the 32-bit value
|
||||
lit32 = (self._literal >> 32) if self._literal > 0xffffffff else self._literal
|
||||
s = f"0x{lit32:x}"
|
||||
else:
|
||||
s = decode_src(v)
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def __eq__(self, other):
|
||||
@@ -528,6 +539,10 @@ class Inst:
|
||||
elif val in self._VOP3SD_OPS: self.op = VOP3SDOp(val)
|
||||
else: self.op = VOP3Op(val)
|
||||
except ValueError: self.op = val
|
||||
# Prefer BitField marker (class-specific enum) over _enum_map (generic RDNA3 enums)
|
||||
elif 'op' in self._fields and (marker := self._fields['op'].marker) and issubclass(marker, IntEnum):
|
||||
try: self.op = marker(val)
|
||||
except ValueError: self.op = val
|
||||
elif cls_name in self._enum_map:
|
||||
try: self.op = self._enum_map[cls_name](val)
|
||||
except ValueError: self.op = val
|
||||
|
||||
@@ -194,13 +194,30 @@ def _parse_single_pdf(url: str):
|
||||
if fmt_name in formats:
|
||||
formats[fmt_name] = [(n, h, 14 if n == 'OP' else l, e, t) for n, h, l, e, t in formats[fmt_name]]
|
||||
if doc_name in ('RDNA3', 'RDNA3.5'):
|
||||
if 'SOPPOp' in enums: assert 8 not in enums['SOPPOp']; enums['SOPPOp'][8] = 'S_WAITCNT_DEPCTR'
|
||||
if 'SOPPOp' in enums:
|
||||
for k, v in {8: 'S_WAITCNT_DEPCTR', 58: 'S_TTRACEDATA', 59: 'S_TTRACEDATA_IMM'}.items():
|
||||
assert k not in enums['SOPPOp']; enums['SOPPOp'][k] = v
|
||||
if 'SOPKOp' in enums:
|
||||
for k, v in {22: 'S_SUBVECTOR_LOOP_BEGIN', 23: 'S_SUBVECTOR_LOOP_END'}.items():
|
||||
assert k not in enums['SOPKOp']; enums['SOPKOp'][k] = v
|
||||
if 'SMEMOp' in enums:
|
||||
for k, v in {34: 'S_ATC_PROBE', 35: 'S_ATC_PROBE_BUFFER'}.items():
|
||||
assert k not in enums['SMEMOp']; enums['SMEMOp'][k] = v
|
||||
if 'DSOp' in enums:
|
||||
for k, v in {24: 'DS_GWS_SEMA_RELEASE_ALL', 25: 'DS_GWS_INIT', 26: 'DS_GWS_SEMA_V', 27: 'DS_GWS_SEMA_BR', 28: 'DS_GWS_SEMA_P', 29: 'DS_GWS_BARRIER'}.items():
|
||||
assert k not in enums['DSOp']; enums['DSOp'][k] = v
|
||||
if 'FLATOp' in enums:
|
||||
for k, v in {40: 'GLOBAL_LOAD_ADDTID_B32', 41: 'GLOBAL_STORE_ADDTID_B32', 55: 'FLAT_ATOMIC_CSUB_U32'}.items():
|
||||
assert k not in enums['FLATOp']; enums['FLATOp'][k] = v
|
||||
# CDNA SDWA/DPP: PDF only has modifier fields, need VOP1/VOP2 overlay for correct encoding
|
||||
if is_cdna:
|
||||
if 'SDWA' in formats:
|
||||
formats['SDWA'] = [('ENCODING', 8, 0, 0xf9, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None)] + \
|
||||
[f for f in formats['SDWA'] if f[0] not in ('ENCODING', 'SDST', 'SD', 'ROW_MASK')]
|
||||
if 'DPP' in formats:
|
||||
formats['DPP'] = [('ENCODING', 8, 0, 0xfa, None), ('VOP_OP', 16, 9, None, None), ('VDST', 24, 17, None, 'VGPRField'), ('VOP2_OP', 31, 25, None, None),
|
||||
('SRC0', 39, 32, None, 'Src'), ('DPP_CTRL', 48, 40, None, None), ('BOUND_CTRL', 51, 51, None, None), ('SRC0_NEG', 52, 52, None, None), ('SRC0_ABS', 53, 53, None, None),
|
||||
('SRC1_NEG', 54, 54, None, None), ('SRC1_ABS', 55, 55, None, None), ('BANK_MASK', 59, 56, None, None), ('ROW_MASK', 63, 60, None, None)]
|
||||
|
||||
# Extract pseudocode for instructions
|
||||
all_text = '\n'.join(pdf.text(i) for i in range(instr_start, instr_end))
|
||||
|
||||
@@ -131,28 +131,19 @@ def _make_asm_test(name):
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
|
||||
_, base_fmt_cls, base_op_enum = LLVM_TEST_FILES[name]
|
||||
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
|
||||
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
|
||||
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
|
||||
|
||||
# First pass: decode all instructions and collect disasm strings
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
|
||||
skipped = 0
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
if len(data) > fmt_cls._size(): continue
|
||||
temp_inst = fmt_cls.from_bytes(data)
|
||||
temp_op = temp_inst._values.get('op', 0)
|
||||
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
|
||||
if temp_op in undocumented.get(name, set()): skipped += 1; continue
|
||||
if name == 'sopp':
|
||||
simm16 = temp_inst._values.get('simm16', 0)
|
||||
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
|
||||
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
|
||||
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
|
||||
# Detect VOP3 promotions in VOP1/VOP2/VOPC tests: VOP3 has bits [31:26]=0b110101 in first dword
|
||||
is_vop3_enc = name in ('vop1', 'vop2', 'vopc', 'vopcx') and len(data) >= 4 and (data[3] >> 2) == 0x35
|
||||
fmt_cls, op_enum = (VOP3, VOP3Op) if is_vop3_enc else (base_fmt_cls, base_op_enum)
|
||||
try:
|
||||
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
if base_fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
||||
temp = VOP3.from_bytes(data)
|
||||
op_val = temp._values.get('op', 0)
|
||||
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
||||
@@ -188,7 +179,7 @@ def _make_disasm_test(name):
|
||||
if llvm_bytes is not None and llvm_bytes == data: passed += 1
|
||||
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
144
extra/assembly/amd/test/test_llvm_cdna.py
Normal file
144
extra/assembly/amd/test/test_llvm_cdna.py
Normal file
@@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test CDNA assembler/disassembler against LLVM test vectors."""
|
||||
import unittest, re, subprocess
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.amd.autogen.cdna.ins import *
|
||||
from extra.assembly.amd.asm import disasm
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
def parse_llvm_tests(text: str, mnemonic_filter: str = None, size_filter: int = None) -> list[tuple[str, bytes]]:
|
||||
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
||||
tests, lines = [], text.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
line = line.strip()
|
||||
if not line or line.startswith(('//', '.', ';')): continue
|
||||
asm_text = line.split('//')[0].strip()
|
||||
if not asm_text or (mnemonic_filter and not asm_text.startswith(mnemonic_filter)): continue
|
||||
for j in list(range(max(0, i - 3), i)) + list(range(i, min(i + 3, len(lines)))):
|
||||
if m := re.search(r'(?:VI9|GFX9|CHECK)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
elif m := re.search(r'CHECK[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
||||
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
||||
else: continue
|
||||
try:
|
||||
data = bytes.fromhex(hex_bytes)
|
||||
if size_filter is None or len(data) == size_filter: tests.append((asm_text, data))
|
||||
except ValueError: pass
|
||||
break
|
||||
return tests
|
||||
|
||||
# Use gfx9 tests for compatible scalar/vector formats and gfx90a/gfx942 tests for CDNA-specific instructions
|
||||
# Format: (filename, format_class, op_enum, mcpu, mnemonic_filter, size_filter)
|
||||
CDNA_TEST_FILES = {
|
||||
# Scalar ALU - encoding is stable across GFX9/CDNA
|
||||
'sop1': ('gfx9_asm_sop1.s', SOP1, SOP1Op, 'gfx940', None, None),
|
||||
'sop2': ('gfx9_asm_sop2.s', SOP2, SOP2Op, 'gfx940', None, None),
|
||||
'sopp': ('gfx9_asm_sopp.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopp_gfx9': ('sopp-gfx9.s', SOPP, SOPPOp, 'gfx940', None, None),
|
||||
'sopk': ('gfx9_asm_sopk.s', SOPK, SOPKOp, 'gfx940', None, None),
|
||||
'sopc': ('gfx9_asm_sopc.s', SOPC, SOPCOp, 'gfx940', None, None),
|
||||
# Vector ALU - encoding is mostly stable
|
||||
'vop1': ('gfx9_asm_vop1.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop1_gfx9': ('vop1-gfx9.s', VOP1, VOP1Op, 'gfx940', None, None),
|
||||
'vop2': ('gfx9_asm_vop2.s', VOP2, VOP2Op, 'gfx940', None, None),
|
||||
'vopc': ('gfx9_asm_vopc.s', VOPC, VOPCOp, 'gfx940', None, None),
|
||||
'vop3p': ('gfx9_asm_vop3p.s', VOP3P, VOP3POp, 'gfx940', None, None),
|
||||
'vop3_gfx9': ('vop3-gfx9.s', VOP3A, VOP3AOp, 'gfx940', None, 8), # Only 64-bit VOP3 instructions
|
||||
# Memory instructions
|
||||
'ds': ('gfx9_asm_ds.s', DS, DSOp, 'gfx940', None, None),
|
||||
'ds_gfx9': ('ds-gfx9.s', DS, DSOp, 'gfx940', None, None),
|
||||
# CDNA memory instructions (gfx90a has correct FLAT/MUBUF encodings with acc registers)
|
||||
'flat_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'flat_', None),
|
||||
'global_gfx90a': ('gfx90a_ldst_acc.s', FLAT, FLATOp, 'gfx90a', 'global_', None),
|
||||
'mubuf_gfx90a': ('gfx90a_ldst_acc.s', MUBUF, MUBUFOp, 'gfx90a', 'buffer_', None),
|
||||
'mubuf_gfx9': ('mubuf-gfx9.s', MUBUF, MUBUFOp, 'gfx940', None, None),
|
||||
'scratch_gfx942': ('flat-scratch-gfx942.s', FLAT, FLATOp, 'gfx942', 'scratch_', None),
|
||||
# CDNA-specific: MFMA/MAI instructions
|
||||
'mai': ('mai-gfx942.s', VOP3P, VOP3POp, 'gfx942', None, None),
|
||||
# SDWA and DPP format tests for VOP1 (VOP2 has different bit layout, tested separately)
|
||||
'sdwa_vop1': ('gfx9_asm_vop1.s', SDWA, VOP1Op, 'gfx940', None, None),
|
||||
'dpp_vop1': ('gfx9_asm_vop1.s', DPP, VOP1Op, 'gfx940', None, None),
|
||||
}
|
||||
|
||||
class TestLLVMCDNA(unittest.TestCase):
|
||||
"""Test CDNA instruction format decode/encode roundtrip and disassembly."""
|
||||
tests: dict[str, list[tuple[str, bytes]]] = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
for name, (filename, _, _, _, mnemonic_filter, size_filter) in CDNA_TEST_FILES.items():
|
||||
try:
|
||||
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
||||
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'), mnemonic_filter, size_filter)
|
||||
except Exception as e:
|
||||
print(f"Warning: couldn't fetch {filename}: {e}")
|
||||
cls.tests[name] = []
|
||||
|
||||
def _get_val(v): return v.val if hasattr(v, 'val') else v
|
||||
|
||||
def _filter_and_decode(tests, fmt_cls, op_enum):
|
||||
"""Filter tests and decode instructions, yielding (asm_text, data, decoded, error)."""
|
||||
fn, is_sdwa, is_dpp = fmt_cls.__name__, fmt_cls.__name__ == 'SDWA', fmt_cls.__name__ == 'DPP'
|
||||
for asm_text, data in tests:
|
||||
has_lit = False
|
||||
# SDWA/DPP format tests: only accept matching 8-byte instructions
|
||||
if is_sdwa:
|
||||
if len(data) != 8 or data[0] != 0xf9: continue
|
||||
elif is_dpp:
|
||||
if len(data) != 8 or data[0] != 0xfa: continue
|
||||
elif fmt_cls._size() == 4 and len(data) == 8:
|
||||
if data[0] in (0xf9, 0xfa): continue # Skip SDWA/DPP (tested separately)
|
||||
has_lit = data[0] == 255 or (len(data) >= 2 and data[1] == 255 and fn in ('SOP2', 'SOPC'))
|
||||
if fn == 'SOPK': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 23) & 0x1f) == 20
|
||||
if fn == 'VOP2': has_lit = has_lit or ((int.from_bytes(data[:4], 'little') >> 25) & 0x3f) in (23, 24, 36, 37)
|
||||
if not has_lit: continue
|
||||
if len(data) > fmt_cls._size() + (4 if has_lit else 0): continue
|
||||
try:
|
||||
decoded = fmt_cls.from_bytes(data)
|
||||
# For SDWA/DPP, opcode location depends on VOP1 vs VOP2
|
||||
if is_sdwa or is_dpp:
|
||||
vop2_op = _get_val(decoded._values.get('vop2_op', 0))
|
||||
op_val = _get_val(decoded._values.get('vop_op', 0)) if vop2_op == 0x3f else vop2_op
|
||||
else:
|
||||
op_val = _get_val(decoded._values.get('op', 0))
|
||||
try: op_enum(op_val)
|
||||
except ValueError: continue
|
||||
yield asm_text, data, decoded, None
|
||||
except Exception as e:
|
||||
yield asm_text, data, None, str(e)
|
||||
|
||||
def _make_roundtrip_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{asm_text}': orig={data.hex()} reenc={decoded.to_bytes()[:len(data)].hex()}")
|
||||
print(f"CDNA {name.upper()} roundtrip: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
def _make_disasm_test(name):
|
||||
def test(self):
|
||||
_, fmt_cls, op_enum, _, _, _ = CDNA_TEST_FILES[name]
|
||||
passed, failed, failures = 0, 0, []
|
||||
for asm_text, data, decoded, error in _filter_and_decode(self.tests.get(name, []), fmt_cls, op_enum):
|
||||
if error: failed += 1; failures.append(f"'{asm_text}': {error}"); continue
|
||||
if decoded.to_bytes()[:len(data)] != data: failed += 1; failures.append(f"'{asm_text}': roundtrip failed"); continue
|
||||
if not (disasm_text := disasm(decoded)) or not disasm_text.strip(): failed += 1; failures.append(f"'{asm_text}': empty disassembly"); continue
|
||||
passed += 1
|
||||
print(f"CDNA {name.upper()} disasm: {passed} passed, {failed} failed")
|
||||
if failures[:5]: print(" " + "\n ".join(failures[:5]))
|
||||
self.assertEqual(failed, 0)
|
||||
return test
|
||||
|
||||
for name in CDNA_TEST_FILES:
|
||||
setattr(TestLLVMCDNA, f'test_{name}_roundtrip', _make_roundtrip_test(name))
|
||||
setattr(TestLLVMCDNA, f'test_{name}_disasm', _make_disasm_test(name))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user