This commit is contained in:
George Hotz
2026-01-04 20:51:01 -08:00
parent 85b28faf33
commit f2b11010e8
3 changed files with 203 additions and 34 deletions

View File

@@ -568,8 +568,8 @@ def _op2dsl(op: str, arch: str = "rdna3") -> str:
if lo in spec_dsl: return wrap(spec_dsl[lo])
if op in FLOATS: return wrap(op)
rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]")
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}]")
if m := re.match(r'^([asvt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}:{m.group(3)}]")
if m := re.match(r'^([asvt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}]")
if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op
return wrap(op)
@@ -634,6 +634,11 @@ _CDNA_ALIASES = {
# VOP aliases (inverse of _CDNA_DISASM_ALIASES)
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32',
'v_mul_legacy_f32': 'v_fmac_f64', 'v_mac_f32': 'v_dot2c_f32_bf16', 'v_madmk_f32': 'v_fmamk_f32', 'v_madak_f32': 'v_fmaak_f32',
# VOPC: v_cmp_t_fXX -> v_cmp_tru_fXX for CDNA
'v_cmp_t_f16': 'v_cmp_tru_f16', 'v_cmp_t_f32': 'v_cmp_tru_f32', 'v_cmp_t_f64': 'v_cmp_tru_f64',
'v_cmpx_t_f16': 'v_cmpx_tru_f16', 'v_cmpx_t_f32': 'v_cmpx_tru_f32', 'v_cmpx_t_f64': 'v_cmpx_tru_f64',
# VOP1: flr/rpi -> floor/nearest for CDNA
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
}
def _apply_alias(text: str, arch: str = "rdna3") -> str:
@@ -655,7 +660,7 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]')
if m:
bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower()
is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot'))
is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot', 'v_mad_mix', 'v_fma_mix'))
opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \
(bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits))
m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None
@@ -670,7 +675,8 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]'); opsel_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]')
opsel_hi, opsel_hi_count = (sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))), len(m.group(1).split(','))) if m else (None, 0)
m, text = _extract(text, r'\s+gds(?:\s|$)'); gds = 1 if m else None
m, text = _extract(text, r'\s+offset0:(\d+)'); offset0 = m.group(1) if m else None
m, text = _extract(text, r'\s+offset1:(\d+)'); offset1 = m.group(1) if m else None
@@ -721,11 +727,12 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
# SDWA instructions (CDNA)
if mn.endswith('_sdwa') and arch == "cdna":
base_mn = mn[:-5] # strip _sdwa
# Get VOP1/VOP2 opcode
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, SDWA
# Get VOP1/VOP2/VOPC opcode
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOPCOp, SDWA
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
if vop1_op is None and vop2_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
vopc_op = getattr(VOPCOp, base_mn.upper(), None)
if vop1_op is None and vop2_op is None and vopc_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
# Parse operands: vdst, [vcc,] src0[, vsrc1]
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
vdst = args[0] # keep as v[N] for VGPRField
@@ -793,8 +800,85 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
# Build SDWA kwargs
# VOP1 SDWA: vop_op = VOP1 opcode, vop2_op = 0x3f (63)
# VOP2 SDWA: vop_op = vsrc1, vop2_op = VOP2 opcode
# VOPC SDWA: vop_op = src1, vop2_op = 0x3e (62), vdst = VOPC opcode, dst_sel/dst_u/clmp/omod = sdst encoding
sdwa_kw = []
if vop1_op is not None:
if vopc_op is not None:
# VOPC SDWA: opcode goes in vdst field, vop2_op=62
# Parse sdst from first operand (e.g., vcc, s[n:n+1], flat_scratch, ttmp[n:n+1])
_SDWA_SDST_MAP = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 128+102, 'flat_scratch_lo': 128+102,
'ttmp0': 128+108, 'ttmp2': 128+110, 'ttmp4': 128+112, 'ttmp6': 128+114,
'ttmp8': 128+116, 'ttmp10': 128+118, 'ttmp12': 128+120, 'ttmp14': 128+122}
sdst_raw = ops[0].strip().lower()
if sdst_raw in _SDWA_SDST_MAP: sdst_enc = _SDWA_SDST_MAP[sdst_raw]
elif sdst_raw.startswith('s[') and ':' in sdst_raw: sdst_enc = 128 + int(sdst_raw[2:].split(':')[0])
elif sdst_raw.startswith('s') and sdst_raw[1:].isdigit(): sdst_enc = 128 + int(sdst_raw[1:])
elif sdst_raw.startswith('ttmp[') and ':' in sdst_raw: sdst_enc = 128 + 108 + int(sdst_raw[5:].split(':')[0])
else: sdst_enc = 0 # Default: vcc
# For VOPC SDWA, src0 is ops[1], src1 is ops[2]
src0_raw = ops[1].strip().lower() if len(ops) > 1 else 'v0'
src1_raw = ops[2].strip().lower() if len(ops) > 2 else 'v0'
# Parse src0 with modifiers
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
if src0_neg_mod: src0_raw = src0_raw[1:]
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
if src0_abs_mod: src0_raw = src0_raw[1:-1]
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
if src0_sext_mod: src0_raw = src0_raw[5:-1]
# Extract src0 value and type
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
s0 = 0
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
s0 = 1
elif src0_raw in _SDWA_SGPR_MAP:
src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
elif src0_raw in _SDWA_INLINE_CONST:
src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
# Integer or float inline constant
if '.' in src0_raw:
src0_val = _SDWA_INLINE_CONST.get(src0_raw, 128)
s0 = 1
else:
ival = int(src0_raw)
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
else: src0_val, s0 = 0, 0
else: src0_val, s0 = 0, 0
# Parse src1 with modifiers
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
if src1_neg_mod: src1_raw = src1_raw[1:]
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
if src1_abs_mod: src1_raw = src1_raw[1:-1]
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
if src1_sext_mod: src1_raw = src1_raw[5:-1]
# Extract src1 value and type
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
s1 = 0
else: vsrc1_val, s1 = 0, 0
sdwa_kw.append(f'vop_op={vsrc1_val}')
sdwa_kw.append('vop2_op=62') # 0x3e indicates VOPC mode
sdwa_kw.append(f'vdst=RawImm({vopc_op.value})') # VOPC opcode in vdst
sdwa_kw.append(f'src0=RawImm({src0_val})')
# Encode sdst in dst_sel/dst_u/clmp/omod fields
sdwa_kw.append(f'dst_sel={sdst_enc & 7}')
sdwa_kw.append(f'dst_u={(sdst_enc >> 3) & 3}')
sdwa_kw.append(f'clmp={(sdst_enc >> 5) & 1}')
sdwa_kw.append(f'omod={(sdst_enc >> 6) & 3}')
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
if src0_sext_mod or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
if s0: sdwa_kw.append('s0=1')
if src1_sext_mod or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
if s1: sdwa_kw.append('s1=1')
return f"SDWA({', '.join(sdwa_kw)})"
elif vop1_op is not None:
sdwa_kw.append(f'vop_op={vop1_op.value}')
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
else:
@@ -902,14 +986,45 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
elif dst.startswith('ttmp') and dst[4:].isdigit(): dst_val = 108 + int(dst[4:])
else:
sgpr_map = {'vcc_lo': 106, 'vcc_hi': 107, 'm0': 124, 'exec_lo': 126, 'exec_hi': 127,
'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105}
'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105,
'null': 124} # null register for RDNA3
dst_val = sgpr_map.get(dst, int(dst) if dst.isdigit() else 0)
return f"v_readfirstlane_b32_e32(vdst=RawImm({dst_val}), src0={args[1]})"
if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})"
if mn in ('s_cbranch_join', 's_set_gpr_idx_idx'): return f"{mn}(ssrc0={args[0]}, sdst=RawImm(0))" # No destination, only source
if mn == 's_cbranch_g_fork': return f"{mn}(ssrc0={args[0]}, ssrc1={args[1]}, sdst=RawImm(0))" # Two sources, no dest
if mn == 's_set_gpr_idx_on': return f"{mn}(ssrc0={args[0]}, ssrc1=RawImm({int(args[1], 0)}))" # Mode bits as raw value
if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))"
if mn == 's_version': return f"{mn}(simm16={args[0]})"
if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})"
# SMEM: s_dcache_discard has swapped operand layout (saddr→sbase, soffset→sdata)
if arch == "cdna" and mn.startswith('s_dcache_discard'):
gs = ", glc=1" if glc else ""
# Syntax: s_dcache_discard saddr, soffset [offset:imm]
if off_val and len(ops) >= 2:
# SGPR + immediate offset: soe=1, imm=1, soffset=SGPR, offset=imm
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={off_val}, soffset={args[1]}, soe=1, imm=1{gs})"
if len(ops) >= 2 and re.match(r'^-?[0-9]|^-?0x', ops[1].strip().lower()):
# Immediate offset only: imm=1
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0), imm=1{gs})"
# SGPR offset only: imm=0, offset=SGPR
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0){gs})"
# SMEM: s_atomic_*/s_buffer_atomic_* uses offset field for SGPR (imm=0), not soffset
if arch == "cdna" and (mn.startswith('s_buffer_atomic') or (mn.startswith('s_atomic') and not mn.startswith('s_atc'))):
gs = ", glc=1" if glc else ""
if len(ops) >= 3:
# Syntax: s_atomic_* sdata, sbase, soffset [offset:imm]
if off_val:
# SGPR + immediate offset: soe=1, imm=1
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}, soe=1, imm=1{gs})"
if re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
# Immediate offset only: imm=1
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0), imm=1{gs})"
# SGPR offset only: imm=0, offset=SGPR
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs})"
# SMEM
if mn in SMEM_OPS or (arch == "cdna" and mn.startswith(('s_load_dword', 's_buffer_load_dword'))):
gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else ""
@@ -924,6 +1039,9 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
if len(ops) >= 3:
# SGPR offset only: offset=SGPR index, soffset=0
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs}{ds})"
if len(ops) == 2:
# No offset specified: imm=1, offset=0
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset=0, soffset=RawImm(0), imm=1{gs}{ds})"
else:
# RDNA3 encoding
if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
@@ -1003,12 +1121,17 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
return f"{mn}(vdst=v[0], addr={addr_val}, saddr={saddr_val}{flat_mods})"
# For scratch, 'off' as vaddr means vaddr=0 (no offset), not null register
# For load: args=[vdst, addr, saddr], for store: args=[addr, data, saddr]
# For RDNA3 scratch with 'off' as vaddr, set sve=0 (no VGPR address)
if 'store' in pre:
addr_val = 'v[0]' if seg == 'scratch' and args[0] == 'OFF' else args[0]
return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
addr_off = seg == 'scratch' and args[0] == 'OFF'
addr_val = 'v[0]' if addr_off else args[0]
sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else ''
return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})"
else:
addr_val = 'v[0]' if seg == 'scratch' and args[1] == 'OFF' else args[1]
return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
addr_off = seg == 'scratch' and args[1] == 'OFF'
addr_val = 'v[0]' if addr_off else args[1]
sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else ''
return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})"
for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'):
if mn.startswith(pre):
seg = pre.split('_')[0] # 'flat', 'global', or 'scratch'
@@ -1034,15 +1157,17 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'load' in mn: return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'load' in mn or ('read' in mn and 'read2' not in mn): return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'read2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'write2' in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
if 'xchg2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
if 'store' in mn and not _has(mn, 'cmp', 'xchg'):
return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
if _has(mn, 'cmpstore', 'mskor', 'wrap'):
if _has(mn, 'cmpst', 'mskor', 'wrap'):
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
@@ -1063,16 +1188,22 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]]
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
# For RDNA3 v_cmpx, destination is implicitly exec (126)
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2 and arch == 'rdna3': args = ['RawImm(126)'] + args
# v_cmp_*_e64 and v_cmpx_*_e64 have SGPR destination in vdst field - encode as RawImm
# For CDNA, v_cmpx also writes to SGPR pair (first operand)
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
if mn.startswith('v_cmp') and mn.endswith('_e64') and len(args) >= 1:
dst = ops[0].strip().lower()
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
# For CDNA v_cmpx with 3 operands (sdst, src0, src1), convert sdst to RawImm
# For RDNA3, v_cmpx only has 2 operands (src0, src1) - already handled above
is_cmpx = 'cmpx' in mn
if not is_cmpx or arch == 'cdna':
dst = ops[0].strip().lower()
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
fn = mn.replace('.', '_')
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
@@ -1096,23 +1227,52 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
all_kw = list(kw)
if lit_s: all_kw.append(lit_s.lstrip(', '))
if opsel is not None: all_kw.append(f'opsel={opsel}')
if opsel_hi is not None: all_kw.append(f'opsel_hi={opsel_hi & 3}'); all_kw.append(f'opsel_hi2={(opsel_hi >> 2) & 1}')
if opsel_hi is not None:
all_kw.append(f'opsel_hi={opsel_hi & 3}')
if opsel_hi_count >= 3: all_kw.append(f'opsel_hi2={(opsel_hi >> 2) & 1}') # only set opsel_hi2 if 3 elements specified
if neg_lo is not None: all_kw.append(f'neg={neg_lo}')
if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}')
if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1'])
# For CDNA _e64 VOP instructions: use keyword args (VOP3 layout)
# Pattern: v_xxx_e64 dst, src0[, src1[, src2]] -> v_xxx(vdst=dst, src0=src0[, src1=src1[, src2=src2]])
# For v_nop_e64 (no operands), add _vop3=True marker to force VOP3 encoding
# Pattern: v_xxx_e64 dst, src0[, src1[, src2]] -> VOP3A with promoted opcode
# VOP1 to VOP3 promotion: VOP3 op = 384 + (VOP1_op - 64) for VOP1_op >= 64, else 256 + VOP1_op
if fn.endswith('_e64') and fn.startswith('v_') and arch == "cdna":
fn_base = fn[:-4] # strip _e64
fn_base = fn[:-4].upper() # strip _e64 and uppercase for enum lookup
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOP3AOp, VOP3BOp
# Check if this is a VOP3B instruction (has sdst for carry-out)
vop3b_op = getattr(VOP3BOp, fn_base, None)
if vop3b_op is not None:
# VOP3B: v_xxx_e64 vdst, sdst, src0, src1[, src2]
vop3_args = []
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
if len(args) >= 2: vop3_args.append(f'sdst={args[1]}')
if len(args) >= 3: vop3_args.append(f'src0={args[2]}')
if len(args) >= 4: vop3_args.append(f'src1={args[3]}')
if len(args) >= 5: vop3_args.append(f'src2={args[4]}')
a_str = ', '.join(vop3_args + all_kw)
return f"{fn[:-4]}({a_str})"
# Check if this is a VOP1 instruction that needs promotion
vop1_op = getattr(VOP1Op, fn_base, None)
vop2_op = getattr(VOP2Op, fn_base, None)
vop3a_op = getattr(VOP3AOp, fn_base, None)
if vop1_op is not None and vop3a_op is None:
# VOP1 -> VOP3 promotion: calculate promoted opcode
promoted_op = 384 + (vop1_op.value - 64) if vop1_op.value >= 64 else 256 + vop1_op.value
vop3_args = [f'op={promoted_op}']
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
if len(args) >= 2: vop3_args.append(f'src0={args[1]}')
if len(args) >= 3: vop3_args.append(f'src1={args[2]}')
if len(args) >= 4: vop3_args.append(f'src2={args[3]}')
return f"VOP3A({', '.join(vop3_args + all_kw)})"
# Otherwise try normal VOP3 lookup
vop3_args = ['_vop3=True'] # marker for asm() to force VOP3
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
if len(args) >= 2: vop3_args.append(f'src0={args[1]}')
if len(args) >= 3: vop3_args.append(f'src1={args[2]}')
if len(args) >= 4: vop3_args.append(f'src2={args[3]}')
a_str = ', '.join(vop3_args + all_kw)
return f"{fn_base}({a_str})"
return f"{fn[:-4]}({a_str})"
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"

