assembly/amd: rdna4 passing test_roundtrip (#14300)

* test_roundtrip on different archs

* failing tests

* take RDNA4 xml changes from the emu branch

* work

* min diff to disasm flat

* test_add passes, rdna4 first

* correct vgpr field for the multi dword store stuff

* amdllvm

* recompile in roundtrip, get sources from emulator

* amdllvm, 2

* clean clean

* note, don't rely on that os.environ

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
qazal
2026-01-23 07:33:53 -05:00
committed by GitHub
parent f3b0e42863
commit b913c910c5
8 changed files with 58 additions and 40 deletions

View File

@@ -247,7 +247,7 @@ def write_enum(enums, path):
with open(path, "w") as f: f.write("\n".join(lines))
def write_ins(encodings, enums, lit_only_ops, types, arch, path):
_VGPR_FIELDS = {"vdst", "vdstx", "vsrc0", "vsrc1", "vsrc2", "vsrc3", "vsrcx1", "vsrcy1", "vaddr", "vdata", "data", "data0", "data1", "addr"}
_VGPR_FIELDS = {"vdst", "vdstx", "vsrc0", "vsrc1", "vsrc2", "vsrc3", "vsrcx1", "vsrcy1", "vaddr", "vdata", "data", "data0", "data1", "addr", "vsrc"}
_VARIANT_SUFFIXES = ("_LIT", "_DPP16", "_DPP8", "_SDWA_SDST", "_SDWA", "_MFMA")
def get_base_fmt(fmt):
for sfx in _VARIANT_SUFFIXES: fmt = fmt.replace(sfx, "")
@@ -310,7 +310,10 @@ def write_ins(encodings, enums, lit_only_ops, types, arch, path):
all_ops = set(enums.get(enc_name, {}).keys())
# Exclude SDST ops from base class (they need VOP1_SDST/VOP3_SDST/VOP3B)
base_allowed = all_ops - base_lit_ops - sdst_opcodes.get(enc_name, set())
if enc_name in ("FLAT", "VFLAT"):
# RDNA3 FLAT/GLOBAL/SCRATCH share encoding bits, differentiated by seg field
# RDNA4 VFLAT/VGLOBAL/VSCRATCH have distinct encoding bits, no seg field needed
has_seg_field = any(fn == "seg" for fn, _, _ in fields)
if enc_name in ("FLAT", "VFLAT") and has_seg_field:
prefix = "V" if enc_name == "VFLAT" else ""
for cls, seg, op_enum in [(f"{prefix}FLAT", 0, f"{prefix}FLATOp"), (f"{prefix}GLOBAL", 2, f"{prefix}GLOBALOp"), (f"{prefix}SCRATCH", 1, f"{prefix}SCRATCHOp")]:
cls_ops = set(enums.get(cls, {}).keys())
@@ -320,7 +323,7 @@ def write_ins(encodings, enums, lit_only_ops, types, arch, path):
elif fn == "op": lines.append(f" op = EnumBitField({hi}, {lo}, {op_enum}, {fmt_allowed(op_enum, cls_ops)})")
else: lines.append(f" {fn} = {field_def(fn, hi, lo, cls, enc_bits)}")
lines.append("")
elif enc_name not in ("FLAT_GLOBAL", "FLAT_SCRATCH", "FLAT_GLBL", "VGLOBAL", "VSCRATCH", "DPP", "SDWA"):
elif enc_name not in ("FLAT_GLOBAL", "FLAT_SCRATCH", "FLAT_GLBL", "DPP", "SDWA"):
lines.append(f"class {enc_name}(Inst):")
for fn, hi, lo in sort_fields(fields):
if fn == "op":

View File

@@ -101,11 +101,11 @@ class VFLAT(Inst):
sve = BitField(49, 49)
scope = BitField(51, 50)
th = BitField(54, 52)
vsrc = BitField(62, 55)
vsrc = VGPRField(62, 55)
ioffset = BitField(95, 72)
class VGLOBAL(Inst):
encoding = FixedBitField(31, 24, 0b11101100)
encoding = FixedBitField(31, 24, 0b11101110)
op = EnumBitField(21, 14, VGLOBALOp, {VGLOBALOp.GLOBAL_LOAD_U8, VGLOBALOp.GLOBAL_LOAD_I8, VGLOBALOp.GLOBAL_LOAD_U16, VGLOBALOp.GLOBAL_LOAD_I16, VGLOBALOp.GLOBAL_LOAD_B32, VGLOBALOp.GLOBAL_LOAD_B64, VGLOBALOp.GLOBAL_LOAD_B96, VGLOBALOp.GLOBAL_LOAD_B128, VGLOBALOp.GLOBAL_STORE_B8, VGLOBALOp.GLOBAL_STORE_B16, VGLOBALOp.GLOBAL_STORE_B32, VGLOBALOp.GLOBAL_STORE_B64, VGLOBALOp.GLOBAL_STORE_B96, VGLOBALOp.GLOBAL_STORE_B128, VGLOBALOp.GLOBAL_LOAD_D16_U8, VGLOBALOp.GLOBAL_LOAD_D16_I8, VGLOBALOp.GLOBAL_LOAD_D16_B16, VGLOBALOp.GLOBAL_LOAD_D16_HI_U8, VGLOBALOp.GLOBAL_LOAD_D16_HI_I8, VGLOBALOp.GLOBAL_LOAD_D16_HI_B16, VGLOBALOp.GLOBAL_STORE_D16_HI_B8, VGLOBALOp.GLOBAL_STORE_D16_HI_B16, VGLOBALOp.GLOBAL_LOAD_ADDTID_B32, VGLOBALOp.GLOBAL_STORE_ADDTID_B32, VGLOBALOp.GLOBAL_INV, VGLOBALOp.GLOBAL_WB, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B32, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B32, VGLOBALOp.GLOBAL_ATOMIC_ADD_U32, VGLOBALOp.GLOBAL_ATOMIC_SUB_U32, VGLOBALOp.GLOBAL_ATOMIC_SUB_CLAMP_U32, VGLOBALOp.GLOBAL_ATOMIC_MIN_I32, VGLOBALOp.GLOBAL_ATOMIC_MIN_U32, VGLOBALOp.GLOBAL_ATOMIC_MAX_I32, VGLOBALOp.GLOBAL_ATOMIC_MAX_U32, VGLOBALOp.GLOBAL_ATOMIC_AND_B32, VGLOBALOp.GLOBAL_ATOMIC_OR_B32, VGLOBALOp.GLOBAL_ATOMIC_XOR_B32, VGLOBALOp.GLOBAL_ATOMIC_INC_U32, VGLOBALOp.GLOBAL_ATOMIC_DEC_U32, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B64, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B64, VGLOBALOp.GLOBAL_ATOMIC_ADD_U64, VGLOBALOp.GLOBAL_ATOMIC_SUB_U64, VGLOBALOp.GLOBAL_ATOMIC_MIN_I64, VGLOBALOp.GLOBAL_ATOMIC_MIN_U64, VGLOBALOp.GLOBAL_ATOMIC_MAX_I64, VGLOBALOp.GLOBAL_ATOMIC_MAX_U64, VGLOBALOp.GLOBAL_ATOMIC_AND_B64, VGLOBALOp.GLOBAL_ATOMIC_OR_B64, VGLOBALOp.GLOBAL_ATOMIC_XOR_B64, VGLOBALOp.GLOBAL_ATOMIC_INC_U64, VGLOBALOp.GLOBAL_ATOMIC_DEC_U64, VGLOBALOp.GLOBAL_WBINV, VGLOBALOp.GLOBAL_ATOMIC_COND_SUB_U32, VGLOBALOp.GLOBAL_ATOMIC_MIN_NUM_F32, VGLOBALOp.GLOBAL_ATOMIC_MAX_NUM_F32, VGLOBALOp.GLOBAL_LOAD_BLOCK, VGLOBALOp.GLOBAL_STORE_BLOCK, VGLOBALOp.GLOBAL_ATOMIC_ADD_F32, VGLOBALOp.GLOBAL_LOAD_TR_B128, VGLOBALOp.GLOBAL_LOAD_TR_B64, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_F16, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_BF16, VGLOBALOp.GLOBAL_ATOMIC_ORDERED_ADD_B64})
vdst = VGPRField(39, 32)
vaddr = VGPRField(71, 64)
@@ -114,20 +114,7 @@ class VGLOBAL(Inst):
sve = BitField(49, 49)
scope = BitField(51, 50)
th = BitField(54, 52)
vsrc = BitField(62, 55)
ioffset = BitField(95, 72)
class VSCRATCH(Inst):
encoding = FixedBitField(31, 24, 0b11101100)
op = EnumBitField(21, 14, VSCRATCHOp, {VSCRATCHOp.SCRATCH_LOAD_U8, VSCRATCHOp.SCRATCH_LOAD_I8, VSCRATCHOp.SCRATCH_LOAD_U16, VSCRATCHOp.SCRATCH_LOAD_I16, VSCRATCHOp.SCRATCH_LOAD_B32, VSCRATCHOp.SCRATCH_LOAD_B64, VSCRATCHOp.SCRATCH_LOAD_B96, VSCRATCHOp.SCRATCH_LOAD_B128, VSCRATCHOp.SCRATCH_STORE_B8, VSCRATCHOp.SCRATCH_STORE_B16, VSCRATCHOp.SCRATCH_STORE_B32, VSCRATCHOp.SCRATCH_STORE_B64, VSCRATCHOp.SCRATCH_STORE_B96, VSCRATCHOp.SCRATCH_STORE_B128, VSCRATCHOp.SCRATCH_LOAD_D16_U8, VSCRATCHOp.SCRATCH_LOAD_D16_I8, VSCRATCHOp.SCRATCH_LOAD_D16_B16, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16, VSCRATCHOp.SCRATCH_LOAD_BLOCK, VSCRATCHOp.SCRATCH_STORE_BLOCK})
vdst = VGPRField(39, 32)
vaddr = VGPRField(71, 64)
saddr = SGPRField(6, 0, default=NULL)
nv = BitField(7, 7)
sve = BitField(49, 49)
scope = BitField(51, 50)
th = BitField(54, 52)
vsrc = BitField(62, 55)
vsrc = VGPRField(62, 55)
ioffset = BitField(95, 72)
class VIMAGE(Inst):
@@ -253,6 +240,19 @@ class VSAMPLE(Inst):
vaddr2 = BitField(87, 80)
vaddr3 = BitField(95, 88)
class VSCRATCH(Inst):
encoding = FixedBitField(31, 24, 0b11101101)
op = EnumBitField(21, 14, VSCRATCHOp, {VSCRATCHOp.SCRATCH_LOAD_U8, VSCRATCHOp.SCRATCH_LOAD_I8, VSCRATCHOp.SCRATCH_LOAD_U16, VSCRATCHOp.SCRATCH_LOAD_I16, VSCRATCHOp.SCRATCH_LOAD_B32, VSCRATCHOp.SCRATCH_LOAD_B64, VSCRATCHOp.SCRATCH_LOAD_B96, VSCRATCHOp.SCRATCH_LOAD_B128, VSCRATCHOp.SCRATCH_STORE_B8, VSCRATCHOp.SCRATCH_STORE_B16, VSCRATCHOp.SCRATCH_STORE_B32, VSCRATCHOp.SCRATCH_STORE_B64, VSCRATCHOp.SCRATCH_STORE_B96, VSCRATCHOp.SCRATCH_STORE_B128, VSCRATCHOp.SCRATCH_LOAD_D16_U8, VSCRATCHOp.SCRATCH_LOAD_D16_I8, VSCRATCHOp.SCRATCH_LOAD_D16_B16, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16, VSCRATCHOp.SCRATCH_LOAD_BLOCK, VSCRATCHOp.SCRATCH_STORE_BLOCK})
vdst = VGPRField(39, 32)
vaddr = VGPRField(71, 64)
saddr = SGPRField(6, 0, default=NULL)
nv = BitField(7, 7)
sve = BitField(49, 49)
scope = BitField(51, 50)
th = BitField(54, 52)
vsrc = VGPRField(62, 55)
ioffset = BitField(95, 72)
class SOP1_LIT(SOP1):
op = EnumBitField(15, 8, SOP1Op, {SOP1Op.S_MOV_B32, SOP1Op.S_MOV_B64, SOP1Op.S_CMOV_B32, SOP1Op.S_CMOV_B64, SOP1Op.S_BREV_B32, SOP1Op.S_BREV_B64, SOP1Op.S_CTZ_I32_B32, SOP1Op.S_CTZ_I32_B64, SOP1Op.S_CLZ_I32_U32, SOP1Op.S_CLZ_I32_U64, SOP1Op.S_CLS_I32, SOP1Op.S_CLS_I32_I64, SOP1Op.S_SEXT_I32_I8, SOP1Op.S_SEXT_I32_I16, SOP1Op.S_BITSET0_B32, SOP1Op.S_BITSET0_B64, SOP1Op.S_BITSET1_B32, SOP1Op.S_BITSET1_B64, SOP1Op.S_BITREPLICATE_B64_B32, SOP1Op.S_ABS_I32, SOP1Op.S_BCNT0_I32_B32, SOP1Op.S_BCNT0_I32_B64, SOP1Op.S_BCNT1_I32_B32, SOP1Op.S_BCNT1_I32_B64, SOP1Op.S_QUADMASK_B32, SOP1Op.S_QUADMASK_B64, SOP1Op.S_WQM_B32, SOP1Op.S_WQM_B64, SOP1Op.S_NOT_B32, SOP1Op.S_NOT_B64, SOP1Op.S_AND_SAVEEXEC_B32, SOP1Op.S_AND_SAVEEXEC_B64, SOP1Op.S_OR_SAVEEXEC_B32, SOP1Op.S_OR_SAVEEXEC_B64, SOP1Op.S_XOR_SAVEEXEC_B32, SOP1Op.S_XOR_SAVEEXEC_B64, SOP1Op.S_NAND_SAVEEXEC_B32, SOP1Op.S_NAND_SAVEEXEC_B64, SOP1Op.S_NOR_SAVEEXEC_B32, SOP1Op.S_NOR_SAVEEXEC_B64, SOP1Op.S_XNOR_SAVEEXEC_B32, SOP1Op.S_XNOR_SAVEEXEC_B64, SOP1Op.S_AND_NOT0_SAVEEXEC_B32, SOP1Op.S_AND_NOT0_SAVEEXEC_B64, SOP1Op.S_OR_NOT0_SAVEEXEC_B32, SOP1Op.S_OR_NOT0_SAVEEXEC_B64, SOP1Op.S_AND_NOT1_SAVEEXEC_B32, SOP1Op.S_AND_NOT1_SAVEEXEC_B64, SOP1Op.S_OR_NOT1_SAVEEXEC_B32, SOP1Op.S_OR_NOT1_SAVEEXEC_B64, SOP1Op.S_AND_NOT0_WREXEC_B32, SOP1Op.S_AND_NOT0_WREXEC_B64, SOP1Op.S_AND_NOT1_WREXEC_B32, SOP1Op.S_AND_NOT1_WREXEC_B64, SOP1Op.S_MOVRELS_B32, SOP1Op.S_MOVRELS_B64, SOP1Op.S_MOVRELD_B32, SOP1Op.S_MOVRELD_B64, SOP1Op.S_MOVRELSD_2_B32, SOP1Op.S_GETPC_B64, SOP1Op.S_SETPC_B64, SOP1Op.S_SWAPPC_B64, SOP1Op.S_RFE_B64, SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64, SOP1Op.S_BARRIER_SIGNAL, SOP1Op.S_BARRIER_SIGNAL_ISFIRST, SOP1Op.S_GET_BARRIER_STATE, SOP1Op.S_BARRIER_INIT, SOP1Op.S_BARRIER_JOIN, SOP1Op.S_ALLOC_VGPR, SOP1Op.S_SLEEP_VAR, SOP1Op.S_CEIL_F32, SOP1Op.S_FLOOR_F32, SOP1Op.S_TRUNC_F32, SOP1Op.S_RNDNE_F32, SOP1Op.S_CVT_F32_I32, SOP1Op.S_CVT_F32_U32, SOP1Op.S_CVT_I32_F32, SOP1Op.S_CVT_U32_F32, SOP1Op.S_CVT_F16_F32, SOP1Op.S_CVT_F32_F16, SOP1Op.S_CVT_HI_F32_F16, SOP1Op.S_CEIL_F16, SOP1Op.S_FLOOR_F16, SOP1Op.S_TRUNC_F16, SOP1Op.S_RNDNE_F16})
literal = BitField(63, 32)

View File

@@ -29,10 +29,11 @@ def _matches(data: bytes, cls: type[Inst]) -> bool:
# Import instruction classes for each architecture
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP1_LIT, VOP2, VOP2_LIT, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP,
SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH)
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP2 as R4_VOP2, VOP2_LIT as R4_VOP2_LIT,
VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P,
VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT,
SOPC as R4_SOPC, SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP,
from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP1_LIT as R4_VOP1_LIT,
VOP2 as R4_VOP2, VOP2_LIT as R4_VOP2_LIT, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P,
VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT,
SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT, SOPC as R4_SOPC, SOPC_LIT as R4_SOPC_LIT,
SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP,
SMEM as R4_SMEM, DS as R4_DS, VFLAT as R4_FLAT, VGLOBAL as R4_GLOBAL, VSCRATCH as R4_SCRATCH)
from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP1_SDWA as C_VOP1_SDWA, VOP1_DPP16 as C_VOP1_DPP16,
VOP2 as C_VOP2, VOP2_LIT as C_VOP2_LIT, VOP2_SDWA as C_VOP2_SDWA, VOP2_DPP16 as C_VOP2_DPP16,
@@ -47,7 +48,8 @@ _FORMATS = {
"rdna3": [VOPD, VOP3P, VINTERP, VOP3SD, VOP3_SDST, VOP3, DS, GLOBAL, SCRATCH, FLAT, SMEM,
SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, VOPC, VOP1_SDST, VOP1, VOP1_LIT, VOP2, VOP2_LIT],
"rdna4": [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3SD, R4_VOP3_SDST, R4_VOP3, R4_DS, R4_GLOBAL, R4_SCRATCH, R4_FLAT, R4_SMEM,
R4_SOP1, R4_SOPC, R4_SOPP, R4_SOPK, R4_SOPK_LIT, R4_VOPC, R4_VOP1_SDST, R4_VOP1, R4_SOP2, R4_SOP2_LIT, R4_VOP2, R4_VOP2_LIT],
R4_SOP1, R4_SOP1_LIT, R4_SOPC, R4_SOPC_LIT, R4_SOPP, R4_SOPK, R4_SOPK_LIT, R4_VOPC, R4_VOP1_SDST, R4_VOP1, R4_VOP1_LIT,
R4_SOP2, R4_SOP2_LIT, R4_VOP2, R4_VOP2_LIT],
"cdna": [C_VOP3PX2, C_VOP3P, C_VOP3SD, C_VOP3_SDST, C_VOP3, C_DS, C_GLOBAL, C_SCRATCH, C_FLAT, C_MUBUF, C_SMEM,
C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_SOPK_LIT, C_VOPC_SDWA_SDST, C_VOPC,
C_VOP1_DPP16, C_VOP1_SDWA, C_VOP1, C_VOP2_DPP16, C_VOP2_SDWA, C_SOP2, C_VOP2, C_VOP2_LIT],

