This commit is contained in:
George Hotz
2026-01-05 08:20:03 -08:00
parent eaa5a05f3d
commit bb5103fdb0

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import re
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_GPRS_CDNA, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
from extra.assembly.amd.autogen.rdna3 import ins
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp)
@@ -605,6 +605,40 @@ def _extract(text: str, pat: str, flags=re.I):
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
return None, text
def _parse_src_mods(raw: str) -> tuple[str, bool, bool, bool]:
"""Parse neg/abs/sext modifiers from operand string. Returns (stripped_op, neg, abs_, sext)."""
neg = raw.startswith('-') and not raw[1:2].isdigit() and raw[1:3] != '0.'
if neg: raw = raw[1:]
abs_ = raw.startswith('|') and raw.endswith('|')
if abs_: raw = raw[1:-1]
sext = raw.startswith('sext(') and raw.endswith(')')
if sext: raw = raw[5:-1]
return raw, neg, abs_, sext
def _parse_sdwa_src(raw: str) -> tuple[int, int]:
"""Parse SDWA source operand. Returns (value, s_flag) where s_flag=1 for SGPR/literal."""
# VGPRs: v0, v[0]
if raw.startswith('v') and (raw[1:].isdigit() or raw[1] == '['):
return int(raw[1:].split('[')[0]) if raw[1:].isdigit() else int(raw.split('[')[1].split(']')[0]), 0
# SGPRs: s0, s[0], s[0:1]
if raw.startswith('s') and (raw[1:].isdigit() or raw[1] == '['):
return int(raw[1:].split('[')[0]) if raw[1:].isdigit() else int(raw.split('[')[1].split(':')[0]), 1
# TTMPs: ttmp0, ttmp[0]
if raw.startswith('ttmp') and raw[4:].isdigit(): return 108 + int(raw[4:]), 1
# Special registers from SPECIAL_GPRS_CDNA (reverse lookup)
_SGPR_REV = {v: k for k, v in SPECIAL_GPRS_CDNA.items()}
_SGPR_REV.update({'src_vccz': 251, 'src_execz': 252, 'src_scc': 253, 'vcc': 106}) # extras not in SPECIAL_GPRS_CDNA
if raw in _SGPR_REV: return _SGPR_REV[raw], 1
# Inline constants: integers 0-64 -> 128+N, -1 to -16 -> 192+abs(N), floats use FLOAT_ENC
if raw.lstrip('-').replace('.', '', 1).isdigit():
if '.' in raw:
try: return FLOAT_ENC.get(float(raw), 128), 1
except ValueError: return 128, 1
ival = int(raw)
if 0 <= ival <= 64: return 128 + ival, 1
if -16 <= ival < 0: return 192 + (-ival), 1
return 0, 0
# Instruction aliases: LLVM uses different names for some instructions
_ALIASES = {
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
@@ -716,13 +750,12 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
m, text = _extract(text, r'\s+blgp:(\d+)'); blgp = int(m.group(1)) if m else None
# MFMA neg:[x,y,z] modifier -> sets neg field (same as blgp for MFMA)
m, text = _extract(text, r'\s+neg:\[([^\]]+)\]'); mfma_neg = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
# SDWA modifiers
_SDWA_SEL = {'BYTE_0': 0, 'BYTE_1': 1, 'BYTE_2': 2, 'BYTE_3': 3, 'WORD_0': 4, 'WORD_1': 5, 'DWORD': 6}
_SDWA_DST_UNUSED = {'UNUSED_PAD': 0, 'UNUSED_SEXT': 1, 'UNUSED_PRESERVE': 2}
m, text = _extract(text, r'\s+dst_sel:(\w+)'); sdwa_dst_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
m, text = _extract(text, r'\s+dst_unused:(\w+)'); sdwa_dst_unused = _SDWA_DST_UNUSED.get(m.group(1), 0) if m else None
m, text = _extract(text, r'\s+src0_sel:(\w+)'); sdwa_src0_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
m, text = _extract(text, r'\s+src1_sel:(\w+)'); sdwa_src1_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
# SDWA modifiers: sel values are BYTE_0-3=0-3, WORD_0-1=4-5, DWORD=6; dst_unused PAD=0, SEXT=1, PRESERVE=2
def _sel(s): return {'BYTE_0': 0, 'BYTE_1': 1, 'BYTE_2': 2, 'BYTE_3': 3, 'WORD_0': 4, 'WORD_1': 5, 'DWORD': 6}.get(s, 6)
m, text = _extract(text, r'\s+dst_sel:(\w+)'); sdwa_dst_sel = _sel(m.group(1)) if m else None
m, text = _extract(text, r'\s+dst_unused:(\w+)'); sdwa_dst_unused = {'UNUSED_PAD': 0, 'UNUSED_SEXT': 1, 'UNUSED_PRESERVE': 2}.get(m.group(1), 0) if m else None
m, text = _extract(text, r'\s+src0_sel:(\w+)'); sdwa_src0_sel = _sel(m.group(1)) if m else None
m, text = _extract(text, r'\s+src1_sel:(\w+)'); sdwa_src1_sel = _sel(m.group(1)) if m else None
m, text = _extract(text, r'\s+sext\(src0\)'); sdwa_src0_sext = 1 if m else None
m, text = _extract(text, r'\s+sext\(src1\)'); sdwa_src1_sext = 1 if m else None
# DPP modifiers
@@ -765,40 +798,37 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
from extra.assembly.amd.autogen.cdna.ins import VOP3POp, VOP1Op
# Handle aliases: v_accvgpr_read_b32 -> v_accvgpr_read, v_accvgpr_write_b32 -> v_accvgpr_write
fn = mn.replace('_b32', '').upper()
# MFMA instruction name mapping: LLVM names to enum names
# LLVM: v_mfma_f32_32x32x1f32 -> Enum: V_MFMA_F32_32X32X1_2B_F32
# NOTE: gfx90a uses different opcodes than gfx942 for some instructions
_MFMA_ALIASES = {
'V_MFMA_F32_32X32X1F32': 'V_MFMA_F32_32X32X1_2B_F32', 'V_MFMA_F32_16X16X1F32': 'V_MFMA_F32_16X16X1_4B_F32',
'V_MFMA_F32_4X4X1F32': 'V_MFMA_F32_4X4X1_16B_F32', 'V_MFMA_F32_32X32X2F32': 'V_MFMA_F32_32X32X2_F32',
'V_MFMA_F32_16X16X4F32': 'V_MFMA_F32_16X16X4_F32',
'V_MFMA_F32_32X32X4F16': 'V_MFMA_F32_32X32X4_2B_F16', 'V_MFMA_F32_16X16X4F16': 'V_MFMA_F32_16X16X4_4B_F16',
'V_MFMA_F32_4X4X4F16': 'V_MFMA_F32_4X4X4_16B_F16', 'V_MFMA_F32_32X32X8F16': 'V_MFMA_F32_32X32X8_F16',
'V_MFMA_F32_16X16X16F16': 'V_MFMA_F32_16X16X16_F16',
'V_MFMA_I32_32X32X4I8': 'V_MFMA_I32_32X32X4_2B_I8', 'V_MFMA_I32_16X16X4I8': 'V_MFMA_I32_16X16X4_4B_I8',
'V_MFMA_I32_4X4X4I8': 'V_MFMA_I32_4X4X4_16B_I8',
'V_MFMA_F64_16X16X4F64': 'V_MFMA_F64_16X16X4_F64', 'V_MFMA_F64_4X4X4F64': 'V_MFMA_F64_4X4X4_4B_F64',
# GFX942-specific aliases
'V_MFMA_I32_32X32X16I8': 'V_MFMA_I32_32X32X16_I8', 'V_MFMA_I32_16X16X32I8': 'V_MFMA_I32_16X16X32_I8',
'V_MFMA_F32_32X32X4BF16': 'V_MFMA_F32_32X32X4_2B_BF16', 'V_MFMA_F32_16X16X4BF16': 'V_MFMA_F32_16X16X4_4B_BF16',
'V_MFMA_F32_4X4X4BF16': 'V_MFMA_F32_4X4X4_16B_BF16',
'V_MFMA_F32_32X32X8BF16': 'V_MFMA_F32_32X32X8_BF16', 'V_MFMA_F32_16X16X16BF16': 'V_MFMA_F32_16X16X16_BF16',
# _1K variants map to same opcodes as non-1K (the _1K is for compatibility)
'V_MFMA_F32_32X32X4BF16_1K': 'V_MFMA_F32_32X32X4_2B_BF16', 'V_MFMA_F32_16X16X4BF16_1K': 'V_MFMA_F32_16X16X4_4B_BF16',
'V_MFMA_F32_4X4X4BF16_1K': 'V_MFMA_F32_4X4X4_16B_BF16',
'V_MFMA_F32_32X32X8BF16_1K': 'V_MFMA_F32_32X32X8_BF16', 'V_MFMA_F32_16X16X16BF16_1K': 'V_MFMA_F32_16X16X16_BF16',
# XF32 aliases
'V_MFMA_F32_16X16X8XF32': 'V_MFMA_F32_16X16X8_XF32', 'V_MFMA_F32_32X32X4XF32': 'V_MFMA_F32_32X32X4_XF32',
# SMFMAC aliases (LLVM name -> enum name)
'V_SMFMAC_F32_16X16X32F16': 'V_SMFMAC_F32_16X16X32_F16', 'V_SMFMAC_F32_32X32X16F16': 'V_SMFMAC_F32_32X32X16_F16',
'V_SMFMAC_F32_16X16X32BF16': 'V_SMFMAC_F32_16X16X32_BF16', 'V_SMFMAC_F32_32X32X16BF16': 'V_SMFMAC_F32_32X32X16_BF16',
'V_SMFMAC_I32_16X16X64I8': 'V_SMFMAC_I32_16X16X64_I8', 'V_SMFMAC_I32_32X32X32I8': 'V_SMFMAC_I32_32X32X32_I8',
# FP8/BF8 SMFMAC aliases
'V_SMFMAC_F32_16X16X64BF8BF8': 'V_SMFMAC_F32_16X16X64_BF8_BF8', 'V_SMFMAC_F32_16X16X64BF8FP8': 'V_SMFMAC_F32_16X16X64_BF8_FP8',
'V_SMFMAC_F32_16X16X64FP8BF8': 'V_SMFMAC_F32_16X16X64_FP8_BF8', 'V_SMFMAC_F32_16X16X64FP8FP8': 'V_SMFMAC_F32_16X16X64_FP8_FP8',
'V_SMFMAC_F32_32X32X32BF8BF8': 'V_SMFMAC_F32_32X32X32_BF8_BF8', 'V_SMFMAC_F32_32X32X32BF8FP8': 'V_SMFMAC_F32_32X32X32_BF8_FP8',
'V_SMFMAC_F32_32X32X32FP8BF8': 'V_SMFMAC_F32_32X32X32_FP8_BF8', 'V_SMFMAC_F32_32X32X32FP8FP8': 'V_SMFMAC_F32_32X32X32_FP8_FP8',
}
# MFMA/SMFMAC name mapping: LLVM v_mfma_f32_32x32x1f32 -> enum V_MFMA_F32_32X32X1_2B_F32
def _mfma_alias(n):
n = n.replace('_1K', '') # Strip _1K suffix first (same opcodes)
# FP8/BF8 pairs: insert underscore between AND before: 64BF8BF8 -> 64_BF8_BF8
for t in ('BF8BF8', 'BF8FP8', 'FP8BF8', 'FP8FP8'):
if t in n: n = n.replace(t, '_' + t[:3] + '_' + t[3:])
# Insert underscore before dtype suffix
for t in ('F32', 'F16', 'BF16', 'I8', 'F64', 'XF32'):
n = re.sub(rf'(X\d+)({t})$', rf'\1_{t}', n)
# Block sizes for specific shapes
for pat, blk in [('32X32X1_', '32X32X1_2B_'), ('16X16X1_', '16X16X1_4B_'), ('4X4X1_', '4X4X1_16B_'),
('32X32X4_F16', '32X32X4_2B_F16'), ('16X16X4_F16', '16X16X4_4B_F16'), ('4X4X4_F16', '4X4X4_16B_F16'),
('32X32X4_I8', '32X32X4_2B_I8'), ('16X16X4_I8', '16X16X4_4B_I8'), ('4X4X4_I8', '4X4X4_16B_I8'),
('32X32X4_BF16', '32X32X4_2B_BF16'), ('16X16X4_BF16', '16X16X4_4B_BF16'), ('4X4X4_BF16', '4X4X4_16B_BF16'),
('4X4X4_F64', '4X4X4_4B_F64')]:
n = n.replace(pat, blk)
return n
_MFMA_ALIASES = {n: _mfma_alias(n) for n in [
'V_MFMA_F32_32X32X1F32', 'V_MFMA_F32_16X16X1F32', 'V_MFMA_F32_4X4X1F32', 'V_MFMA_F32_32X32X2F32', 'V_MFMA_F32_16X16X4F32',
'V_MFMA_F32_32X32X4F16', 'V_MFMA_F32_16X16X4F16', 'V_MFMA_F32_4X4X4F16', 'V_MFMA_F32_32X32X8F16', 'V_MFMA_F32_16X16X16F16',
'V_MFMA_F32_32X32X16F16', 'V_MFMA_F32_16X16X32F16',
'V_MFMA_I32_32X32X4I8', 'V_MFMA_I32_16X16X4I8', 'V_MFMA_I32_4X4X4I8', 'V_MFMA_I32_32X32X16I8', 'V_MFMA_I32_16X16X32I8',
'V_MFMA_F64_16X16X4F64', 'V_MFMA_F64_4X4X4F64',
'V_MFMA_F32_32X32X4BF16', 'V_MFMA_F32_16X16X4BF16', 'V_MFMA_F32_4X4X4BF16', 'V_MFMA_F32_32X32X8BF16', 'V_MFMA_F32_16X16X16BF16',
'V_MFMA_F32_32X32X16BF16', 'V_MFMA_F32_16X16X32BF16',
'V_MFMA_F32_32X32X4BF16_1K', 'V_MFMA_F32_16X16X4BF16_1K', 'V_MFMA_F32_4X4X4BF16_1K', 'V_MFMA_F32_32X32X8BF16_1K', 'V_MFMA_F32_16X16X16BF16_1K',
'V_MFMA_F32_16X16X8XF32', 'V_MFMA_F32_32X32X4XF32',
'V_SMFMAC_F32_16X16X32F16', 'V_SMFMAC_F32_32X32X16F16', 'V_SMFMAC_F32_16X16X32BF16', 'V_SMFMAC_F32_32X32X16BF16',
'V_SMFMAC_I32_16X16X64I8', 'V_SMFMAC_I32_32X32X32I8',
'V_SMFMAC_F32_16X16X64BF8BF8', 'V_SMFMAC_F32_16X16X64BF8FP8', 'V_SMFMAC_F32_16X16X64FP8BF8', 'V_SMFMAC_F32_16X16X64FP8FP8',
'V_SMFMAC_F32_32X32X32BF8BF8', 'V_SMFMAC_F32_32X32X32BF8FP8', 'V_SMFMAC_F32_32X32X32FP8BF8', 'V_SMFMAC_F32_32X32X32FP8FP8']}
# GFX90a-specific opcodes (different from gfx942 enum): map to raw opcode values
# These instructions use opcodes that don't match our gfx942-based enum
_MFMA_GFX90A_OPS = {
@@ -883,242 +913,82 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
# SDWA instructions (CDNA)
if mn.endswith('_sdwa') and arch == "cdna":
base_mn = mn[:-5] # strip _sdwa
# Get VOP1/VOP2/VOPC opcode
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOPCOp, SDWA
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
vopc_op = getattr(VOPCOp, base_mn.upper(), None)
vop1_op, vop2_op, vopc_op = getattr(VOP1Op, base_mn.upper(), None), getattr(VOP2Op, base_mn.upper(), None), getattr(VOPCOp, base_mn.upper(), None)
if vop1_op is None and vop2_op is None and vopc_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
# Parse operands: vdst, [vcc,] src0[, vsrc1]
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
vdst = args[0] # keep as v[N] for VGPRField
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
has_carry = base_mn in carry_out_ops
src0_idx = 2 if has_carry else 1
src1_idx = 3 if has_carry else 2
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
src0 = args[1] if len(args) > 1 else 'v[0]'
# Parse neg/abs/sext modifiers from src0_raw
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit() and src0_raw[1:3] != '0.'
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
if src0_neg_mod: src0_raw = src0_raw[1:]
if src0_abs_mod: src0_raw = src0_raw[1:-1]
if src0_sext_mod: src0_raw = src0_raw[5:-1]
# Extract src0 register number for RawImm
_SDWA_SGPR_MAP = {'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105,
'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'ttmp0': 108, 'ttmp1': 109, 'ttmp2': 110, 'ttmp3': 111,
'ttmp4': 112, 'ttmp5': 113, 'ttmp6': 114, 'ttmp7': 115, 'ttmp8': 116, 'ttmp9': 117,
'ttmp10': 118, 'ttmp11': 119, 'ttmp12': 120, 'ttmp13': 121, 'ttmp14': 122, 'ttmp15': 123,
'm0': 124, 'exec_lo': 126, 'exec_hi': 127,
'src_vccz': 251, 'src_execz': 252, 'src_scc': 253}
# Inline constant encoding for SDWA src0
_SDWA_INLINE_CONST = {'0': 128, '0.0': 128, '1': 129, '1.0': 242, '2': 130, '3': 131, '4': 132, '-1': 193, '-2': 194, '-3': 195, '-4': 196,
'0.5': 240, '-0.5': 241, '-1.0': 243, '2.0': 244, '-2.0': 245, '4.0': 246, '-4.0': 247}
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
s0 = 0
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
s0 = 1
elif src0_raw in _SDWA_SGPR_MAP: src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
elif src0_raw.startswith('ttmp') and src0_raw[4:].isdigit(): src0_val, s0 = 108 + int(src0_raw[4:]), 1
elif src0_raw in _SDWA_INLINE_CONST: src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
# Integer or float inline constant
if '.' in src0_raw:
src0_val, s0 = _SDWA_INLINE_CONST.get(src0_raw, (0, 0))
if src0_val == 0 and src0_raw != '0.0': s0 = 0
else: s0 = 1
else:
ival = int(src0_raw)
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
else: src0_val, s0 = 0, 0 # Not an inline constant
else: src0_val, s0 = 0, 0
# For VOP2, parse vsrc1 and its modifiers
vsrc1_val, src1_neg_mod, src1_abs_mod, src1_sext_mod, s1 = 0, False, False, False, 0
if vop2_op is not None and len(ops) > src1_idx:
src1_raw = ops[src1_idx].strip().lower()
# Parse neg/abs/sext modifiers
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit() and src1_raw[1:3] != '0.'
if src1_neg_mod: src1_raw = src1_raw[1:]
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
if src1_abs_mod: src1_raw = src1_raw[1:-1]
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
if src1_sext_mod: src1_raw = src1_raw[5:-1]
# Extract vsrc1 register number
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
s1 = 0
elif src1_raw in _SDWA_SGPR_MAP: vsrc1_val, s1 = _SDWA_SGPR_MAP[src1_raw], 1
elif src1_raw in _SDWA_INLINE_CONST: vsrc1_val, s1 = _SDWA_INLINE_CONST[src1_raw], 1
# Build SDWA kwargs
# VOP1 SDWA: vop_op = VOP1 opcode, vop2_op = 0x3f (63)
# VOP2 SDWA: vop_op = vsrc1, vop2_op = VOP2 opcode
# VOPC SDWA: vop_op = src1, vop2_op = 0x3e (62), vdst = VOPC opcode, dst_sel/dst_u/clmp/omod = sdst encoding
sdwa_kw = []
# Operand layout: vdst, [vcc,] src0[, vsrc1] - carry-out ops have vcc at index 1
carry_out = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
src0_idx, src1_idx = (2, 3) if base_mn in carry_out else (1, 2)
# VOPC SDWA: sdst at [0], src0 at [1], src1 at [2]
if vopc_op is not None:
# VOPC SDWA: opcode goes in vdst field, vop2_op=62
# Parse sdst from first operand (e.g., vcc, s[n:n+1], flat_scratch, ttmp[n:n+1])
_SDWA_SDST_MAP = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 128+102, 'flat_scratch_lo': 128+102,
'ttmp0': 128+108, 'ttmp2': 128+110, 'ttmp4': 128+112, 'ttmp6': 128+114,
'ttmp8': 128+116, 'ttmp10': 128+118, 'ttmp12': 128+120, 'ttmp14': 128+122}
_SDWA_SDST = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 230, 'flat_scratch_lo': 230}
sdst_raw = ops[0].strip().lower()
if sdst_raw in _SDWA_SDST_MAP: sdst_enc = _SDWA_SDST_MAP[sdst_raw]
elif sdst_raw.startswith('s[') and ':' in sdst_raw: sdst_enc = 128 + int(sdst_raw[2:].split(':')[0])
elif sdst_raw.startswith('s') and sdst_raw[1:].isdigit(): sdst_enc = 128 + int(sdst_raw[1:])
elif sdst_raw.startswith('ttmp[') and ':' in sdst_raw: sdst_enc = 128 + 108 + int(sdst_raw[5:].split(':')[0])
else: sdst_enc = 0 # Default: vcc
# For VOPC SDWA, src0 is ops[1], src1 is ops[2]
src0_raw = ops[1].strip().lower() if len(ops) > 1 else 'v0'
src1_raw = ops[2].strip().lower() if len(ops) > 2 else 'v0'
# Parse src0 with modifiers
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
if src0_neg_mod: src0_raw = src0_raw[1:]
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
if src0_abs_mod: src0_raw = src0_raw[1:-1]
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
if src0_sext_mod: src0_raw = src0_raw[5:-1]
# Extract src0 value and type
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
s0 = 0
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
s0 = 1
elif src0_raw in _SDWA_SGPR_MAP:
src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
elif src0_raw in _SDWA_INLINE_CONST:
src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
# Integer or float inline constant
if '.' in src0_raw:
src0_val = _SDWA_INLINE_CONST.get(src0_raw, 128)
s0 = 1
else:
ival = int(src0_raw)
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
else: src0_val, s0 = 0, 0
else: src0_val, s0 = 0, 0
# Parse src1 with modifiers
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
if src1_neg_mod: src1_raw = src1_raw[1:]
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
if src1_abs_mod: src1_raw = src1_raw[1:-1]
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
if src1_sext_mod: src1_raw = src1_raw[5:-1]
# Extract src1 value and type
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
s1 = 0
else: vsrc1_val, s1 = 0, 0
sdwa_kw.append(f'vop_op={vsrc1_val}')
sdwa_kw.append('vop2_op=62') # 0x3e indicates VOPC mode
sdwa_kw.append(f'vdst=RawImm({vopc_op.value})') # VOPC opcode in vdst
sdwa_kw.append(f'src0=RawImm({src0_val})')
# Encode sdst in dst_sel/dst_u/clmp/omod fields
sdwa_kw.append(f'dst_sel={sdst_enc & 7}')
sdwa_kw.append(f'dst_u={(sdst_enc >> 3) & 3}')
sdwa_kw.append(f'clmp={(sdst_enc >> 5) & 1}')
sdwa_kw.append(f'omod={(sdst_enc >> 6) & 3}')
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
if src0_sext_mod or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
sdst_enc = _SDWA_SDST.get(sdst_raw, 128 + int(sdst_raw[2:].split(':')[0]) if sdst_raw.startswith('s[') else
128 + int(sdst_raw[1:]) if sdst_raw.startswith('s') and sdst_raw[1:].isdigit() else
128 + 108 + int(sdst_raw[5:].split(':')[0]) if sdst_raw.startswith('ttmp[') else 0)
src0_raw, src0_neg, src0_abs, src0_sext = _parse_src_mods(ops[1].strip().lower() if len(ops) > 1 else 'v0')
src1_raw, src1_neg, src1_abs, src1_sext = _parse_src_mods(ops[2].strip().lower() if len(ops) > 2 else 'v0')
src0_val, s0 = _parse_sdwa_src(src0_raw)
vsrc1_val, s1 = _parse_sdwa_src(src1_raw)
sdwa_kw = [f'vop_op={vsrc1_val}', 'vop2_op=62', f'vdst=RawImm({vopc_op.value})', f'src0=RawImm({src0_val})',
f'dst_sel={sdst_enc & 7}', f'dst_u={(sdst_enc >> 3) & 3}', f'clmp={(sdst_enc >> 5) & 1}', f'omod={(sdst_enc >> 6) & 3}',
f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}', f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}']
if src0_sext or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
if src0_neg: sdwa_kw.append('src0_neg=1')
if src0_abs: sdwa_kw.append('src0_abs=1')
if s0: sdwa_kw.append('s0=1')
if src1_sext_mod or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
if src1_sext or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
if src1_neg: sdwa_kw.append('src1_neg=1')
if src1_abs: sdwa_kw.append('src1_abs=1')
if s1: sdwa_kw.append('s1=1')
return f"SDWA({', '.join(sdwa_kw)})"
elif vop1_op is not None:
sdwa_kw.append(f'vop_op={vop1_op.value}')
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
else:
sdwa_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 SDWA
sdwa_kw.append(f'vop2_op={vop2_op.value}')
sdwa_kw.append(f'vdst={vdst}')
sdwa_kw.append(f'src0=RawImm({src0_val})')
# Defaults: dst_sel=6 (DWORD), dst_unused=2 (UNUSED_PRESERVE), src0_sel=6 (DWORD), src1_sel=6 (DWORD)
sdwa_kw.append(f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}')
sdwa_kw.append(f'dst_u={sdwa_dst_unused if sdwa_dst_unused is not None else 2}')
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
if sdwa_src0_sext or src0_sext_mod: sdwa_kw.append('src0_sext=1')
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
# VOP1/VOP2 SDWA
src0_raw, src0_neg, src0_abs, src0_sext = _parse_src_mods(ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0')
src0_val, s0 = _parse_sdwa_src(src0_raw)
vsrc1_val, src1_neg, src1_abs, src1_sext, s1 = 0, False, False, False, 0
if vop2_op is not None and len(ops) > src1_idx:
src1_raw, src1_neg, src1_abs, src1_sext = _parse_src_mods(ops[src1_idx].strip().lower())
vsrc1_val, s1 = _parse_sdwa_src(src1_raw)
sdwa_kw = [f'vop_op={vop1_op.value if vop1_op else vsrc1_val}', f'vop2_op={63 if vop1_op else vop2_op.value}',
f'vdst={args[0]}', f'src0=RawImm({src0_val})', f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}',
f'dst_u={sdwa_dst_unused if sdwa_dst_unused is not None else 2}', f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}']
if src0_sext or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
if src0_neg: sdwa_kw.append('src0_neg=1')
if src0_abs: sdwa_kw.append('src0_abs=1')
if s0: sdwa_kw.append('s0=1')
# VOP2 SDWA src1 modifiers and defaults
if vop2_op is not None:
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
if sdwa_src1_sext or src1_sext_mod: sdwa_kw.append('src1_sext=1')
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
if src1_sext or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
if src1_neg: sdwa_kw.append('src1_neg=1')
if src1_abs: sdwa_kw.append('src1_abs=1')
if s1: sdwa_kw.append('s1=1')
# Add clamp/omod from kw if present
for k in kw:
if k.startswith('clmp='): sdwa_kw.append(k)
elif k.startswith('omod='): sdwa_kw.append(k)
if k.startswith('clmp=') or k.startswith('omod='): sdwa_kw.append(k)
return f"SDWA({', '.join(sdwa_kw)})"
# DPP instructions (CDNA)
if mn.endswith('_dpp') and arch == "cdna" and dpp_ctrl is not None:
base_mn = mn[:-4] # strip _dpp
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, DPP
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
vop1_op, vop2_op = getattr(VOP1Op, base_mn.upper(), None), getattr(VOP2Op, base_mn.upper(), None)
if vop1_op is None and vop2_op is None: raise ValueError(f"unknown DPP instruction: {mn}")
# Parse operands: vdst, [vcc,] src0[, vsrc1]
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
vdst = args[0]
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
has_carry = base_mn in carry_out_ops
src0_idx = 2 if has_carry else 1
src1_idx = 3 if has_carry else 2
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
# Parse neg/abs modifiers for src0 (neg before abs for -|v1| case)
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
if src0_neg_mod: src0_raw = src0_raw[1:]
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
if src0_abs_mod: src0_raw = src0_raw[1:-1]
# Extract src0 VGPR number
if src0_raw.startswith('v') and src0_raw[1:].isdigit(): src0_val = int(src0_raw[1:])
elif 'v[' in src0_raw: src0_val = int(src0_raw.split('[')[1].split(']')[0])
else: src0_val = 0
# For VOP2, parse vsrc1 and its modifiers
vsrc1_val, src1_neg_mod, src1_abs_mod = 0, False, False
carry_out = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
src0_idx, src1_idx = (2, 3) if base_mn in carry_out else (1, 2)
src0_raw, src0_neg, src0_abs, _ = _parse_src_mods(ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0')
src0_val = int(src0_raw[1:]) if src0_raw.startswith('v') and src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0]) if 'v[' in src0_raw else 0
vsrc1_val, src1_neg, src1_abs = 0, False, False
if vop2_op is not None and len(ops) > src1_idx:
src1_raw = ops[src1_idx].strip().lower()
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
if src1_neg_mod: src1_raw = src1_raw[1:]
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
if src1_abs_mod: src1_raw = src1_raw[1:-1]
if src1_raw.startswith('v') and src1_raw[1:].isdigit(): vsrc1_val = int(src1_raw[1:])
elif 'v[' in src1_raw: vsrc1_val = int(src1_raw.split('[')[1].split(']')[0])
# Build DPP kwargs
# VOP1 DPP: vop_op = VOP1 opcode, vop2_op = 0x3f
# VOP2 DPP: vop_op = vsrc1, vop2_op = VOP2 opcode
dpp_kw = []
if vop1_op is not None:
dpp_kw.append(f'vop_op={vop1_op.value}')
dpp_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
else:
dpp_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 DPP
dpp_kw.append(f'vop2_op={vop2_op.value}')
dpp_kw.append(f'vdst={vdst}')
dpp_kw.append(f'src0=RawImm({src0_val})')
dpp_kw.append(f'dpp_ctrl={dpp_ctrl}')
src1_raw, src1_neg, src1_abs, _ = _parse_src_mods(ops[src1_idx].strip().lower())
vsrc1_val = int(src1_raw[1:]) if src1_raw.startswith('v') and src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0]) if 'v[' in src1_raw else 0
dpp_kw = [f'vop_op={vop1_op.value if vop1_op else vsrc1_val}', f'vop2_op={63 if vop1_op else vop2_op.value}',
f'vdst={args[0]}', f'src0=RawImm({src0_val})', f'dpp_ctrl={dpp_ctrl}']
if dpp_bound_ctrl: dpp_kw.append('bound_ctrl=1')
if src0_neg_mod: dpp_kw.append('src0_neg=1')
if src0_abs_mod: dpp_kw.append('src0_abs=1')
if src1_neg_mod: dpp_kw.append('src1_neg=1')
if src1_abs_mod: dpp_kw.append('src1_abs=1')
# Default masks: if one is specified but not the other, the other defaults to 0xf
if src0_neg: dpp_kw.append('src0_neg=1')
if src0_abs: dpp_kw.append('src0_abs=1')
if src1_neg: dpp_kw.append('src1_neg=1')
if src1_abs: dpp_kw.append('src1_abs=1')
if dpp_bank_mask_specified or dpp_row_mask_specified:
dpp_kw.append(f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 0xf}')
dpp_kw.append(f'row_mask={dpp_row_mask if dpp_row_mask is not None else 0xf}')
dpp_kw.extend([f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 0xf}', f'row_mask={dpp_row_mask if dpp_row_mask is not None else 0xf}'])
return f"DPP({', '.join(dpp_kw)})"
# VOPD (RDNA3 only)