From eaa5a05f3da090561cce1d48bbcd8945cbabbffb Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 5 Jan 2026 07:45:31 -0800 Subject: [PATCH] 100% asm --- extra/assembly/amd/asm.py | 119 +++++++++++++++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 2 deletions(-) diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index a88f239284..6ff4e6b46a 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -60,7 +60,17 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO', 19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'} +# GFX942-specific HWREG values +_HWREG_GFX942 = {'HW_REG_XCC_ID': 20, 'HW_REG_SQ_PERF_SNAPSHOT_DATA': 21, 'HW_REG_SQ_PERF_SNAPSHOT_DATA1': 22, + 'HW_REG_SQ_PERF_SNAPSHOT_PC_LO': 23, 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI': 24} HWREG_IDS = {v.lower(): k for k, v in HWREG.items()} +HWREG_IDS.update({k.lower(): v for k, v in _HWREG_GFX942.items()}) +def hwreg(name, offset=0, size=32): + """Encode hwreg(name[, offset[, size]]) -> simm16 value. id[5:0], offset[10:6], size-1[15:11]""" + if isinstance(name, int): hid = name + else: hid = HWREG_IDS.get(name.lower(), HWREG_IDS.get(name.lower().replace('hw_reg_', ''), None)) + if hid is None: raise ValueError(f"unknown hwreg: {name}") + return hid | (offset << 6) | ((size - 1) << 11) # RDNA unified buffer format - extracted from PDF, use enum for name->value lookup BUF_FMT = {e.name: e.value for e in BufFmt} def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y] @@ -571,6 +581,13 @@ def _op2dsl(op: str, arch: str = "rdna3") -> str: if m := re.match(r'^([asvt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}:{m.group(3)}]") if m := re.match(r'^([asvt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}]") if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op + # Floating-point literal: convert to IEEE 754 32-bit integer representation + import struct + try: + f = float(op) + as_int = struct.unpack(' list[str]: @@ -722,6 +739,7 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str: m, text = _extract(text, r'\s+row_half_mirror(?:\s|$)'); dpp_ctrl = 0x141 if m else dpp_ctrl m, text = _extract(text, r'\s+row_bcast:15(?:\s|$)'); dpp_ctrl = 0x142 if m else dpp_ctrl m, text = _extract(text, r'\s+row_bcast:31(?:\s|$)'); dpp_ctrl = 0x143 if m else dpp_ctrl + m, text = _extract(text, r'\s+row_newbcast:(\d+)'); dpp_ctrl = 0x150 + int(m.group(1)) if m else dpp_ctrl m, text = _extract(text, r'\s+row_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_row_mask = int(m.group(1), 0) if m else None; dpp_row_mask_specified = m is not None m, text = _extract(text, r'\s+bank_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_bank_mask = int(m.group(1), 0) if m else None; dpp_bank_mask_specified = m is not None m, text = _extract(text, r'\s+bound_ctrl:([01])'); dpp_bound_ctrl = 1 if m else None # bound_ctrl:0 or bound_ctrl:1 both set bit to 1 @@ -1311,7 +1329,9 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str: # ACC register support for CDNA: detect a[N] registers and set acc=1 acc_mod = ', acc=1' if arch == 'cdna' and _has_acc(args) else '' args = [_acc_to_vgpr(a) for a in args] # convert a[N] to v[N] for encoding - if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})" + # For atomics with return value: vdst, addr, data, [saddr] - triggered by glc (or sc0 for GFX942) + has_return = glc or sc0 + if has_return and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})" if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})" # DS instructions @@ -1453,6 +1473,95 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str: a_str = ', '.join(vop3_args + all_kw) return f"{fn[:-4]}({a_str})" + # CDNA VOP1 with modifiers: auto-promote to VOP3A/SDWA/DPP + # Check if this is a VOP1 instruction needing extended encoding (not already _e64/_sdwa/_dpp) + has_vop3_mods = any(k.startswith(('omod=', 'clmp=')) for k in all_kw) + has_sdwa_mods = sdwa_src0_sel is not None or sdwa_src1_sel is not None or sdwa_dst_sel is not None + has_dpp_mods = dpp_ctrl is not None + if arch == "cdna" and fn.startswith('v_') and not fn.endswith(('_e64', '_sdwa', '_dpp')) and (has_vop3_mods or has_sdwa_mods or has_dpp_mods): + from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, SDWA, DPP + fn_upper = fn.upper() + vop1_op = getattr(VOP1Op, fn_upper, None) + vop2_op = getattr(VOP2Op, fn_upper, None) + if vop1_op is not None or vop2_op is not None: + if has_sdwa_mods: + # SDWA encoding for VOP1/VOP2 with src0_sel/src1_sel/dst_sel + sdwa_kw = [] + src0_orig = ops[1].strip().lower() if len(ops) > 1 else '' + src0_is_sgpr = src0_orig.startswith('s') and not src0_orig.startswith('src') + src0_is_literal = src0_orig.isdigit() or (len(src0_orig) > 2 and src0_orig[:2] == '0x') + if vop1_op is not None: + sdwa_kw.append(f'vop_op={vop1_op.value}') + sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode + sdwa_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]') + sdwa_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]') + else: + sdwa_kw.append(f'vop_op={args[1] if len(args) > 1 else "v[0]"}') + sdwa_kw.append(f'vop2_op={vop2_op.value}') + sdwa_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]') + sdwa_kw.append(f'src0={args[2] if len(args) > 2 else "v[0]"}') + sdwa_kw.append(f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}') + sdwa_kw.append('dst_u=0') + sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}') + sdwa_kw.append('src0_sext=0') + sdwa_kw.append('src0_neg=0') + sdwa_kw.append('src0_abs=0') + sdwa_kw.append(f's0={1 if src0_is_sgpr or src0_is_literal else 0}') # s0=1 for SGPR/literal + sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 0}') # 0 for VOP1 + sdwa_kw.append('src1_sext=0') + sdwa_kw.append('src1_neg=0') + sdwa_kw.append('src1_abs=0') + sdwa_kw.append('s1=0') + # Add clamp and omod if present + if any(k == 'clmp=1' for k in all_kw): sdwa_kw.append('clmp=1') + for k in all_kw: + if k.startswith('omod='): sdwa_kw.append(k); break + return f"SDWA({', '.join(sdwa_kw)})" + elif has_dpp_mods: + # DPP encoding for VOP1/VOP2 with quad_perm/row_shl/etc. + dpp_kw = [] + if vop1_op is not None: + dpp_kw.append(f'vop_op={vop1_op.value}') + dpp_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode + dpp_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]') + dpp_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]') + else: + dpp_kw.append(f'vop_op={args[1] if len(args) > 1 else "v[0]"}') + dpp_kw.append(f'vop2_op={vop2_op.value}') + dpp_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]') + dpp_kw.append(f'src0={args[2] if len(args) > 2 else "v[0]"}') + dpp_kw.append(f'dpp_ctrl={dpp_ctrl}') + dpp_kw.append(f'row_mask={dpp_row_mask if dpp_row_mask is not None else 15}') + dpp_kw.append(f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 15}') + dpp_kw.append(f'bound_ctrl={dpp_bound_ctrl if dpp_bound_ctrl is not None else 0}') + return f"DPP({', '.join(dpp_kw)})" + elif has_vop3_mods and vop1_op is not None: + # VOP3A encoding for VOP1 with clamp/omod + from extra.assembly.amd.autogen.cdna.ins import VOP3AOp + # Calculate promoted opcode: VOP3 op = 320 + VOP1_op + promoted_op = 320 + vop1_op.value + vop3_kw = [f'op={promoted_op}'] + vop3_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]') + vop3_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]') + vop3_kw.append('src1=RawImm(0)') + vop3_kw.append('src2=RawImm(0)') + vop3_kw.extend(all_kw) + return f"VOP3A({', '.join(vop3_kw)})" + + # GFX942-specific VOP3A opcode adjustments: some instructions need +64 offset + _GFX942_VOP3A_OFFSET64 = {'V_CVT_PK_BF8_F32', 'V_CVT_PK_FP8_F32', 'V_CVT_SR_BF8_F32', 'V_CVT_SR_FP8_F32', 'V_LSHL_ADD_U64'} + if gfx942 and fn.upper() in _GFX942_VOP3A_OFFSET64: + from extra.assembly.amd.autogen.cdna.ins import VOP3AOp + base_op = getattr(VOP3AOp, fn.upper(), None) + if base_op is not None: + vop3_kw = [f'op={base_op + 64}'] + vop3_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]') + vop3_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]') + vop3_kw.append(f'src1={args[2]}' if len(args) > 2 else 'src1=RawImm(0)') + vop3_kw.append(f'src2={args[3]}' if len(args) > 3 else 'src2=RawImm(0)') + vop3_kw.extend(all_kw) + return f"VOP3A({', '.join(vop3_kw)})" + a_str, kw_str = ', '.join(args), ', '.join(all_kw) return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})" @@ -1465,9 +1574,15 @@ def asm(text: str, arch: str = "rdna3") -> Inst: from extra.assembly.amd.autogen.cdna import ins as cdna_ins ns = {n: getattr(cdna_ins, n) for n in dir(cdna_ins) if not n.startswith('_')} # CDNA special registers: m0=124, flat_scratch=102-103, xnack_mask=104-105, no NULL (use m0 for off) + # HWREG symbolic names for s_getreg_b32/s_setreg_b32 + _hwreg_names = {k: v for k, v in _HWREG_GFX942.items()} + _hwreg_names.update({v: k for k, v in HWREG.items()}) # standard names: id -> name + _hwreg_ids = {v: k for k, v in _hwreg_names.items()} # reverse: name -> id ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP, 'VCC_LO': RawImm(106), 'VCC_HI': RawImm(107), 'VCC': RawImm(106), 'EXEC_LO': RawImm(126), 'EXEC_HI': RawImm(127), 'EXEC': RawImm(126), - 'SCC': RawImm(253), 'M0': RawImm(124), 'NULL': RawImm(124), 'OFF': RawImm(124), + 'SCC': RawImm(253), 'M0': RawImm(124), 'NULL': RawImm(124), 'OFF': RawImm(124), 'hwreg': hwreg, + 'HW_REG_XCC_ID': 20, 'HW_REG_SQ_PERF_SNAPSHOT_DATA': 21, 'HW_REG_SQ_PERF_SNAPSHOT_DATA1': 22, + 'HW_REG_SQ_PERF_SNAPSHOT_PC_LO': 23, 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI': 24, 'FLAT_SCRATCH_LO': RawImm(102), 'FLAT_SCRATCH_HI': RawImm(103), 'FLAT_SCRATCH': RawImm(102), 'XNACK_MASK_LO': RawImm(104), 'XNACK_MASK_HI': RawImm(105), 'XNACK_MASK': RawImm(104), 'SRC_VCCZ': RawImm(251), 'SRC_EXECZ': RawImm(252), 'SRC_SCC': RawImm(253), 'SRC_LDS_DIRECT': RawImm(254)})