From 92bfe92138f9e315c1aa300e0508bdd69fe54904 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 26 Jan 2026 03:51:26 -0500 Subject: [PATCH] assembly/amd: fix cdna mfma xml (#14329) * handwritten failing test * new amdxml * more mfma from fixes * ci * move arch of test integration * alt * amdxml human cleanup * _TestIntegration rename to IntegrationTestBase * it's the same problem as _LIT * better comment * better variable name --- extra/assembly/amd/amdxml.py | 42 +++--- extra/assembly/amd/autogen/cdna/ins.py | 135 ++++++++++---------- extra/assembly/amd/decode.py | 4 +- extra/assembly/amd/disasm.py | 4 +- extra/assembly/amd/test/test_handwritten.py | 19 ++- 5 files changed, 112 insertions(+), 92 deletions(-) diff --git a/extra/assembly/amd/amdxml.py b/extra/assembly/amd/amdxml.py index f87b29b1b3..1f0298486e 100644 --- a/extra/assembly/amd/amdxml.py +++ b/extra/assembly/amd/amdxml.py @@ -126,13 +126,13 @@ def parse_xml(filename: str): if fmt and fmt not in fmts: fmts[fmt] = 0 if otype: op_types_set.add(otype) if op_info: types[(name, base_enum)] = op_info - # Find opcodes that only exist in _LIT encoding (no base format version) - lit_only_ops: dict[str, set[int]] = {} + # Find opcodes that only exist in a specific variant encoding (no base format version) + suffix_only_ops: dict[str, dict[str, set[int]]] = {} # {suffix: {base_fmt: {opcodes}}} for base_fmt, opcodes in opcode_encs.items(): for opcode, encs in opcodes.items(): - if all("_LIT" in e for e in encs): - lit_only_ops.setdefault(base_fmt, set()).add(opcode) - return encodings, enums, types, fmts, op_types_set, lit_only_ops + suffix = next((s for s in _ENC_SUFFIX_MAP.values() if all(s in e for e in encs)), None) + if suffix is not None: suffix_only_ops.setdefault(suffix, {}).setdefault(base_fmt, set()).add(opcode) + return encodings, enums, types, fmts, op_types_set, suffix_only_ops # ═══════════════════════════════════════════════════════════════════════════════ # PDF parsing @@ -252,7 +252,7 @@ def write_enum(enums, path): lines.append("") with open(path, "w") as f: f.write("\n".join(lines)) -def write_ins(encodings, enums, lit_only_ops, types, arch, path): +def write_ins(encodings, enums, suffix_only_ops, types, arch, path): _VGPR_FIELDS = {"vdst", "vdstx", "vsrc0", "vsrc1", "vsrc2", "vsrc3", "vsrcx1", "vsrcy1", "vaddr", "vdata", "data", "data0", "data1", "addr", "vsrc"} _VARIANT_SUFFIXES = ("_LIT", "_DPP16", "_DPP8", "_SDWA_SDST", "_SDWA", "_MFMA") def get_base_fmt(fmt): @@ -313,11 +313,11 @@ def write_ins(encodings, enums, lit_only_ops, types, arch, path): # Generate base classes first for enc_name, (fields, enc_bits) in sorted(base_encodings.items()): - # Get lit-only ops for this format (these can't be used in base class) - base_lit_ops = lit_only_ops.get(enc_name, set()) all_ops = set(enums.get(enc_name, {}).keys()) + # Get suffix-only ops for this format (these can't be used in base class) + base_suffix_ops = set().union(*(d.get(enc_name, set()) for d in suffix_only_ops.values())) # Exclude SDST ops from base class (they need VOP1_SDST/VOP3_SDST/VOP3B) - base_allowed = all_ops - base_lit_ops - sdst_opcodes.get(enc_name, set()) + base_allowed = all_ops - base_suffix_ops - sdst_opcodes.get(enc_name, set()) # RDNA3 FLAT/GLOBAL/SCRATCH share encoding bits, differentiated by seg field # RDNA4 VFLAT/VGLOBAL/VSCRATCH have distinct encoding bits, no seg field needed has_seg_field = any(fn == "seg" for fn, _, _ in fields) @@ -347,15 +347,19 @@ def write_ins(encodings, enums, lit_only_ops, types, arch, path): if base not in base_encodings: continue # skip if no base class base_fields = {f[0] for f in base_encodings[base][0]} extra_fields = [(fn, hi, lo) for fn, hi, lo in fields if fn not in base_fields] - is_lit = enc_name.endswith("_LIT") + # Check if this is a suffix-only variant + variant_suffix = next((sfx for sfx in _VARIANT_SUFFIXES if enc_name.endswith(sfx)), None) + is_suffix_variant = variant_suffix in suffix_only_ops all_ops = set(enums.get(base, {}).keys()) - if extra_fields or is_lit: + if extra_fields or is_suffix_variant: lines.append(f"class {enc_name}({base}):") op_field = next((f for f in base_encodings[base][0] if f[0] == "op"), None) # _LIT classes: override op to allow all opcodes (base excludes lit-only ops) - if op_field and is_lit: + # other classes override op to only suffix-only opcodes + if op_field and is_suffix_variant: _, hi, lo = op_field - lines.append(f" op = EnumBitField({hi}, {lo}, {base}Op, {fmt_allowed(f'{base}Op', all_ops)})") + allowed_ops = all_ops if variant_suffix == "_LIT" else suffix_only_ops[variant_suffix][base] + lines.append(f" op = EnumBitField({hi}, {lo}, {base}Op, {fmt_allowed(f'{base}Op', allowed_ops)})") for fn, hi, lo in sort_fields(extra_fields): lines.append(f" {fn} = {field_def(fn, hi, lo, enc_name)}") lines.append("") @@ -389,14 +393,14 @@ def write_ins(encodings, enums, lit_only_ops, types, arch, path): for fmt, ops in sorted(enums.items()): if fmt not in base_encodings and fmt not in ("GLOBAL", "SCRATCH", "VGLOBAL", "VSCRATCH"): continue suffix = "_E32" if fmt in ("VOP1", "VOP2", "VOPC") else "_E64" if fmt == "VOP3" else "" - lit_ops = lit_only_ops.get(fmt, set()) + op_to_suffix = {op:suffix for suffix,ops in suffix_only_ops.items() for op in ops.get(fmt, set())} fmt_sdst_ops = sdst_opcodes.get(fmt, set()) for op, name in sorted(ops.items()): msuf = suffix if fmt != "VOP3" or op < 512 else "" - # Determine class: SDST variants, LIT-only instructions, or base + # Determine class: SDST variants, suffix-specific variants (e.g., _MFMA, _LIT), or base if fmt == "VOP1" and op in fmt_sdst_ops: cls = "VOP1_SDST" elif fmt == "VOP3" and (op in fmt_sdst_ops or op < 256): cls = "VOP3_SDST" - elif op in lit_ops: cls = f"{fmt}_LIT" + elif op_to_suffix.get(op): cls = f"{fmt}{op_to_suffix[op]}" else: cls = fmt lines.append(f"{name.lower()}{msuf.lower()} = functools.partial({cls}, {fmt}Op.{name}{msuf})") with open(path, "w") as f: f.write("\n".join(lines)) @@ -445,11 +449,11 @@ if __name__ == "__main__": # First pass: parse XML for all architectures for arch, cfg in ARCHS.items(): print(f"Parsing XML: {cfg['xml']} -> {arch}") - encodings, enums, types, fmts, op_types_set, lit_only_ops = parse_xml(cfg["xml"]) + encodings, enums, types, fmts, op_types_set, suffix_only_ops = parse_xml(cfg["xml"]) for fmt, ops in FIXES.get(arch, {}).items(): enums.setdefault(fmt, {}).update(ops) for fmt, fields in FIELD_FIXES.get(arch, {}).items(): if fmt in encodings: encodings[fmt] = (encodings[fmt][0] + fields, encodings[fmt][1]) - arch_data[arch] = {"encodings": encodings, "enums": enums, "types": types, "lit_only_ops": lit_only_ops} + arch_data[arch] = {"encodings": encodings, "enums": enums, "types": types, "suffix_only_ops": suffix_only_ops} for fmt, bits in fmts.items(): assert fmt not in all_fmts or all_fmts[fmt] == bits, f"FMT_BITS mismatch for {fmt}: {all_fmts[fmt]} vs {bits}" all_fmts[fmt] = bits @@ -462,7 +466,7 @@ if __name__ == "__main__": for arch, data in arch_data.items(): base = pathlib.Path(__file__).parent / "autogen" / arch write_enum(data["enums"], base / "enum.py") - write_ins(data["encodings"], data["enums"], data["lit_only_ops"], data["types"], arch, base / "ins.py") + write_ins(data["encodings"], data["enums"], data["suffix_only_ops"], data["types"], arch, base / "ins.py") write_operands(data["types"], data["enums"], arch, base / "operands.py") print(f" {arch}: {len(data['encodings'])} encodings, {sum(len(v) for v in data['enums'].values())} instructions") # Second pass: parse PDFs and write pcode diff --git a/extra/assembly/amd/autogen/cdna/ins.py b/extra/assembly/amd/autogen/cdna/ins.py index 171b9e270c..b2eda3e52a 100644 --- a/extra/assembly/amd/autogen/cdna/ins.py +++ b/extra/assembly/amd/autogen/cdna/ins.py @@ -164,7 +164,7 @@ class VOP3(Inst): class VOP3P(Inst): encoding = FixedBitField(31, 23, 0b110100111) - op = EnumBitField(22, 16, VOP3POp, {VOP3POp.V_PK_MAD_I16, VOP3POp.V_PK_MUL_LO_U16, VOP3POp.V_PK_ADD_I16, VOP3POp.V_PK_SUB_I16, VOP3POp.V_PK_LSHLREV_B16, VOP3POp.V_PK_LSHRREV_B16, VOP3POp.V_PK_ASHRREV_I16, VOP3POp.V_PK_MAX_I16, VOP3POp.V_PK_MIN_I16, VOP3POp.V_PK_MAD_U16, VOP3POp.V_PK_ADD_U16, VOP3POp.V_PK_SUB_U16, VOP3POp.V_PK_MAX_U16, VOP3POp.V_PK_MIN_U16, VOP3POp.V_PK_FMA_F16, VOP3POp.V_PK_ADD_F16, VOP3POp.V_PK_MUL_F16, VOP3POp.V_PK_MIN_F16, VOP3POp.V_PK_MAX_F16, VOP3POp.V_DOT2_F32_BF16, VOP3POp.V_PK_MINIMUM3_F16, VOP3POp.V_PK_MAXIMUM3_F16, VOP3POp.V_MAD_MIX_F32, VOP3POp.V_MAD_MIXLO_F16, VOP3POp.V_MAD_MIXHI_F16, VOP3POp.V_DOT2_F32_F16, VOP3POp.V_DOT2_I32_I16, VOP3POp.V_DOT2_U32_U16, VOP3POp.V_DOT4_I32_I8, VOP3POp.V_DOT4_U32_U8, VOP3POp.V_DOT8_I32_I4, VOP3POp.V_DOT8_U32_U4, VOP3POp.V_MFMA_LD_SCALE_B32, VOP3POp.V_MFMA_F32_16X16X128_F8F6F4, VOP3POp.V_MFMA_F32_32X32X64_F8F6F4, VOP3POp.V_PK_FMA_F32, VOP3POp.V_PK_MUL_F32, VOP3POp.V_PK_ADD_F32, VOP3POp.V_PK_MOV_B32, VOP3POp.V_MFMA_F32_16X16X32_BF16, VOP3POp.V_MFMA_I32_16X16X64_I8, VOP3POp.V_MFMA_F32_32X32X16_BF16, VOP3POp.V_MFMA_I32_32X32X32_I8, VOP3POp.V_SMFMAC_F32_16X16X64_BF16, VOP3POp.V_SMFMAC_I32_16X16X128_I8, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_BF8, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_FP8, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_BF8, VOP3POp.V_MFMA_F32_16X16X8_XF32, VOP3POp.V_MFMA_F32_32X32X4_XF32, VOP3POp.V_MFMA_F32_32X32X1_2B_F32, VOP3POp.V_MFMA_F32_16X16X1_4B_F32, VOP3POp.V_MFMA_F32_4X4X1_16B_F32, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_FP8, VOP3POp.V_MFMA_F32_32X32X2_F32, VOP3POp.V_MFMA_F32_16X16X4_F32, VOP3POp.V_SMFMAC_F32_32X32X32_BF16, VOP3POp.V_SMFMAC_I32_32X32X64_I8, VOP3POp.V_MFMA_F32_32X32X4_2B_F16, VOP3POp.V_MFMA_F32_16X16X4_4B_F16, VOP3POp.V_MFMA_F32_4X4X4_16B_F16, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_BF8, VOP3POp.V_MFMA_F32_32X32X8_F16, VOP3POp.V_MFMA_F32_16X16X16_F16, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_FP8, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_BF8, VOP3POp.V_MFMA_I32_32X32X4_2B_I8, VOP3POp.V_MFMA_I32_16X16X4_4B_I8, VOP3POp.V_MFMA_I32_4X4X4_16B_I8, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_FP8, VOP3POp.V_MFMA_F32_16X16X32_F16, VOP3POp.V_MFMA_F32_32X32X16_F16, VOP3POp.V_MFMA_I32_32X32X16_I8, VOP3POp.V_MFMA_I32_16X16X32_I8, VOP3POp.V_ACCVGPR_READ, VOP3POp.V_ACCVGPR_WRITE, VOP3POp.V_SMFMAC_F32_16X16X64_F16, VOP3POp.V_SMFMAC_F32_32X32X32_F16, VOP3POp.V_MFMA_F32_32X32X4_2B_BF16, VOP3POp.V_MFMA_F32_16X16X4_4B_BF16, VOP3POp.V_MFMA_F32_4X4X4_16B_BF16, VOP3POp.V_MFMA_F32_32X32X8_BF16, VOP3POp.V_MFMA_F32_16X16X16_BF16, VOP3POp.V_SMFMAC_F32_16X16X32_F16, VOP3POp.V_SMFMAC_F32_32X32X16_F16, VOP3POp.V_SMFMAC_F32_16X16X32_BF16, VOP3POp.V_SMFMAC_F32_32X32X16_BF16, VOP3POp.V_SMFMAC_I32_16X16X64_I8, VOP3POp.V_SMFMAC_I32_32X32X32_I8, VOP3POp.V_MFMA_F64_16X16X4_F64, VOP3POp.V_MFMA_F64_4X4X4_4B_F64, VOP3POp.V_MFMA_F32_16X16X32_BF8_BF8, VOP3POp.V_MFMA_F32_16X16X32_BF8_FP8, VOP3POp.V_MFMA_F32_16X16X32_FP8_BF8, VOP3POp.V_MFMA_F32_16X16X32_FP8_FP8, VOP3POp.V_MFMA_F32_32X32X16_BF8_BF8, VOP3POp.V_MFMA_F32_32X32X16_BF8_FP8, VOP3POp.V_MFMA_F32_32X32X16_FP8_BF8, VOP3POp.V_MFMA_F32_32X32X16_FP8_FP8, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_BF8, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_FP8, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_BF8, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_FP8, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_BF8, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_FP8, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_BF8, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_FP8}) + op = EnumBitField(22, 16, VOP3POp, {VOP3POp.V_PK_MAD_I16, VOP3POp.V_PK_MUL_LO_U16, VOP3POp.V_PK_ADD_I16, VOP3POp.V_PK_SUB_I16, VOP3POp.V_PK_LSHLREV_B16, VOP3POp.V_PK_LSHRREV_B16, VOP3POp.V_PK_ASHRREV_I16, VOP3POp.V_PK_MAX_I16, VOP3POp.V_PK_MIN_I16, VOP3POp.V_PK_MAD_U16, VOP3POp.V_PK_ADD_U16, VOP3POp.V_PK_SUB_U16, VOP3POp.V_PK_MAX_U16, VOP3POp.V_PK_MIN_U16, VOP3POp.V_PK_FMA_F16, VOP3POp.V_PK_ADD_F16, VOP3POp.V_PK_MUL_F16, VOP3POp.V_PK_MIN_F16, VOP3POp.V_PK_MAX_F16, VOP3POp.V_DOT2_F32_BF16, VOP3POp.V_PK_MINIMUM3_F16, VOP3POp.V_PK_MAXIMUM3_F16, VOP3POp.V_MAD_MIX_F32, VOP3POp.V_MAD_MIXLO_F16, VOP3POp.V_MAD_MIXHI_F16, VOP3POp.V_DOT2_F32_F16, VOP3POp.V_DOT2_I32_I16, VOP3POp.V_DOT2_U32_U16, VOP3POp.V_DOT4_I32_I8, VOP3POp.V_DOT4_U32_U8, VOP3POp.V_DOT8_I32_I4, VOP3POp.V_DOT8_U32_U4, VOP3POp.V_MFMA_LD_SCALE_B32, VOP3POp.V_PK_FMA_F32, VOP3POp.V_PK_MUL_F32, VOP3POp.V_PK_ADD_F32, VOP3POp.V_PK_MOV_B32, VOP3POp.V_MFMA_F32_16X16X8_XF32, VOP3POp.V_MFMA_F32_32X32X4_XF32, VOP3POp.V_ACCVGPR_READ, VOP3POp.V_ACCVGPR_WRITE}) vdst = VGPRField(7, 0) src0 = SrcField(40, 32) src1 = SrcField(49, 41) @@ -310,6 +310,7 @@ class VOP2_SDWA_SDST(VOP2): s1 = BitField(63, 63) class VOP3P_MFMA(VOP3P): + op = EnumBitField(22, 16, VOP3POp, {VOP3POp.V_MFMA_F32_16X16X128_F8F6F4, VOP3POp.V_MFMA_F32_32X32X64_F8F6F4, VOP3POp.V_MFMA_F32_16X16X32_BF16, VOP3POp.V_MFMA_I32_16X16X64_I8, VOP3POp.V_MFMA_F32_32X32X16_BF16, VOP3POp.V_MFMA_I32_32X32X32_I8, VOP3POp.V_SMFMAC_F32_16X16X64_BF16, VOP3POp.V_SMFMAC_I32_16X16X128_I8, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_BF8, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_FP8, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_BF8, VOP3POp.V_MFMA_F32_32X32X1_2B_F32, VOP3POp.V_MFMA_F32_16X16X1_4B_F32, VOP3POp.V_MFMA_F32_4X4X1_16B_F32, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_FP8, VOP3POp.V_MFMA_F32_32X32X2_F32, VOP3POp.V_MFMA_F32_16X16X4_F32, VOP3POp.V_SMFMAC_F32_32X32X32_BF16, VOP3POp.V_SMFMAC_I32_32X32X64_I8, VOP3POp.V_MFMA_F32_32X32X4_2B_F16, VOP3POp.V_MFMA_F32_16X16X4_4B_F16, VOP3POp.V_MFMA_F32_4X4X4_16B_F16, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_BF8, VOP3POp.V_MFMA_F32_32X32X8_F16, VOP3POp.V_MFMA_F32_16X16X16_F16, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_FP8, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_BF8, VOP3POp.V_MFMA_I32_32X32X4_2B_I8, VOP3POp.V_MFMA_I32_16X16X4_4B_I8, VOP3POp.V_MFMA_I32_4X4X4_16B_I8, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_FP8, VOP3POp.V_MFMA_F32_16X16X32_F16, VOP3POp.V_MFMA_F32_32X32X16_F16, VOP3POp.V_MFMA_I32_32X32X16_I8, VOP3POp.V_MFMA_I32_16X16X32_I8, VOP3POp.V_SMFMAC_F32_16X16X64_F16, VOP3POp.V_SMFMAC_F32_32X32X32_F16, VOP3POp.V_MFMA_F32_32X32X4_2B_BF16, VOP3POp.V_MFMA_F32_16X16X4_4B_BF16, VOP3POp.V_MFMA_F32_4X4X4_16B_BF16, VOP3POp.V_MFMA_F32_32X32X8_BF16, VOP3POp.V_MFMA_F32_16X16X16_BF16, VOP3POp.V_SMFMAC_F32_16X16X32_F16, VOP3POp.V_SMFMAC_F32_32X32X16_F16, VOP3POp.V_SMFMAC_F32_16X16X32_BF16, VOP3POp.V_SMFMAC_F32_32X32X16_BF16, VOP3POp.V_SMFMAC_I32_16X16X64_I8, VOP3POp.V_SMFMAC_I32_32X32X32_I8, VOP3POp.V_MFMA_F64_16X16X4_F64, VOP3POp.V_MFMA_F64_4X4X4_4B_F64, VOP3POp.V_MFMA_F32_16X16X32_BF8_BF8, VOP3POp.V_MFMA_F32_16X16X32_BF8_FP8, VOP3POp.V_MFMA_F32_16X16X32_FP8_BF8, VOP3POp.V_MFMA_F32_16X16X32_FP8_FP8, VOP3POp.V_MFMA_F32_32X32X16_BF8_BF8, VOP3POp.V_MFMA_F32_32X32X16_BF8_FP8, VOP3POp.V_MFMA_F32_32X32X16_FP8_BF8, VOP3POp.V_MFMA_F32_32X32X16_FP8_FP8, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_BF8, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_FP8, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_BF8, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_FP8, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_BF8, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_FP8, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_BF8, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_FP8}) cbsz = BitField(10, 8) abid = BitField(14, 11) acc_cd = BitField(15, 15) @@ -1648,80 +1649,80 @@ v_dot4_u32_u8 = functools.partial(VOP3P, VOP3POp.V_DOT4_U32_U8) v_dot8_i32_i4 = functools.partial(VOP3P, VOP3POp.V_DOT8_I32_I4) v_dot8_u32_u4 = functools.partial(VOP3P, VOP3POp.V_DOT8_U32_U4) v_mfma_ld_scale_b32 = functools.partial(VOP3P, VOP3POp.V_MFMA_LD_SCALE_B32) -v_mfma_f32_16x16x128_f8f6f4 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X128_F8F6F4) -v_mfma_f32_32x32x64_f8f6f4 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X64_F8F6F4) +v_mfma_f32_16x16x128_f8f6f4 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X128_F8F6F4) +v_mfma_f32_32x32x64_f8f6f4 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X64_F8F6F4) v_pk_fma_f32 = functools.partial(VOP3P, VOP3POp.V_PK_FMA_F32) v_pk_mul_f32 = functools.partial(VOP3P, VOP3POp.V_PK_MUL_F32) v_pk_add_f32 = functools.partial(VOP3P, VOP3POp.V_PK_ADD_F32) v_pk_mov_b32 = functools.partial(VOP3P, VOP3POp.V_PK_MOV_B32) -v_mfma_f32_16x16x32_bf16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X32_BF16) -v_mfma_i32_16x16x64_i8 = functools.partial(VOP3P, VOP3POp.V_MFMA_I32_16X16X64_I8) -v_mfma_f32_32x32x16_bf16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X16_BF16) -v_mfma_i32_32x32x32_i8 = functools.partial(VOP3P, VOP3POp.V_MFMA_I32_32X32X32_I8) -v_smfmac_f32_16x16x64_bf16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X64_BF16) -v_smfmac_i32_16x16x128_i8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_I32_16X16X128_I8) -v_smfmac_f32_16x16x128_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_BF8) -v_smfmac_f32_16x16x128_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_FP8) -v_smfmac_f32_16x16x128_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_BF8) +v_mfma_f32_16x16x32_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X32_BF16) +v_mfma_i32_16x16x64_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_I32_16X16X64_I8) +v_mfma_f32_32x32x16_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X16_BF16) +v_mfma_i32_32x32x32_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_I32_32X32X32_I8) +v_smfmac_f32_16x16x64_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X64_BF16) +v_smfmac_i32_16x16x128_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_I32_16X16X128_I8) +v_smfmac_f32_16x16x128_bf8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_BF8) +v_smfmac_f32_16x16x128_bf8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X128_BF8_FP8) +v_smfmac_f32_16x16x128_fp8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_BF8) v_mfma_f32_16x16x8_xf32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X8_XF32) v_mfma_f32_32x32x4_xf32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X4_XF32) -v_mfma_f32_32x32x1_2b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X1_2B_F32) -v_mfma_f32_16x16x1_4b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X1_4B_F32) -v_mfma_f32_4x4x1_16b_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_4X4X1_16B_F32) -v_smfmac_f32_16x16x128_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_FP8) -v_mfma_f32_32x32x2_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X2_F32) -v_mfma_f32_16x16x4_f32 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X4_F32) -v_smfmac_f32_32x32x32_bf16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_BF16) -v_smfmac_i32_32x32x64_i8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_I32_32X32X64_I8) -v_mfma_f32_32x32x4_2b_f16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X4_2B_F16) -v_mfma_f32_16x16x4_4b_f16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X4_4B_F16) -v_mfma_f32_4x4x4_16b_f16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_4X4X4_16B_F16) -v_smfmac_f32_32x32x64_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_BF8) -v_mfma_f32_32x32x8_f16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X8_F16) -v_mfma_f32_16x16x16_f16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X16_F16) -v_smfmac_f32_32x32x64_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_FP8) -v_smfmac_f32_32x32x64_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_BF8) -v_mfma_i32_32x32x4_2b_i8 = functools.partial(VOP3P, VOP3POp.V_MFMA_I32_32X32X4_2B_I8) -v_mfma_i32_16x16x4_4b_i8 = functools.partial(VOP3P, VOP3POp.V_MFMA_I32_16X16X4_4B_I8) -v_mfma_i32_4x4x4_16b_i8 = functools.partial(VOP3P, VOP3POp.V_MFMA_I32_4X4X4_16B_I8) -v_smfmac_f32_32x32x64_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_FP8) -v_mfma_f32_16x16x32_f16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X32_F16) -v_mfma_f32_32x32x16_f16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X16_F16) -v_mfma_i32_32x32x16_i8 = functools.partial(VOP3P, VOP3POp.V_MFMA_I32_32X32X16_I8) -v_mfma_i32_16x16x32_i8 = functools.partial(VOP3P, VOP3POp.V_MFMA_I32_16X16X32_I8) +v_mfma_f32_32x32x1_2b_f32 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X1_2B_F32) +v_mfma_f32_16x16x1_4b_f32 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X1_4B_F32) +v_mfma_f32_4x4x1_16b_f32 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_4X4X1_16B_F32) +v_smfmac_f32_16x16x128_fp8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X128_FP8_FP8) +v_mfma_f32_32x32x2_f32 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X2_F32) +v_mfma_f32_16x16x4_f32 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X4_F32) +v_smfmac_f32_32x32x32_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X32_BF16) +v_smfmac_i32_32x32x64_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_I32_32X32X64_I8) +v_mfma_f32_32x32x4_2b_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X4_2B_F16) +v_mfma_f32_16x16x4_4b_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X4_4B_F16) +v_mfma_f32_4x4x4_16b_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_4X4X4_16B_F16) +v_smfmac_f32_32x32x64_bf8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_BF8) +v_mfma_f32_32x32x8_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X8_F16) +v_mfma_f32_16x16x16_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X16_F16) +v_smfmac_f32_32x32x64_bf8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X64_BF8_FP8) +v_smfmac_f32_32x32x64_fp8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_BF8) +v_mfma_i32_32x32x4_2b_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_I32_32X32X4_2B_I8) +v_mfma_i32_16x16x4_4b_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_I32_16X16X4_4B_I8) +v_mfma_i32_4x4x4_16b_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_I32_4X4X4_16B_I8) +v_smfmac_f32_32x32x64_fp8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X64_FP8_FP8) +v_mfma_f32_16x16x32_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X32_F16) +v_mfma_f32_32x32x16_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X16_F16) +v_mfma_i32_32x32x16_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_I32_32X32X16_I8) +v_mfma_i32_16x16x32_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_I32_16X16X32_I8) v_accvgpr_read = functools.partial(VOP3P, VOP3POp.V_ACCVGPR_READ) v_accvgpr_write = functools.partial(VOP3P, VOP3POp.V_ACCVGPR_WRITE) -v_smfmac_f32_16x16x64_f16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X64_F16) -v_smfmac_f32_32x32x32_f16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_F16) -v_mfma_f32_32x32x4_2b_bf16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X4_2B_BF16) -v_mfma_f32_16x16x4_4b_bf16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X4_4B_BF16) -v_mfma_f32_4x4x4_16b_bf16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_4X4X4_16B_BF16) -v_mfma_f32_32x32x8_bf16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X8_BF16) -v_mfma_f32_16x16x16_bf16 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X16_BF16) -v_smfmac_f32_16x16x32_f16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X32_F16) -v_smfmac_f32_32x32x16_f16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X16_F16) -v_smfmac_f32_16x16x32_bf16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X32_BF16) -v_smfmac_f32_32x32x16_bf16 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X16_BF16) -v_smfmac_i32_16x16x64_i8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_I32_16X16X64_I8) -v_smfmac_i32_32x32x32_i8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_I32_32X32X32_I8) -v_mfma_f64_16x16x4_f64 = functools.partial(VOP3P, VOP3POp.V_MFMA_F64_16X16X4_F64) -v_mfma_f64_4x4x4_4b_f64 = functools.partial(VOP3P, VOP3POp.V_MFMA_F64_4X4X4_4B_F64) -v_mfma_f32_16x16x32_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X32_BF8_BF8) -v_mfma_f32_16x16x32_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X32_BF8_FP8) -v_mfma_f32_16x16x32_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X32_FP8_BF8) -v_mfma_f32_16x16x32_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_16X16X32_FP8_FP8) -v_mfma_f32_32x32x16_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X16_BF8_BF8) -v_mfma_f32_32x32x16_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X16_BF8_FP8) -v_mfma_f32_32x32x16_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X16_FP8_BF8) -v_mfma_f32_32x32x16_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_MFMA_F32_32X32X16_FP8_FP8) -v_smfmac_f32_16x16x64_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_BF8) -v_smfmac_f32_16x16x64_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_FP8) -v_smfmac_f32_16x16x64_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_BF8) -v_smfmac_f32_16x16x64_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_FP8) -v_smfmac_f32_32x32x32_bf8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_BF8) -v_smfmac_f32_32x32x32_bf8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_FP8) -v_smfmac_f32_32x32x32_fp8_bf8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_BF8) -v_smfmac_f32_32x32x32_fp8_fp8 = functools.partial(VOP3P, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_FP8) +v_smfmac_f32_16x16x64_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X64_F16) +v_smfmac_f32_32x32x32_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X32_F16) +v_mfma_f32_32x32x4_2b_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X4_2B_BF16) +v_mfma_f32_16x16x4_4b_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X4_4B_BF16) +v_mfma_f32_4x4x4_16b_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_4X4X4_16B_BF16) +v_mfma_f32_32x32x8_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X8_BF16) +v_mfma_f32_16x16x16_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X16_BF16) +v_smfmac_f32_16x16x32_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X32_F16) +v_smfmac_f32_32x32x16_f16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X16_F16) +v_smfmac_f32_16x16x32_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X32_BF16) +v_smfmac_f32_32x32x16_bf16 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X16_BF16) +v_smfmac_i32_16x16x64_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_I32_16X16X64_I8) +v_smfmac_i32_32x32x32_i8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_I32_32X32X32_I8) +v_mfma_f64_16x16x4_f64 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F64_16X16X4_F64) +v_mfma_f64_4x4x4_4b_f64 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F64_4X4X4_4B_F64) +v_mfma_f32_16x16x32_bf8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X32_BF8_BF8) +v_mfma_f32_16x16x32_bf8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X32_BF8_FP8) +v_mfma_f32_16x16x32_fp8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X32_FP8_BF8) +v_mfma_f32_16x16x32_fp8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_16X16X32_FP8_FP8) +v_mfma_f32_32x32x16_bf8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X16_BF8_BF8) +v_mfma_f32_32x32x16_bf8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X16_BF8_FP8) +v_mfma_f32_32x32x16_fp8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X16_FP8_BF8) +v_mfma_f32_32x32x16_fp8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_MFMA_F32_32X32X16_FP8_FP8) +v_smfmac_f32_16x16x64_bf8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_BF8) +v_smfmac_f32_16x16x64_bf8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X64_BF8_FP8) +v_smfmac_f32_16x16x64_fp8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_BF8) +v_smfmac_f32_16x16x64_fp8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_16X16X64_FP8_FP8) +v_smfmac_f32_32x32x32_bf8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_BF8) +v_smfmac_f32_32x32x32_bf8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X32_BF8_FP8) +v_smfmac_f32_32x32x32_fp8_bf8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_BF8) +v_smfmac_f32_32x32x32_fp8_fp8 = functools.partial(VOP3P_MFMA, VOP3POp.V_SMFMAC_F32_32X32X32_FP8_FP8) v_mfma_scale_f32_16x16x128_f8f6f4 = functools.partial(VOP3PX2, VOP3PX2Op.V_MFMA_SCALE_F32_16X16X128_F8F6F4) v_mfma_scale_f32_32x32x64_f8f6f4 = functools.partial(VOP3PX2, VOP3PX2Op.V_MFMA_SCALE_F32_32X32X64_F8F6F4) v_add_co_u32 = functools.partial(VOP3SD, VOP3SDOp.V_ADD_CO_U32) diff --git a/extra/assembly/amd/decode.py b/extra/assembly/amd/decode.py index 324e92e098..60e2fe1f70 100644 --- a/extra/assembly/amd/decode.py +++ b/extra/assembly/amd/decode.py @@ -38,7 +38,7 @@ from extra.assembly.amd.autogen.rdna4.ins import (VOP1 as R4_VOP1, VOP1_SDST as from extra.assembly.amd.autogen.cdna.ins import (VOP1 as C_VOP1, VOP1_SDWA as C_VOP1_SDWA, VOP1_DPP16 as C_VOP1_DPP16, VOP2 as C_VOP2, VOP2_LIT as C_VOP2_LIT, VOP2_SDWA as C_VOP2_SDWA, VOP2_DPP16 as C_VOP2_DPP16, VOPC as C_VOPC, VOPC_SDWA_SDST as C_VOPC_SDWA_SDST, - VOP3 as C_VOP3, VOP3_SDST as C_VOP3_SDST, VOP3SD as C_VOP3SD, VOP3P as C_VOP3P, VOP3PX2 as C_VOP3PX2, + VOP3 as C_VOP3, VOP3_SDST as C_VOP3_SDST, VOP3SD as C_VOP3SD, VOP3P as C_VOP3P, VOP3P_MFMA as C_VOP3P_MFMA, VOP3PX2 as C_VOP3PX2, SOP1 as C_SOP1, SOP2 as C_SOP2, SOPC as C_SOPC, SOPK as C_SOPK, SOPK_LIT as C_SOPK_LIT, SOPP as C_SOPP, SMEM as C_SMEM, DS as C_DS, FLAT as C_FLAT, GLOBAL as C_GLOBAL, SCRATCH as C_SCRATCH, MUBUF as C_MUBUF) @@ -50,7 +50,7 @@ _FORMATS = { "rdna4": [R4_VOPD, R4_VOP3P, R4_VINTERP, R4_VOP3SD, R4_VOP3_SDST, R4_VOP3, R4_DS, R4_GLOBAL, R4_SCRATCH, R4_FLAT, R4_SMEM, R4_SOP1, R4_SOP1_LIT, R4_SOPC, R4_SOPC_LIT, R4_SOPP, R4_SOPK, R4_SOPK_LIT, R4_VOPC, R4_VOP1_SDST, R4_VOP1, R4_VOP1_LIT, R4_SOP2, R4_SOP2_LIT, R4_VOP2, R4_VOP2_LIT], - "cdna": [C_VOP3PX2, C_VOP3P, C_VOP3SD, C_VOP3_SDST, C_VOP3, C_DS, C_GLOBAL, C_SCRATCH, C_FLAT, C_MUBUF, C_SMEM, + "cdna": [C_VOP3PX2, C_VOP3P_MFMA, C_VOP3P, C_VOP3SD, C_VOP3_SDST, C_VOP3, C_DS, C_GLOBAL, C_SCRATCH, C_FLAT, C_MUBUF, C_SMEM, C_SOP1, C_SOPC, C_SOPP, C_SOPK, C_SOPK_LIT, C_VOPC_SDWA_SDST, C_VOPC, C_VOP1_DPP16, C_VOP1_SDWA, C_VOP1, C_VOP2_DPP16, C_VOP2_SDWA, C_SOP2, C_VOP2, C_VOP2_LIT], } diff --git a/extra/assembly/amd/disasm.py b/extra/assembly/amd/disasm.py index 0d6ecf9a81..17d0171b1f 100644 --- a/extra/assembly/amd/disasm.py +++ b/extra/assembly/amd/disasm.py @@ -603,7 +603,7 @@ from extra.assembly.amd.autogen.cdna.ins import (VOP1 as CDNA_VOP1, VOP1_LIT as VOP1_SDWA as CDNA_VOP1_SDWA, VOP1_DPP16 as CDNA_VOP1_DPP16, VOP2 as CDNA_VOP2, VOP2_LIT as CDNA_VOP2_LIT, VOP2_SDWA as CDNA_VOP2_SDWA, VOP2_DPP16 as CDNA_VOP2_DPP16, VOPC as CDNA_VOPC, VOPC_LIT as CDNA_VOPC_LIT, VOPC_SDWA_SDST as CDNA_VOPC_SDWA_SDST, - VOP3 as CDNA_VOP3, VOP3_SDST as CDNA_VOP3_SDST, VOP3SD as CDNA_VOP3SD, VOP3P as CDNA_VOP3P, VOP3PX2 as CDNA_VOP3PX2, + VOP3 as CDNA_VOP3, VOP3_SDST as CDNA_VOP3_SDST, VOP3SD as CDNA_VOP3SD, VOP3P as CDNA_VOP3P, VOP3P_MFMA as CDNA_VOP3P_MFMA, VOP3PX2 as CDNA_VOP3PX2, SOP1 as CDNA_SOP1, SOP1_LIT as CDNA_SOP1_LIT, SOP2 as CDNA_SOP2, SOP2_LIT as CDNA_SOP2_LIT, SOPC as CDNA_SOPC, SOPC_LIT as CDNA_SOPC_LIT, SOPK as CDNA_SOPK, SOPK_LIT as CDNA_SOPK_LIT, SOPP as CDNA_SOPP, SMEM as CDNA_SMEM, DS as CDNA_DS, @@ -902,5 +902,5 @@ DISASM_HANDLERS.update({CDNA_VOP1: _disasm_vop1, CDNA_VOP1_LIT: _disasm_vop1, CDNA_SOP1: _disasm_sop1, CDNA_SOP1_LIT: _disasm_sop1, CDNA_SOP2: _disasm_sop2, CDNA_SOP2_LIT: _disasm_sop2, CDNA_SOPC: _disasm_sopc, CDNA_SOPC_LIT: _disasm_sopc, CDNA_SOPK: _disasm_sopk, CDNA_SOPK_LIT: _disasm_sopk, CDNA_SOPP: _disasm_sopp, CDNA_SMEM: _disasm_smem, CDNA_DS: _disasm_ds, CDNA_FLAT: _disasm_flat, CDNA_GLOBAL: _disasm_flat, CDNA_SCRATCH: _disasm_flat, - CDNA_VOP3: _disasm_vop3a, CDNA_VOP3_SDST: _disasm_vop3b, CDNA_VOP3SD: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, + CDNA_VOP3: _disasm_vop3a, CDNA_VOP3_SDST: _disasm_vop3b, CDNA_VOP3SD: _disasm_vop3b, CDNA_VOP3P: _disasm_cdna_vop3p, CDNA_VOP3P_MFMA: _disasm_cdna_vop3p, CDNA_MUBUF: _disasm_mubuf, CDNA_VOP3PX2: _disasm_vop3px2}) diff --git a/extra/assembly/amd/test/test_handwritten.py b/extra/assembly/amd/test/test_handwritten.py index 40916349af..dbb40dcb04 100644 --- a/extra/assembly/amd/test/test_handwritten.py +++ b/extra/assembly/amd/test/test_handwritten.py @@ -6,17 +6,21 @@ from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.dsl import Inst from extra.assembly.amd.test.test_roundtrip import compile_asm -class TestIntegration(unittest.TestCase): +class IntegrationTestBase(unittest.TestCase): inst: Inst + arch: str def tearDown(self): if not hasattr(self, 'inst'): return b = self.inst.to_bytes() st = self.inst.disasm() # Test that the instruction can be compiled by LLVM and produces the same bytes desc = f"{st:25s} {self.inst} {b!r}" - self.assertEqual(b, compile_asm(st), desc) + self.assertEqual(b, compile_asm(st, arch=self.arch), desc) print(desc) +class TestIntegration(IntegrationTestBase): + arch: str = "rdna3" + def test_wmma(self): self.inst = v_wmma_f32_16x16x16_f16(v[0:7], v[184:191], v[136:143], v[0:7]) @@ -124,6 +128,17 @@ class TestIntegration(unittest.TestCase): int_inst = s_mov_b32(s[0], struct.unpack("I", struct.pack("f", 1337.0))[0]) self.assertEqual(self.inst, int_inst) +class TestIntegrationCDNA(IntegrationTestBase): + arch = "cdna" + + def test_mfma(self): + from extra.assembly.amd.autogen.cdna.ins import v_mfma_f32_16x16x16_f16 + self.inst = v_mfma_f32_16x16x16_f16(v[0:3], v[0:1], v[0:1], 0) + + def test_mfma_fp8(self): + from extra.assembly.amd.autogen.cdna.ins import v_mfma_f32_16x16x128_f8f6f4 + self.inst = v_mfma_f32_16x16x128_f8f6f4(v[0:3], v[0:5], v[0:5], 1, cbsz=2, blgp=2) + class TestRegisterSliceSyntax(unittest.TestCase): """ Issue: Register slice syntax should use AMD assembly convention (inclusive end).