View File

@@ -305,7 +305,10 @@ class Inst:
if isinstance(val, SrcMod):
mod_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)
if val.neg and 'neg' in self._fields: self._or_field('neg', mod_bit)
if val.abs_ and 'abs' in self._fields: self._or_field('abs', mod_bit)
# abs can be in 'abs' field (VOP3A) or 'neg_hi' field (VOP3P uses neg_hi for abs)
if val.abs_:
if 'abs' in self._fields: self._or_field('abs', mod_bit)
elif 'neg_hi' in self._fields and 'abs' not in self._fields: self._or_field('neg_hi', mod_bit) # VOP3P uses neg_hi for abs
if isinstance(val, Reg) and val.hi and has_opsel:
self._or_field('opsel', {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0))
# Track literal value if needed
@@ -371,7 +374,9 @@ class Inst:
# Format-specific setup
if cls_name == 'FLAT' and 'sve' in self._fields:
seg = self._values.get('seg', 0)
if (seg.val if isinstance(seg, RawImm) else seg) == 1 and isinstance(orig_args.get('addr'), VGPR): self._values['sve'] = 1
# Only auto-set sve=1 if not explicitly passed and conditions match (seg=1/scratch, addr is VGPR)
if 'sve' not in orig_args and (seg.val if isinstance(seg, RawImm) else seg) == 1 and isinstance(orig_args.get('addr'), VGPR):
self._values['sve'] = 1
if cls_name == 'VOP3P':
op = orig_args.get('op')
if hasattr(op, 'value'): op = op.value

View File

@@ -69,13 +69,17 @@ def _make_test(f: str, arch: str, test_type: str):
self.assertEqual(decoded.to_bytes()[:len(data)], data)
print(f"{name}: {len(tests)} passed")
elif test_type == "asm":
passed, skipped = 0, 0
passed, failed, skipped = 0, 0, 0
for asm_text, expected in tests:
try:
self.assertEqual(asm(asm_text).to_bytes(), expected)
passed += 1
except: skipped += 1
print(f"{name}: {passed} passed, {skipped} skipped")
result = asm(asm_text, arch=arch)
if result.to_bytes() == expected:
passed += 1
else:
failed += 1
except:
skipped += 1
print(f"{name}: {passed} passed, {failed} failed, {skipped} skipped")
elif test_type == "disasm":
to_test = []
for _, data in tests: