From b913c910c55ed75bcee81a29a43e5ac9af2c6331 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 23 Jan 2026 07:33:53 -0500 Subject: [PATCH] assembly/amd: rdna4 passing test_roundtrip (#14300) * test_roundtrip on different archs * failing tests * take RDNA4 xml changes from the emu branch * work * min diff to disasm flat * test_add passes, rdna4 first * correct vgpr field for the multi dword store stuff * amdllvm * recompile in roundtrip, get sources from emulator * amdllvm, 2 * clean clean * note, don't rely on that os.environ --------- Co-authored-by: George Hotz --- extra/assembly/amd/amdxml.py | 9 ++++-- extra/assembly/amd/autogen/rdna4/ins.py | 32 +++++++++---------- extra/assembly/amd/decode.py | 12 ++++--- extra/assembly/amd/disasm.py | 28 ++++++++++------ extra/assembly/amd/dsl.py | 2 +- extra/assembly/amd/test/helpers.py | 1 + .../amd/test/test_compare_emulators.py | 1 + extra/assembly/amd/test/test_roundtrip.py | 13 +++++--- 8 files changed, 58 insertions(+), 40 deletions(-) diff --git a/extra/assembly/amd/amdxml.py b/extra/assembly/amd/amdxml.py index c82b44bcfd..2c3468e43f 100644 --- a/extra/assembly/amd/amdxml.py +++ b/extra/assembly/amd/amdxml.py @@ -247,7 +247,7 @@ def write_enum(enums, path): with open(path, "w") as f: f.write("\n".join(lines)) def write_ins(encodings, enums, lit_only_ops, types, arch, path): - _VGPR_FIELDS = {"vdst", "vdstx", "vsrc0", "vsrc1", "vsrc2", "vsrc3", "vsrcx1", "vsrcy1", "vaddr", "vdata", "data", "data0", "data1", "addr"} + _VGPR_FIELDS = {"vdst", "vdstx", "vsrc0", "vsrc1", "vsrc2", "vsrc3", "vsrcx1", "vsrcy1", "vaddr", "vdata", "data", "data0", "data1", "addr", "vsrc"} _VARIANT_SUFFIXES = ("_LIT", "_DPP16", "_DPP8", "_SDWA_SDST", "_SDWA", "_MFMA") def get_base_fmt(fmt): for sfx in _VARIANT_SUFFIXES: fmt = fmt.replace(sfx, "") @@ -310,7 +310,10 @@ def write_ins(encodings, enums, lit_only_ops, types, arch, path): all_ops = set(enums.get(enc_name, {}).keys()) # Exclude SDST ops from base class (they need VOP1_SDST/VOP3_SDST/VOP3B) base_allowed = all_ops - base_lit_ops - sdst_opcodes.get(enc_name, set()) - if enc_name in ("FLAT", "VFLAT"): + # RDNA3 FLAT/GLOBAL/SCRATCH share encoding bits, differentiated by seg field + # RDNA4 VFLAT/VGLOBAL/VSCRATCH have distinct encoding bits, no seg field needed + has_seg_field = any(fn == "seg" for fn, _, _ in fields) + if enc_name in ("FLAT", "VFLAT") and has_seg_field: prefix = "V" if enc_name == "VFLAT" else "" for cls, seg, op_enum in [(f"{prefix}FLAT", 0, f"{prefix}FLATOp"), (f"{prefix}GLOBAL", 2, f"{prefix}GLOBALOp"), (f"{prefix}SCRATCH", 1, f"{prefix}SCRATCHOp")]: cls_ops = set(enums.get(cls, {}).keys()) @@ -320,7 +323,7 @@ def write_ins(encodings, enums, lit_only_ops, types, arch, path): elif fn == "op": lines.append(f" op = EnumBitField({hi}, {lo}, {op_enum}, {fmt_allowed(op_enum, cls_ops)})") else: lines.append(f" {fn} = {field_def(fn, hi, lo, cls, enc_bits)}") lines.append("") - elif enc_name not in ("FLAT_GLOBAL", "FLAT_SCRATCH", "FLAT_GLBL", "VGLOBAL", "VSCRATCH", "DPP", "SDWA"): + elif enc_name not in ("FLAT_GLOBAL", "FLAT_SCRATCH", "FLAT_GLBL", "DPP", "SDWA"): lines.append(f"class {enc_name}(Inst):") for fn, hi, lo in sort_fields(fields): if fn == "op": diff --git a/extra/assembly/amd/autogen/rdna4/ins.py b/extra/assembly/amd/autogen/rdna4/ins.py index ae86b6efe2..207a10ba65 100644 --- a/extra/assembly/amd/autogen/rdna4/ins.py +++ b/extra/assembly/amd/autogen/rdna4/ins.py @@ -101,11 +101,11 @@ class VFLAT(Inst): sve = BitField(49, 49) scope = BitField(51, 50) th = BitField(54, 52) - vsrc = BitField(62, 55) + vsrc = VGPRField(62, 55) ioffset = BitField(95, 72) class VGLOBAL(Inst): - encoding = FixedBitField(31, 24, 0b11101100) + encoding = FixedBitField(31, 24, 0b11101110) op = EnumBitField(21, 14, VGLOBALOp, {VGLOBALOp.GLOBAL_LOAD_U8, VGLOBALOp.GLOBAL_LOAD_I8, VGLOBALOp.GLOBAL_LOAD_U16, VGLOBALOp.GLOBAL_LOAD_I16, VGLOBALOp.GLOBAL_LOAD_B32, VGLOBALOp.GLOBAL_LOAD_B64, VGLOBALOp.GLOBAL_LOAD_B96, VGLOBALOp.GLOBAL_LOAD_B128, VGLOBALOp.GLOBAL_STORE_B8, VGLOBALOp.GLOBAL_STORE_B16, VGLOBALOp.GLOBAL_STORE_B32, VGLOBALOp.GLOBAL_STORE_B64, VGLOBALOp.GLOBAL_STORE_B96, VGLOBALOp.GLOBAL_STORE_B128, VGLOBALOp.GLOBAL_LOAD_D16_U8, VGLOBALOp.GLOBAL_LOAD_D16_I8, VGLOBALOp.GLOBAL_LOAD_D16_B16, VGLOBALOp.GLOBAL_LOAD_D16_HI_U8, VGLOBALOp.GLOBAL_LOAD_D16_HI_I8, VGLOBALOp.GLOBAL_LOAD_D16_HI_B16, VGLOBALOp.GLOBAL_STORE_D16_HI_B8, VGLOBALOp.GLOBAL_STORE_D16_HI_B16, VGLOBALOp.GLOBAL_LOAD_ADDTID_B32, VGLOBALOp.GLOBAL_STORE_ADDTID_B32, VGLOBALOp.GLOBAL_INV, VGLOBALOp.GLOBAL_WB, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B32, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B32, VGLOBALOp.GLOBAL_ATOMIC_ADD_U32, VGLOBALOp.GLOBAL_ATOMIC_SUB_U32, VGLOBALOp.GLOBAL_ATOMIC_SUB_CLAMP_U32, VGLOBALOp.GLOBAL_ATOMIC_MIN_I32, VGLOBALOp.GLOBAL_ATOMIC_MIN_U32, VGLOBALOp.GLOBAL_ATOMIC_MAX_I32, VGLOBALOp.GLOBAL_ATOMIC_MAX_U32, VGLOBALOp.GLOBAL_ATOMIC_AND_B32, VGLOBALOp.GLOBAL_ATOMIC_OR_B32, VGLOBALOp.GLOBAL_ATOMIC_XOR_B32, VGLOBALOp.GLOBAL_ATOMIC_INC_U32, VGLOBALOp.GLOBAL_ATOMIC_DEC_U32, VGLOBALOp.GLOBAL_ATOMIC_SWAP_B64, VGLOBALOp.GLOBAL_ATOMIC_CMPSWAP_B64, VGLOBALOp.GLOBAL_ATOMIC_ADD_U64, VGLOBALOp.GLOBAL_ATOMIC_SUB_U64, VGLOBALOp.GLOBAL_ATOMIC_MIN_I64, VGLOBALOp.GLOBAL_ATOMIC_MIN_U64, VGLOBALOp.GLOBAL_ATOMIC_MAX_I64, VGLOBALOp.GLOBAL_ATOMIC_MAX_U64, VGLOBALOp.GLOBAL_ATOMIC_AND_B64, VGLOBALOp.GLOBAL_ATOMIC_OR_B64, VGLOBALOp.GLOBAL_ATOMIC_XOR_B64, VGLOBALOp.GLOBAL_ATOMIC_INC_U64, VGLOBALOp.GLOBAL_ATOMIC_DEC_U64, VGLOBALOp.GLOBAL_WBINV, VGLOBALOp.GLOBAL_ATOMIC_COND_SUB_U32, VGLOBALOp.GLOBAL_ATOMIC_MIN_NUM_F32, VGLOBALOp.GLOBAL_ATOMIC_MAX_NUM_F32, VGLOBALOp.GLOBAL_LOAD_BLOCK, VGLOBALOp.GLOBAL_STORE_BLOCK, VGLOBALOp.GLOBAL_ATOMIC_ADD_F32, VGLOBALOp.GLOBAL_LOAD_TR_B128, VGLOBALOp.GLOBAL_LOAD_TR_B64, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_F16, VGLOBALOp.GLOBAL_ATOMIC_PK_ADD_BF16, VGLOBALOp.GLOBAL_ATOMIC_ORDERED_ADD_B64}) vdst = VGPRField(39, 32) vaddr = VGPRField(71, 64) @@ -114,20 +114,7 @@ class VGLOBAL(Inst): sve = BitField(49, 49) scope = BitField(51, 50) th = BitField(54, 52) - vsrc = BitField(62, 55) - ioffset = BitField(95, 72) - -class VSCRATCH(Inst): - encoding = FixedBitField(31, 24, 0b11101100) - op = EnumBitField(21, 14, VSCRATCHOp, {VSCRATCHOp.SCRATCH_LOAD_U8, VSCRATCHOp.SCRATCH_LOAD_I8, VSCRATCHOp.SCRATCH_LOAD_U16, VSCRATCHOp.SCRATCH_LOAD_I16, VSCRATCHOp.SCRATCH_LOAD_B32, VSCRATCHOp.SCRATCH_LOAD_B64, VSCRATCHOp.SCRATCH_LOAD_B96, VSCRATCHOp.SCRATCH_LOAD_B128, VSCRATCHOp.SCRATCH_STORE_B8, VSCRATCHOp.SCRATCH_STORE_B16, VSCRATCHOp.SCRATCH_STORE_B32, VSCRATCHOp.SCRATCH_STORE_B64, VSCRATCHOp.SCRATCH_STORE_B96, VSCRATCHOp.SCRATCH_STORE_B128, VSCRATCHOp.SCRATCH_LOAD_D16_U8, VSCRATCHOp.SCRATCH_LOAD_D16_I8, VSCRATCHOp.SCRATCH_LOAD_D16_B16, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16, VSCRATCHOp.SCRATCH_LOAD_BLOCK, VSCRATCHOp.SCRATCH_STORE_BLOCK}) - vdst = VGPRField(39, 32) - vaddr = VGPRField(71, 64) - saddr = SGPRField(6, 0, default=NULL) - nv = BitField(7, 7) - sve = BitField(49, 49) - scope = BitField(51, 50) - th = BitField(54, 52) - vsrc = BitField(62, 55) + vsrc = VGPRField(62, 55) ioffset = BitField(95, 72) class VIMAGE(Inst): @@ -253,6 +240,19 @@ class VSAMPLE(Inst): vaddr2 = BitField(87, 80) vaddr3 = BitField(95, 88) +class VSCRATCH(Inst): + encoding = FixedBitField(31, 24, 0b11101101) + op = EnumBitField(21, 14, VSCRATCHOp, {VSCRATCHOp.SCRATCH_LOAD_U8, VSCRATCHOp.SCRATCH_LOAD_I8, VSCRATCHOp.SCRATCH_LOAD_U16, VSCRATCHOp.SCRATCH_LOAD_I16, VSCRATCHOp.SCRATCH_LOAD_B32, VSCRATCHOp.SCRATCH_LOAD_B64, VSCRATCHOp.SCRATCH_LOAD_B96, VSCRATCHOp.SCRATCH_LOAD_B128, VSCRATCHOp.SCRATCH_STORE_B8, VSCRATCHOp.SCRATCH_STORE_B16, VSCRATCHOp.SCRATCH_STORE_B32, VSCRATCHOp.SCRATCH_STORE_B64, VSCRATCHOp.SCRATCH_STORE_B96, VSCRATCHOp.SCRATCH_STORE_B128, VSCRATCHOp.SCRATCH_LOAD_D16_U8, VSCRATCHOp.SCRATCH_LOAD_D16_I8, VSCRATCHOp.SCRATCH_LOAD_D16_B16, VSCRATCHOp.SCRATCH_LOAD_D16_HI_U8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_I8, VSCRATCHOp.SCRATCH_LOAD_D16_HI_B16, VSCRATCHOp.SCRATCH_STORE_D16_HI_B8, VSCRATCHOp.SCRATCH_STORE_D16_HI_B16, VSCRATCHOp.SCRATCH_LOAD_BLOCK, VSCRATCHOp.SCRATCH_STORE_BLOCK}) + vdst = VGPRField(39, 32) + vaddr = VGPRField(71, 64) + saddr = SGPRField(6, 0, default=NULL) + nv = BitField(7, 7) + sve = BitField(49, 49) + scope = BitField(51, 50) + th = BitField(54, 52) + vsrc = VGPRField(62, 55) + ioffset = BitField(95, 72) + class SOP1_LIT(SOP1): op = EnumBitField(15, 8, SOP1Op, {SOP1Op.S_MOV_B32, SOP1Op.S_MOV_B64, SOP1Op.S_CMOV_B32, SOP1Op.S_CMOV_B64, SOP1Op.S_BREV_B32, SOP1Op.S_BREV_B64, SOP1Op.S_CTZ_I32_B32, SOP1Op.S_CTZ_I32_B64, SOP1Op.S_CLZ_I32_U32, SOP1Op.S_CLZ_I32_U64, SOP1Op.S_CLS_I32, SOP1Op.S_CLS_I32_I64, SOP1Op.S_SEXT_I32_I8, SOP1Op.S_SEXT_I32_I16, SOP1Op.S_BITSET0_B32, SOP1Op.S_BITSET0_B64, SOP1Op.S_BITSET1_B32, SOP1Op.S_BITSET1_B64, SOP1Op.S_BITREPLICATE_B64_B32, SOP1Op.S_ABS_I32, SOP1Op.S_BCNT0_I32_B32, SOP1Op.S_BCNT0_I32_B64, SOP1Op.S_BCNT1_I32_B32, SOP1Op.S_BCNT1_I32_B64, SOP1Op.S_QUADMASK_B32, SOP1Op.S_QUADMASK_B64, SOP1Op.S_WQM_B32, SOP1Op.S_WQM_B64, SOP1Op.S_NOT_B32, SOP1Op.S_NOT_B64, SOP1Op.S_AND_SAVEEXEC_B32, SOP1Op.S_AND_SAVEEXEC_B64, SOP1Op.S_OR_SAVEEXEC_B32, SOP1Op.S_OR_SAVEEXEC_B64, SOP1Op.S_XOR_SAVEEXEC_B32, SOP1Op.S_XOR_SAVEEXEC_B64, SOP1Op.S_NAND_SAVEEXEC_B32, SOP1Op.S_NAND_SAVEEXEC_B64, SOP1Op.S_NOR_SAVEEXEC_B32, SOP1Op.S_NOR_SAVEEXEC_B64, SOP1Op.S_XNOR_SAVEEXEC_B32, SOP1Op.S_XNOR_SAVEEXEC_B64, SOP1Op.S_AND_NOT0_SAVEEXEC_B32, SOP1Op.S_AND_NOT0_SAVEEXEC_B64, SOP1Op.S_OR_NOT0_SAVEEXEC_B32, SOP1Op.S_OR_NOT0_SAVEEXEC_B64, SOP1Op.S_AND_NOT1_SAVEEXEC_B32, SOP1Op.S_AND_NOT1_SAVEEXEC_B64, SOP1Op.S_OR_NOT1_SAVEEXEC_B32, SOP1Op.S_OR_NOT1_SAVEEXEC_B64, SOP1Op.S_AND_NOT0_WREXEC_B32, SOP1Op.S_AND_NOT0_WREXEC_B64, SOP1Op.S_AND_NOT1_WREXEC_B32, SOP1Op.S_AND_NOT1_WREXEC_B64, SOP1Op.S_MOVRELS_B32, SOP1Op.S_MOVRELS_B64, SOP1Op.S_MOVRELD_B32, SOP1Op.S_MOVRELD_B64, SOP1Op.S_MOVRELSD_2_B32, SOP1Op.S_GETPC_B64, SOP1Op.S_SETPC_B64, SOP1Op.S_SWAPPC_B64, SOP1Op.S_RFE_B64, SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64, SOP1Op.S_BARRIER_SIGNAL, SOP1Op.S_BARRIER_SIGNAL_ISFIRST, SOP1Op.S_GET_BARRIER_STATE, SOP1Op.S_BARRIER_INIT, SOP1Op.S_BARRIER_JOIN, SOP1Op.S_ALLOC_VGPR, SOP1Op.S_SLEEP_VAR, SOP1Op.S_CEIL_F32, SOP1Op.S_FLOOR_F32, SOP1Op.S_TRUNC_F32, SOP1Op.S_RNDNE_F32, SOP1Op.S_CVT_F32_I32, SOP1Op.S_CVT_F32_U32, SOP1Op.S_CVT_I32_F32, SOP1Op.S_CVT_U32_F32, SOP1Op.S_CVT_F16_F32, SOP1Op.S_CVT_F32_F16, SOP1Op.S_CVT_HI_F32_F16, SOP1Op.S_CEIL_F16, SOP1Op.S_FLOOR_F16, SOP1Op.S_TRUNC_F16, SOP1Op.S_RNDNE_F16}) literal = BitField(63, 32) diff --git a/extra/assembly/amd/decode.py b/extra/assembly/amd/decode.py index 6fe5803c66..324e92e098 100644 --- a/extra/assembly/amd/decode.py +++ b/extra/assembly/amd/decode.py @@ -29,10 +29,11 @@ def _matches(data: bytes, cls: type[Inst]) -> bool: # Import instruction classes for each architecture from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP1_SDST, VOP1_LIT, VOP2, VOP2_LIT, VOP3, VOP3_SDST, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, SMEM, DS, FLAT, GLOBAL, SCRATCH) -from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP2 as R4_VOP2, VOP2_LIT as R4_VOP2_LIT, - VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P, - VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT, - SOPC as R4_SOPC, SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP, +from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as R4_VOP1_SDST, VOP1_LIT as R4_VOP1_LIT, + VOP2 as R4_VOP2, VOP2_LIT as R4_VOP2_LIT, VOP3 as R4_VOP3, VOP3_SDST as R4_VOP3_SDST, VOP3SD as R4_VOP3SD, VOP3P as R4_VOP3P, + VOPC as R4_VOPC, VOPD as R4_VOPD, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT, + SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT, SOPC as R4_SOPC, SOPC_LIT as R4_SOPC_LIT, + SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP, SMEM as R4_SMEM, DS as R4_DS, VFLAT as R4_FLAT, VGLOBAL as R4_GLOBAL, VSCRATCH as R4_SCRATCH) from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP1_SDWA as C_VOP1_SDWA, VOP1_DPP16 as C_VOP1_DPP16, VOP2 as C_VOP2, VOP2_LIT as C_VOP2_LIT, VOP2_SDWA as C_VOP2_SDWA, VOP2_DPP16 as C_VOP2_DPP16, @@ -47,7 +48,8 @@ _FORMATS = { "rdna3": [VOPD, VOP3P, VINTERP, VOP3SD, VOP3_SDST, VOP3, DS, GLOBAL, SCRATCH, FLAT, SMEM, SOP1, SOP1_LIT, SOP2, SOP2_LIT, SOPC, SOPK, SOPK_LIT, SOPP, VOPC, VOP1_SDST, VOP1, VOP1_LIT, VOP2, VOP2_LIT], "rdna4": [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3SD, R4_VOP3_SDST, R4_VOP3, R4_DS, R4_GLOBAL, R4_SCRATCH, R4_FLAT, R4_SMEM, - R4_SOP1, R4_SOPC, R4_SOPP, R4_SOPK, R4_SOPK_LIT, R4_VOPC, R4_VOP1_SDST, R4_VOP1, R4_SOP2, R4_SOP2_LIT, R4_VOP2, R4_VOP2_LIT], + R4_SOP1, R4_SOP1_LIT, R4_SOPC, R4_SOPC_LIT, R4_SOPP, R4_SOPK, R4_SOPK_LIT, R4_VOPC, R4_VOP1_SDST, R4_VOP1, R4_VOP1_LIT, + R4_SOP2, R4_SOP2_LIT, R4_VOP2, R4_VOP2_LIT], "cdna": [C_VOP3PX2, C_VOP3P, C_VOP3SD, C_VOP3_SDST, C_VOP3, C_DS, C_GLOBAL, C_SCRATCH, C_FLAT, C_MUBUF, C_SMEM, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_SOPK_LIT, C_VOPC_SDWA_SDST, C_VOPC, C_VOP1_DPP16, C_VOP1_SDWA, C_VOP1, C_VOP2_DPP16, C_VOP2_SDWA, C_SOP2, C_VOP2, C_VOP2_LIT], diff --git a/extra/assembly/amd/disasm.py b/extra/assembly/amd/disasm.py index 4628fb25bc..86bf449d3e 100644 --- a/extra/assembly/amd/disasm.py +++ b/extra/assembly/amd/disasm.py @@ -84,10 +84,11 @@ from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as VOP3SD as R4_VOP3SD, VOP3SD_LIT as R4_VOP3SD_LIT, VOP3P as R4_VOP3P, VOP3P_LIT as R4_VOP3P_LIT, VOPC as R4_VOPC, VOPC_LIT as R4_VOPC_LIT, VOPD as R4_VOPD, VOPD_LIT as R4_VOPD_LIT, VINTERP as R4_VINTERP, SOP1 as R4_SOP1, SOP1_LIT as R4_SOP1_LIT, SOP2 as R4_SOP2, SOP2_LIT as R4_SOP2_LIT, SOPC as R4_SOPC, SOPC_LIT as R4_SOPC_LIT, SOPK as R4_SOPK, SOPK_LIT as R4_SOPK_LIT, SOPP as R4_SOPP, SMEM as R4_SMEM, DS as R4_DS, - VOPDOp as R4_VOPDOp, HWREG as HWREG_RDNA4) + VOPDOp as R4_VOPDOp, HWREG as HWREG_RDNA4, VFLAT as R4_FLAT, VGLOBAL as R4_GLOBAL, VSCRATCH as R4_SCRATCH) from extra.assembly.amd.autogen.cdna.ins import FLAT as C_FLAT, HWREG as HWREG_CDNA def _is_cdna(inst: Inst) -> bool: return 'cdna' in inst.__class__.__module__ +def _is_r4(inst: Inst) -> bool: return 'rdna4' in inst.__class__.__module__ # CDNA opcode name aliases for disasm (new name -> old name expected by tests) _CDNA_DISASM_ALIASES = {'v_fmac_f64': 'v_mul_legacy_f32', 'v_dot2c_f32_bf16': 'v_mac_f32', 'v_fmamk_f32': 'v_madmk_f32', 'v_fmaak_f32': 'v_madak_f32'} @@ -303,27 +304,30 @@ def _disasm_smem(inst: SMEM) -> str: return f"{name} {_fmt_sdst(inst.sdata, dst_n, cdna)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (getattr(inst, 'dlc', 0), " dlc")) def _disasm_flat(inst: FLAT) -> str: - name, cdna = inst.op_name.lower(), _is_cdna(inst) + name, cdna, r4 = inst.op_name.lower(), _is_cdna(inst), _is_r4(inst) acc = getattr(inst, 'acc', 0) reg_fn = _areg if acc else _vreg - seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' + if r4: seg = 'flat' if (cls_name:=inst.__class__.__name__) == 'VFLAT' else ('global' if cls_name == 'VGLOBAL' else 'scratch') + else: seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat' instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}" # Global/scratch uses 13-bit signed offset + offset = inst.ioffset if r4 else inst.offset if seg != 'flat': if cdna: # CDNA: bit 12 is sign bit but not in offset field raw = int.from_bytes(inst.to_bytes(), 'little') - off_val = inst.offset | ((raw >> 12) & 1) << 12 # get bit 12 + off_val = offset | ((raw >> 12) & 1) << 12 # get bit 12 else: - off_val = inst.offset + off_val = offset off_val = off_val if off_val < 4096 else off_val - 8192 # sign extend 13-bit else: - off_val = inst.offset + off_val = offset # Use get_field_bits: data for stores/atomics, d for loads regs = inst.canonical_op_regs w = regs.get('data', regs.get('d', 1)) if 'store' in name or 'atomic' in name else regs.get('d', 1) off_s = f" offset:{off_val}" if off_val else "" if cdna: mods = f"{off_s}{' sc0' if inst.sc0 else ''}{' nt' if inst.nt else ''}{' sc1' if getattr(inst, 'sc1', 0) else ''}" + elif r4: mods = f"{off_s}{' scope' if inst.scope else ''}{' th' if inst.th else ''}" else: mods = f"{off_s}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}" if seg == 'flat': saddr_s = "" elif _unwrap(inst.saddr) in (0x7F, 124): saddr_s = ", off" @@ -332,18 +336,21 @@ def _disasm_flat(inst: FLAT) -> str: elif t := _ttmp(inst.saddr, 2): saddr_s = f", {t}" else: saddr_s = f", {_sreg(inst.saddr, 2) if _unwrap(inst.saddr) < 106 else decode_src(_unwrap(inst.saddr), cdna)}" if 'addtid' in name: return f"{instr} {reg_fn(inst.data if 'store' in name else inst.vdst)}{saddr_s}{mods}" + # RDNA4: vaddr instead of addr, vsrc instead of data + addr = inst.vaddr if r4 else inst.addr + data = inst.vsrc if r4 else inst.data # load_lds_* instructions: vaddr, saddr (no vdst, data goes to LDS) if 'load_lds' in name: addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2 - addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w) + addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(addr, addr_w) return f"{instr} {addr_s}{saddr_s}{mods}" if seg == 'flat': addr_w = 2 # flat always uses 64-bit vaddr elif cdna: addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2 else: addr_w = 1 if seg == 'scratch' or (_unwrap(inst.saddr) not in (0x7F, 124)) else 2 - addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(inst.addr, addr_w) - data_s, vdst_s = reg_fn(inst.data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w) - glc_or_sc0 = inst.sc0 if cdna else inst.glc + addr_s = "off" if not inst.sve and seg == 'scratch' else _vreg(addr, addr_w) + data_s, vdst_s = reg_fn(data, w), reg_fn(inst.vdst, w // 2 if 'cmpswap' in name else w) if 'atomic' in name: + glc_or_sc0 = inst.sc0 if cdna else inst.glc return f"{instr} {vdst_s}, {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if glc_or_sc0 else f"{instr} {addr_s}, {data_s}{saddr_s if seg != 'flat' else ''}{mods}" if 'store' in name: return f"{instr} {addr_s}, {data_s}{saddr_s}{mods}" return f"{instr} {reg_fn(inst.vdst, w)}, {addr_s}{saddr_s}{mods}" @@ -582,6 +589,7 @@ DISASM_HANDLERS: dict[type, callable] = { R4_VOP2: _disasm_vop2, R4_VOP2_LIT: _disasm_vop2, R4_VOPC: _disasm_vopc, R4_VOPC_LIT: _disasm_vopc, R4_VOP3: _disasm_vop3, R4_VOP3_SDST: _disasm_vop3, R4_VOP3_SDST_LIT: _disasm_vop3, R4_VOP3_LIT: _disasm_vop3, R4_VOP3SD: _disasm_vop3sd, R4_VOP3SD_LIT: _disasm_vop3sd, R4_VOP3P: _disasm_vop3p, R4_VOP3P_LIT: _disasm_vop3p, + R4_FLAT: _disasm_flat, R4_GLOBAL: _disasm_flat, R4_SCRATCH: _disasm_flat, R4_VOPD: _disasm_vopd, R4_VOPD_LIT: _disasm_vopd, R4_VINTERP: _disasm_vinterp, R4_SOPP: _disasm_sopp, R4_SMEM: _disasm_smem, R4_DS: _disasm_ds, R4_SOP1: _disasm_sop1, R4_SOP1_LIT: _disasm_sop1, R4_SOP2: _disasm_sop2, R4_SOP2_LIT: _disasm_sop2, R4_SOPC: _disasm_sopc, R4_SOPC_LIT: _disasm_sopc, R4_SOPK: _disasm_sopk, R4_SOPK_LIT: _disasm_sopk} diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 82203a9e97..0594bd403d 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -360,7 +360,7 @@ class Inst: elif name in ('src1', 'vsrc1', 'ssrc1'): bits['s1'] = val elif name == 'src2': bits['s2'] = val elif name in ('vdst', 'sdst', 'sdata'): bits['d'] = val - elif name in ('data', 'vdata', 'data0'): bits['data'] = val + elif name in ('data', 'vdata', 'data0', 'vsrc'): bits['data'] = val return bits @property def canonical_op_regs(self) -> dict[str, int]: diff --git a/extra/assembly/amd/test/helpers.py b/extra/assembly/amd/test/helpers.py index 0f46a3f702..fa6a1d6c95 100644 --- a/extra/assembly/amd/test/helpers.py +++ b/extra/assembly/amd/test/helpers.py @@ -5,6 +5,7 @@ from dataclasses import dataclass @dataclass class KernelInfo: code: bytes + src: str global_size: tuple[int, int, int] local_size: tuple[int, int, int] buf_idxs: list[int] # indices into shared buffer pool diff --git a/extra/assembly/amd/test/test_compare_emulators.py b/extra/assembly/amd/test/test_compare_emulators.py index 8520946a9e..88f0df913f 100644 --- a/extra/assembly/amd/test/test_compare_emulators.py +++ b/extra/assembly/amd/test/test_compare_emulators.py @@ -324,6 +324,7 @@ def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], buf_sizes.append(b.nbytes) kernels.append(KernelInfo( code=bytes(sec.content), + src=lowered.prg.p.src, global_size=tuple(lowered.prg.p.global_size), local_size=tuple(lowered.prg.p.local_size), buf_idxs=buf_idxs, diff --git a/extra/assembly/amd/test/test_roundtrip.py b/extra/assembly/amd/test/test_roundtrip.py index 14063e49c2..ad20624404 100644 --- a/extra/assembly/amd/test/test_roundtrip.py +++ b/extra/assembly/amd/test/test_roundtrip.py @@ -83,17 +83,21 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): arch = self.arch from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad - from tinygrad.runtime.support.compiler_amd import HIPCompiler + from tinygrad.runtime.support.elf import elf_loader + from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler + from tinygrad.helpers import AMD_LLVM kernels, _, _ = get_kernels_from_tinygrad(op_fn) - compiler = HIPCompiler(get_target(arch)) + # rendered source can be C or llvmir + compiler = (AMDLLVMCompiler if AMD_LLVM else HIPCompiler)(get_target(arch)) # First pass: decode all instructions and collect info decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) for ki, kernel in enumerate(kernels): offset = 0 - while offset < len(kernel.code): - remaining = kernel.code[offset:] + code = next((s.content for s in elf_loader(compiler.compile(kernel.src))[1] if s.name == ".text")) + while offset < len(code): + remaining = code[offset:] fmt = detect_format(remaining, arch) if fmt is None: decoded_instrs.append((ki, offset, None, None, None, False, "no format")) @@ -234,7 +238,6 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): # Fused ops def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0]))) -@unittest.skip("RDNA4 decode roundtrip not yet supported") class TestTinygradKernelRoundtripRDNA4(TestTinygradKernelRoundtrip): arch = 'rdna4' @unittest.skip("CDNA decode roundtrip not yet supported")