View File

@@ -84,10 +84,11 @@ from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as
VOP3SD as R4_VOP3SD, VOP3SD_LIT as R4_VOP3SD_LIT, VOP3P as R4_VOP3P, VOP3P_LIT as R4_VOP3P_LIT, VOPC as R4_VOPC, VOPC_LIT as R4_VOPC_LIT,
VOPD as R4_VOPD, VOPD_LIT as R4_VOPD_LIT, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT, SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT,
SOPC as R4_SOPC, SOPC_LIT as R4_SOPC_LIT, SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP, SMEM as R4_SMEM, DS as R4_DS,
VOPDOp as R4_VOPDOp, HWREG as HWREG_RDNA4)
VOPDOp as R4_VOPDOp, HWREG as HWREG_RDNA4, VFLAT as R4_FLAT, VGLOBAL as R4_GLOBAL, VSCRATCH as R4_SCRATCH)
from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT, HWREG as HWREG_CDNA
def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__
def _is_r4(inst: Inst) -> bool: return 'rdna4' in inst.__class__.__module__
# CDNA opcode name aliases for disasm (new name -> old name expected by tests)
_CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'}
@@ -303,27 +304,30 @@ def _disasm_smem(inst: SMEM) -> str:
return f"{name} {_fmt_sdst(inst.sdata, dst_n, cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc"))
def _disasm_flat(inst: FLAT) -> str:
name, cdna = inst.op_name.lower(), _is_cdna(inst)
name, cdna, r4 = inst.op_name.lower(), _is_cdna(inst), _is_r4(inst)
acc = getattr(inst, 'acc', 0)
reg_fn = _areg if acc else _vreg
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
if r4: seg = 'flat' if (cls_name:=inst.__class__.__name__) == 'VFLAT' else ('global' if cls_name == 'VGLOBAL' else 'scratch')
else: seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
# Global/scratch uses 13-bit signed offset
offset = inst.ioffset if r4 else inst.offset
if seg != 'flat':
if cdna:
# CDNA: bit 12 is sign bit but not in offset field
raw = int.from_bytes(inst.to_bytes(), 'little')
off_val = inst.offset | ((raw >> 12) & 1) << 12 # get bit 12
off_val = offset | ((raw >> 12) & 1) << 12 # get bit 12
else:
off_val = inst.offset
off_val = offset
off_val = off_val if off_val < 4096 else off_val - 8192 # sign extend 13-bit
else:
off_val = inst.offset
off_val = offset
# Use get_field_bits: data for stores/atomics, d for loads
regs = inst.canonical_op_regs
w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'atomic' in name else regs.get('d', 1)
off_s = f" offset:{off_val}" if off_val else ""
if cdna: mods = f"{off_s}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if getattr(inst, 'sc1', 0) else ''}"
elif r4: mods = f"{off_s}{' scope' if inst.scope else ''}{' th' if inst.th else ''}"
else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
if seg == 'flat': saddr_s = ""
elif _unwrap(inst.saddr) in (0x7F, 124): saddr_s = ", off"
@@ -332,18 +336,21 @@ def _disasm_flat(inst: FLAT) -> str:
elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}"
else: saddr_s = f", {_sreg(inst.saddr, 2) if _unwrap(inst.saddr) < 106 else decode_src(_unwrap(inst.saddr), cdna)}"
if 'addtid' in name: return f"{instr} {reg_fn(inst.data if 'store' in name else inst.vdst)}{saddr_s}{mods}"
# RDNA4: vaddr instead of addr, vsrc instead of data
addr = inst.vaddr if r4 else inst.addr
data = inst.vsrc if r4 else inst.data
# load_lds_* instructions: vaddr, saddr (no vdst, data goes to LDS)
if 'load_lds' in name:
addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(addr, addr_w)
return f"{instr} {addr_s}{saddr_s}{mods}"
if seg == 'flat': addr_w = 2 # flat always uses 64-bit vaddr
elif cdna: addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2
else: addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w)
data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
glc_or_sc0 = inst.sc0 if cdna else inst.glc
addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(addr, addr_w)
data_s, vdst_s = reg_fn(data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w)
if 'atomic' in name:
glc_or_sc0 = inst.sc0 if cdna else inst.glc
return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}"
if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}"
return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
@@ -582,6 +589,7 @@ DISASM_HANDLERS: dict[type, callable] = {
R4_VOP2: _disasm_vop2, R4_VOP2_LIT: _disasm_vop2, R4_VOPC: _disasm_vopc, R4_VOPC_LIT: _disasm_vopc,
R4_VOP3: _disasm_vop3, R4_VOP3_SDST: _disasm_vop3, R4_VOP3_SDST_LIT: _disasm_vop3, R4_VOP3_LIT: _disasm_vop3,
R4_VOP3SD: _disasm_vop3sd, R4_VOP3SD_LIT: _disasm_vop3sd, R4_VOP3P: _disasm_vop3p, R4_VOP3P_LIT: _disasm_vop3p,
R4_FLAT: _disasm_flat, R4_GLOBAL: _disasm_flat, R4_SCRATCH: _disasm_flat,
R4_VOPD: _disasm_vopd, R4_VOPD_LIT: _disasm_vopd, R4_VINTERP: _disasm_vinterp, R4_SOPP: _disasm_sopp, R4_SMEM: _disasm_smem, R4_DS: _disasm_ds,
R4_SOP1: _disasm_sop1, R4_SOP1_LIT: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOP2_LIT: _disasm_sop2,
R4_SOPC: _disasm_sopc, R4_SOPC_LIT: _disasm_sopc, R4_SOPK: _disasm_sopk, R4_SOPK_LIT: _disasm_sopk}

