This commit is contained in:
George Hotz
2026-01-05 07:45:31 -08:00
parent ea244a4fce
commit eaa5a05f3d

View File

@@ -60,7 +60,17 @@ HWREG = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_H
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
# GFX942-specific HWREG values
_HWREG_GFX942 = {'HW_REG_XCC_ID': 20, 'HW_REG_SQ_PERF_SNAPSHOT_DATA': 21, 'HW_REG_SQ_PERF_SNAPSHOT_DATA1': 22,
'HW_REG_SQ_PERF_SNAPSHOT_PC_LO': 23, 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI': 24}
HWREG_IDS = {v.lower(): k for k, v in HWREG.items()}
HWREG_IDS.update({k.lower(): v for k, v in _HWREG_GFX942.items()})
def hwreg(name, offset=0, size=32):
"""Encode hwreg(name[, offset[, size]]) -> simm16 value. id[5:0], offset[10:6], size-1[15:11]"""
if isinstance(name, int): hid = name
else: hid = HWREG_IDS.get(name.lower(), HWREG_IDS.get(name.lower().replace('hw_reg_', ''), None))
if hid is None: raise ValueError(f"unknown hwreg: {name}")
return hid | (offset << 6) | ((size - 1) << 11)
# RDNA unified buffer format - extracted from PDF, use enum for name->value lookup
BUF_FMT = {e.name: e.value for e in BufFmt}
def _parse_buf_fmt_combo(s: str) -> int: # parse format:[BUF_DATA_FORMAT_X, BUF_NUM_FORMAT_Y]
@@ -571,6 +581,13 @@ def _op2dsl(op: str, arch: str = "rdna3") -> str:
if m := re.match(r'^([asvt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}:{m.group(3)}]")
if m := re.match(r'^([asvt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}]")
if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op
# Floating-point literal: convert to IEEE 754 32-bit integer representation
import struct
try:
f = float(op)
as_int = struct.unpack('<I', struct.pack('<f', f))[0]
return f"SrcMod({as_int}, neg={neg}, abs_={abs_})" if neg or abs_ else str(as_int)
except ValueError: pass
return wrap(op)
def _parse_ops(s: str) -> list[str]:
@@ -722,6 +739,7 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
m, text = _extract(text, r'\s+row_half_mirror(?:\s|$)'); dpp_ctrl = 0x141 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_bcast:15(?:\s|$)'); dpp_ctrl = 0x142 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_bcast:31(?:\s|$)'); dpp_ctrl = 0x143 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_newbcast:(\d+)'); dpp_ctrl = 0x150 + int(m.group(1)) if m else dpp_ctrl
m, text = _extract(text, r'\s+row_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_row_mask = int(m.group(1), 0) if m else None; dpp_row_mask_specified = m is not None
m, text = _extract(text, r'\s+bank_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_bank_mask = int(m.group(1), 0) if m else None; dpp_bank_mask_specified = m is not None
m, text = _extract(text, r'\s+bound_ctrl:([01])'); dpp_bound_ctrl = 1 if m else None # bound_ctrl:0 or bound_ctrl:1 both set bit to 1
@@ -1311,7 +1329,9 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
# ACC register support for CDNA: detect a[N] registers and set acc=1
acc_mod = ', acc=1' if arch == 'cdna' and _has_acc(args) else ''
args = [_acc_to_vgpr(a) for a in args] # convert a[N] to v[N] for encoding
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})"
# For atomics with return value: vdst, addr, data, [saddr] - triggered by glc (or sc0 for GFX942)
has_return = glc or sc0
if has_return and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})"
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods}{acc_mod})"
# DS instructions
@@ -1453,6 +1473,95 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
a_str = ', '.join(vop3_args + all_kw)
return f"{fn[:-4]}({a_str})"
# CDNA VOP1 with modifiers: auto-promote to VOP3A/SDWA/DPP
# Check if this is a VOP1 instruction needing extended encoding (not already _e64/_sdwa/_dpp)
has_vop3_mods = any(k.startswith(('omod=', 'clmp=')) for k in all_kw)
has_sdwa_mods = sdwa_src0_sel is not None or sdwa_src1_sel is not None or sdwa_dst_sel is not None
has_dpp_mods = dpp_ctrl is not None
if arch == "cdna" and fn.startswith('v_') and not fn.endswith(('_e64', '_sdwa', '_dpp')) and (has_vop3_mods or has_sdwa_mods or has_dpp_mods):
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, SDWA, DPP
fn_upper = fn.upper()
vop1_op = getattr(VOP1Op, fn_upper, None)
vop2_op = getattr(VOP2Op, fn_upper, None)
if vop1_op is not None or vop2_op is not None:
if has_sdwa_mods:
# SDWA encoding for VOP1/VOP2 with src0_sel/src1_sel/dst_sel
sdwa_kw = []
src0_orig = ops[1].strip().lower() if len(ops) > 1 else ''
src0_is_sgpr = src0_orig.startswith('s') and not src0_orig.startswith('src')
src0_is_literal = src0_orig.isdigit() or (len(src0_orig) > 2 and src0_orig[:2] == '0x')
if vop1_op is not None:
sdwa_kw.append(f'vop_op={vop1_op.value}')
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
sdwa_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
sdwa_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
else:
sdwa_kw.append(f'vop_op={args[1] if len(args) > 1 else "v[0]"}')
sdwa_kw.append(f'vop2_op={vop2_op.value}')
sdwa_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
sdwa_kw.append(f'src0={args[2] if len(args) > 2 else "v[0]"}')
sdwa_kw.append(f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}')
sdwa_kw.append('dst_u=0')
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
sdwa_kw.append('src0_sext=0')
sdwa_kw.append('src0_neg=0')
sdwa_kw.append('src0_abs=0')
sdwa_kw.append(f's0={1 if src0_is_sgpr or src0_is_literal else 0}') # s0=1 for SGPR/literal
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 0}') # 0 for VOP1
sdwa_kw.append('src1_sext=0')
sdwa_kw.append('src1_neg=0')
sdwa_kw.append('src1_abs=0')
sdwa_kw.append('s1=0')
# Add clamp and omod if present
if any(k == 'clmp=1' for k in all_kw): sdwa_kw.append('clmp=1')
for k in all_kw:
if k.startswith('omod='): sdwa_kw.append(k); break
return f"SDWA({', '.join(sdwa_kw)})"
elif has_dpp_mods:
# DPP encoding for VOP1/VOP2 with quad_perm/row_shl/etc.
dpp_kw = []
if vop1_op is not None:
dpp_kw.append(f'vop_op={vop1_op.value}')
dpp_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
dpp_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
dpp_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
else:
dpp_kw.append(f'vop_op={args[1] if len(args) > 1 else "v[0]"}')
dpp_kw.append(f'vop2_op={vop2_op.value}')
dpp_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
dpp_kw.append(f'src0={args[2] if len(args) > 2 else "v[0]"}')
dpp_kw.append(f'dpp_ctrl={dpp_ctrl}')
dpp_kw.append(f'row_mask={dpp_row_mask if dpp_row_mask is not None else 15}')
dpp_kw.append(f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 15}')
dpp_kw.append(f'bound_ctrl={dpp_bound_ctrl if dpp_bound_ctrl is not None else 0}')
return f"DPP({', '.join(dpp_kw)})"
elif has_vop3_mods and vop1_op is not None:
# VOP3A encoding for VOP1 with clamp/omod
from extra.assembly.amd.autogen.cdna.ins import VOP3AOp
# Calculate promoted opcode: VOP3 op = 320 + VOP1_op
promoted_op = 320 + vop1_op.value
vop3_kw = [f'op={promoted_op}']
vop3_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
vop3_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
vop3_kw.append('src1=RawImm(0)')
vop3_kw.append('src2=RawImm(0)')
vop3_kw.extend(all_kw)
return f"VOP3A({', '.join(vop3_kw)})"
# GFX942-specific VOP3A opcode adjustments: some instructions need +64 offset
_GFX942_VOP3A_OFFSET64 = {'V_CVT_PK_BF8_F32', 'V_CVT_PK_FP8_F32', 'V_CVT_SR_BF8_F32', 'V_CVT_SR_FP8_F32', 'V_LSHL_ADD_U64'}
if gfx942 and fn.upper() in _GFX942_VOP3A_OFFSET64:
from extra.assembly.amd.autogen.cdna.ins import VOP3AOp
base_op = getattr(VOP3AOp, fn.upper(), None)
if base_op is not None:
vop3_kw = [f'op={base_op + 64}']
vop3_kw.append(f'vdst={args[0]}' if args else 'vdst=v[0]')
vop3_kw.append(f'src0={args[1]}' if len(args) > 1 else 'src0=v[0]')
vop3_kw.append(f'src1={args[2]}' if len(args) > 2 else 'src1=RawImm(0)')
vop3_kw.append(f'src2={args[3]}' if len(args) > 3 else 'src2=RawImm(0)')
vop3_kw.extend(all_kw)
return f"VOP3A({', '.join(vop3_kw)})"
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"
@@ -1465,9 +1574,15 @@ def asm(text: str, arch: str = "rdna3") -> Inst:
from extra.assembly.amd.autogen.cdna import ins as cdna_ins
ns = {n: getattr(cdna_ins, n) for n in dir(cdna_ins) if not n.startswith('_')}
# CDNA special registers: m0=124, flat_scratch=102-103, xnack_mask=104-105, no NULL (use m0 for off)
# HWREG symbolic names for s_getreg_b32/s_setreg_b32
_hwreg_names = {k: v for k, v in _HWREG_GFX942.items()}
_hwreg_names.update({v: k for k, v in HWREG.items()}) # standard names: id -> name
_hwreg_ids = {v: k for k, v in _hwreg_names.items()} # reverse: name -> id
ns.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
'VCC_LO': RawImm(106), 'VCC_HI': RawImm(107), 'VCC': RawImm(106), 'EXEC_LO': RawImm(126), 'EXEC_HI': RawImm(127), 'EXEC': RawImm(126),
'SCC': RawImm(253), 'M0': RawImm(124), 'NULL': RawImm(124), 'OFF': RawImm(124),
'SCC': RawImm(253), 'M0': RawImm(124), 'NULL': RawImm(124), 'OFF': RawImm(124), 'hwreg': hwreg,
'HW_REG_XCC_ID': 20, 'HW_REG_SQ_PERF_SNAPSHOT_DATA': 21, 'HW_REG_SQ_PERF_SNAPSHOT_DATA1': 22,
'HW_REG_SQ_PERF_SNAPSHOT_PC_LO': 23, 'HW_REG_SQ_PERF_SNAPSHOT_PC_HI': 24,
'FLAT_SCRATCH_LO': RawImm(102), 'FLAT_SCRATCH_HI': RawImm(103), 'FLAT_SCRATCH': RawImm(102),
'XNACK_MASK_LO': RawImm(104), 'XNACK_MASK_HI': RawImm(105), 'XNACK_MASK': RawImm(104),
'SRC_VCCZ': RawImm(251), 'SRC_EXECZ': RawImm(252), 'SRC_SCC': RawImm(253), 'SRC_LDS_DIRECT': RawImm(254)})