View File

@@ -360,7 +360,7 @@ class Inst:
elif name in ('src1', 'vsrc1', 'ssrc1'): bits['s1'] = val
elif name == 'src2': bits['s2'] = val
elif name in ('vdst', 'sdst', 'sdata'): bits['d'] = val
elif name in ('data', 'vdata', 'data0'): bits['data'] = val
elif name in ('data', 'vdata', 'data0', 'vsrc'): bits['data'] = val
return bits
@property
def canonical_op_regs(self) -> dict[str, int]:

View File

@@ -5,6 +5,7 @@ from dataclasses import dataclass
@dataclass
class KernelInfo:
code: bytes
src: str
global_size: tuple[int, int, int]
local_size: tuple[int, int, int]
buf_idxs: list[int] # indices into shared buffer pool

View File

@@ -324,6 +324,7 @@ def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int],
buf_sizes.append(b.nbytes)
kernels.append(KernelInfo(
code=bytes(sec.content),
src=lowered.prg.p.src,
global_size=tuple(lowered.prg.p.global_size),
local_size=tuple(lowered.prg.p.local_size),
buf_idxs=buf_idxs,

View File

@@ -83,17 +83,21 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
arch = self.arch
from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
from tinygrad.helpers import AMD_LLVM
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
compiler = HIPCompiler(get_target(arch))
# rendered source can be C or llvmir
compiler = (AMDLLVMCompiler if AMD_LLVM else HIPCompiler)(get_target(arch))
# First pass: decode all instructions and collect info
decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
for ki, kernel in enumerate(kernels):
offset = 0
while offset < len(kernel.code):
remaining = kernel.code[offset:]
code = next((s.content for s in elf_loader(compiler.compile(kernel.src))[1] if s.name == ".text"))
while offset < len(code):
remaining = code[offset:]
fmt = detect_format(remaining, arch)
if fmt is None:
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
@@ -234,7 +238,6 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
# Fused ops
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
@unittest.skip("RDNA4 decode roundtrip not yet supported")
class TestTinygradKernelRoundtripRDNA4(TestTinygradKernelRoundtrip): arch = 'rdna4'
@unittest.skip("CDNA decode roundtrip not yet supported")