diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index b6070ccce2..64fe21348e 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -568,8 +568,8 @@ def _op2dsl(op: str, arch: str = "rdna3") -> str: if lo in spec_dsl: return wrap(spec_dsl[lo]) if op in FLOATS: return wrap(op) rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'} - if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]") - if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}]") + if m := re.match(r'^([asvt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}:{m.group(3)}]") + if m := re.match(r'^([asvt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}]") if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op return wrap(op) @@ -634,6 +634,11 @@ _CDNA_ALIASES = { # VOP aliases (inverse of _CDNA_DISASM_ALIASES) 'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32', 'v_mul_legacy_f32': 'v_fmac_f64', 'v_mac_f32': 'v_dot2c_f32_bf16', 'v_madmk_f32': 'v_fmamk_f32', 'v_madak_f32': 'v_fmaak_f32', + # VOPC: v_cmp_t_fXX -> v_cmp_tru_fXX for CDNA + 'v_cmp_t_f16': 'v_cmp_tru_f16', 'v_cmp_t_f32': 'v_cmp_tru_f32', 'v_cmp_t_f64': 'v_cmp_tru_f64', + 'v_cmpx_t_f16': 'v_cmpx_tru_f16', 'v_cmpx_t_f32': 'v_cmpx_tru_f32', 'v_cmpx_t_f64': 'v_cmpx_tru_f64', + # VOP1: flr/rpi -> floor/nearest for CDNA + 'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32', } def _apply_alias(text: str, arch: str = "rdna3") -> str: @@ -655,7 +660,7 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]') if m: bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower() - is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot')) + is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot', 'v_mad_mix', 'v_fma_mix')) opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \ (bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits)) m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None @@ -670,7 +675,8 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None - m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]'); opsel_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None + m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]') + opsel_hi, opsel_hi_count = (sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))), len(m.group(1).split(','))) if m else (None, 0) m, text = _extract(text, r'\s+gds(?:\s|$)'); gds = 1 if m else None m, text = _extract(text, r'\s+offset0:(\d+)'); offset0 = m.group(1) if m else None m, text = _extract(text, r'\s+offset1:(\d+)'); offset1 = m.group(1) if m else None @@ -721,11 +727,12 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: # SDWA instructions (CDNA) if mn.endswith('_sdwa') and arch == "cdna": base_mn = mn[:-5] # strip _sdwa - # Get VOP1/VOP2 opcode - from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, SDWA + # Get VOP1/VOP2/VOPC opcode + from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOPCOp, SDWA vop1_op = getattr(VOP1Op, base_mn.upper(), None) vop2_op = getattr(VOP2Op, base_mn.upper(), None) - if vop1_op is None and vop2_op is None: raise ValueError(f"unknown SDWA instruction: {mn}") + vopc_op = getattr(VOPCOp, base_mn.upper(), None) + if vop1_op is None and vop2_op is None and vopc_op is None: raise ValueError(f"unknown SDWA instruction: {mn}") # Parse operands: vdst, [vcc,] src0[, vsrc1] # For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3] vdst = args[0] # keep as v[N] for VGPRField @@ -793,8 +800,85 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: # Build SDWA kwargs # VOP1 SDWA: vop_op = VOP1 opcode, vop2_op = 0x3f (63) # VOP2 SDWA: vop_op = vsrc1, vop2_op = VOP2 opcode + # VOPC SDWA: vop_op = src1, vop2_op = 0x3e (62), vdst = VOPC opcode, dst_sel/dst_u/clmp/omod = sdst encoding sdwa_kw = [] - if vop1_op is not None: + if vopc_op is not None: + # VOPC SDWA: opcode goes in vdst field, vop2_op=62 + # Parse sdst from first operand (e.g., vcc, s[n:n+1], flat_scratch, ttmp[n:n+1]) + _SDWA_SDST_MAP = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 128+102, 'flat_scratch_lo': 128+102, + 'ttmp0': 128+108, 'ttmp2': 128+110, 'ttmp4': 128+112, 'ttmp6': 128+114, + 'ttmp8': 128+116, 'ttmp10': 128+118, 'ttmp12': 128+120, 'ttmp14': 128+122} + sdst_raw = ops[0].strip().lower() + if sdst_raw in _SDWA_SDST_MAP: sdst_enc = _SDWA_SDST_MAP[sdst_raw] + elif sdst_raw.startswith('s[') and ':' in sdst_raw: sdst_enc = 128 + int(sdst_raw[2:].split(':')[0]) + elif sdst_raw.startswith('s') and sdst_raw[1:].isdigit(): sdst_enc = 128 + int(sdst_raw[1:]) + elif sdst_raw.startswith('ttmp[') and ':' in sdst_raw: sdst_enc = 128 + 108 + int(sdst_raw[5:].split(':')[0]) + else: sdst_enc = 0 # Default: vcc + # For VOPC SDWA, src0 is ops[1], src1 is ops[2] + src0_raw = ops[1].strip().lower() if len(ops) > 1 else 'v0' + src1_raw = ops[2].strip().lower() if len(ops) > 2 else 'v0' + # Parse src0 with modifiers + src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit() + if src0_neg_mod: src0_raw = src0_raw[1:] + src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|') + if src0_abs_mod: src0_raw = src0_raw[1:-1] + src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')') + if src0_sext_mod: src0_raw = src0_raw[5:-1] + # Extract src0 value and type + if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['): + src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0]) + s0 = 0 + elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['): + src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0]) + s0 = 1 + elif src0_raw in _SDWA_SGPR_MAP: + src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1 + elif src0_raw in _SDWA_INLINE_CONST: + src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1 + elif src0_raw.lstrip('-').replace('.', '', 1).isdigit(): + # Integer or float inline constant + if '.' in src0_raw: + src0_val = _SDWA_INLINE_CONST.get(src0_raw, 128) + s0 = 1 + else: + ival = int(src0_raw) + if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1 + elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1 + else: src0_val, s0 = 0, 0 + else: src0_val, s0 = 0, 0 + # Parse src1 with modifiers + src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit() + if src1_neg_mod: src1_raw = src1_raw[1:] + src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|') + if src1_abs_mod: src1_raw = src1_raw[1:-1] + src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')') + if src1_sext_mod: src1_raw = src1_raw[5:-1] + # Extract src1 value and type + if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['): + vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0]) + s1 = 0 + else: vsrc1_val, s1 = 0, 0 + sdwa_kw.append(f'vop_op={vsrc1_val}') + sdwa_kw.append('vop2_op=62') # 0x3e indicates VOPC mode + sdwa_kw.append(f'vdst=RawImm({vopc_op.value})') # VOPC opcode in vdst + sdwa_kw.append(f'src0=RawImm({src0_val})') + # Encode sdst in dst_sel/dst_u/clmp/omod fields + sdwa_kw.append(f'dst_sel={sdst_enc & 7}') + sdwa_kw.append(f'dst_u={(sdst_enc >> 3) & 3}') + sdwa_kw.append(f'clmp={(sdst_enc >> 5) & 1}') + sdwa_kw.append(f'omod={(sdst_enc >> 6) & 3}') + sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}') + sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}') + if src0_sext_mod or sdwa_src0_sext: sdwa_kw.append('src0_sext=1') + if src0_neg_mod: sdwa_kw.append('src0_neg=1') + if src0_abs_mod: sdwa_kw.append('src0_abs=1') + if s0: sdwa_kw.append('s0=1') + if src1_sext_mod or sdwa_src1_sext: sdwa_kw.append('src1_sext=1') + if src1_neg_mod: sdwa_kw.append('src1_neg=1') + if src1_abs_mod: sdwa_kw.append('src1_abs=1') + if s1: sdwa_kw.append('s1=1') + return f"SDWA({', '.join(sdwa_kw)})" + elif vop1_op is not None: sdwa_kw.append(f'vop_op={vop1_op.value}') sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode else: @@ -902,14 +986,45 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: elif dst.startswith('ttmp') and dst[4:].isdigit(): dst_val = 108 + int(dst[4:]) else: sgpr_map = {'vcc_lo': 106, 'vcc_hi': 107, 'm0': 124, 'exec_lo': 126, 'exec_hi': 127, - 'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105} + 'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105, + 'null': 124} # null register for RDNA3 dst_val = sgpr_map.get(dst, int(dst) if dst.isdigit() else 0) return f"v_readfirstlane_b32_e32(vdst=RawImm({dst_val}), src0={args[1]})" if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})" + if mn in ('s_cbranch_join', 's_set_gpr_idx_idx'): return f"{mn}(ssrc0={args[0]}, sdst=RawImm(0))" # No destination, only source + if mn == 's_cbranch_g_fork': return f"{mn}(ssrc0={args[0]}, ssrc1={args[1]}, sdst=RawImm(0))" # Two sources, no dest + if mn == 's_set_gpr_idx_on': return f"{mn}(ssrc0={args[0]}, ssrc1=RawImm({int(args[1], 0)}))" # Mode bits as raw value if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))" if mn == 's_version': return f"{mn}(simm16={args[0]})" if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})" + # SMEM: s_dcache_discard has swapped operand layout (saddr→sbase, soffset→sdata) + if arch == "cdna" and mn.startswith('s_dcache_discard'): + gs = ", glc=1" if glc else "" + # Syntax: s_dcache_discard saddr, soffset [offset:imm] + if off_val and len(ops) >= 2: + # SGPR + immediate offset: soe=1, imm=1, soffset=SGPR, offset=imm + return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={off_val}, soffset={args[1]}, soe=1, imm=1{gs})" + if len(ops) >= 2 and re.match(r'^-?[0-9]|^-?0x', ops[1].strip().lower()): + # Immediate offset only: imm=1 + return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0), imm=1{gs})" + # SGPR offset only: imm=0, offset=SGPR + return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0){gs})" + + # SMEM: s_atomic_*/s_buffer_atomic_* uses offset field for SGPR (imm=0), not soffset + if arch == "cdna" and (mn.startswith('s_buffer_atomic') or (mn.startswith('s_atomic') and not mn.startswith('s_atc'))): + gs = ", glc=1" if glc else "" + if len(ops) >= 3: + # Syntax: s_atomic_* sdata, sbase, soffset [offset:imm] + if off_val: + # SGPR + immediate offset: soe=1, imm=1 + return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}, soe=1, imm=1{gs})" + if re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()): + # Immediate offset only: imm=1 + return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0), imm=1{gs})" + # SGPR offset only: imm=0, offset=SGPR + return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs})" + # SMEM if mn in SMEM_OPS or (arch == "cdna" and mn.startswith(('s_load_dword', 's_buffer_load_dword'))): gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else "" @@ -924,6 +1039,9 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: if len(ops) >= 3: # SGPR offset only: offset=SGPR index, soffset=0 return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs}{ds})" + if len(ops) == 2: + # No offset specified: imm=1, offset=0 + return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset=0, soffset=RawImm(0), imm=1{gs}{ds})" else: # RDNA3 encoding if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()): @@ -1003,12 +1121,17 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: return f"{mn}(vdst=v[0], addr={addr_val}, saddr={saddr_val}{flat_mods})" # For scratch, 'off' as vaddr means vaddr=0 (no offset), not null register # For load: args=[vdst, addr, saddr], for store: args=[addr, data, saddr] + # For RDNA3 scratch with 'off' as vaddr, set sve=0 (no VGPR address) if 'store' in pre: - addr_val = 'v[0]' if seg == 'scratch' and args[0] == 'OFF' else args[0] - return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})" + addr_off = seg == 'scratch' and args[0] == 'OFF' + addr_val = 'v[0]' if addr_off else args[0] + sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else '' + return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})" else: - addr_val = 'v[0]' if seg == 'scratch' and args[1] == 'OFF' else args[1] - return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})" + addr_off = seg == 'scratch' and args[1] == 'OFF' + addr_val = 'v[0]' if addr_off else args[1] + sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else '' + return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})" for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'): if mn.startswith(pre): seg = pre.split('_')[0] # 'flat', 'global', or 'scratch' @@ -1034,15 +1157,17 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" - if 'load' in mn: return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" + if 'load' in mn or ('read' in mn and 'read2' not in mn): return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" + if 'read2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" if 'write2' in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" + if 'xchg2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" if 'store' in mn and not _has(mn, 'cmp', 'xchg'): return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})" if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})" if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" - if _has(mn, 'cmpstore', 'mskor', 'wrap'): + if _has(mn, 'cmpst', 'mskor', 'wrap'): return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})" return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})" @@ -1063,16 +1188,22 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]] if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '') if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:] - if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args + # For RDNA3 v_cmpx, destination is implicitly exec (126) + if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2 and arch == 'rdna3': args = ['RawImm(126)'] + args # v_cmp_*_e64 and v_cmpx_*_e64 have SGPR destination in vdst field - encode as RawImm + # For CDNA, v_cmpx also writes to SGPR pair (first operand) _SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127} if mn.startswith('v_cmp') and mn.endswith('_e64') and len(args) >= 1: - dst = ops[0].strip().lower() - if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})' - elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})' - elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})' - elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})' - elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})' + # For CDNA v_cmpx with 3 operands (sdst, src0, src1), convert sdst to RawImm + # For RDNA3, v_cmpx only has 2 operands (src0, src1) - already handled above + is_cmpx = 'cmpx' in mn + if not is_cmpx or arch == 'cdna': + dst = ops[0].strip().lower() + if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})' + elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})' + elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})' + elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})' + elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})' fn = mn.replace('.', '_') if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args] @@ -1096,23 +1227,52 @@ def get_dsl(text: str, arch: str = "rdna3") -> str: all_kw = list(kw) if lit_s: all_kw.append(lit_s.lstrip(', ')) if opsel is not None: all_kw.append(f'opsel={opsel}') - if opsel_hi is not None: all_kw.append(f'opsel_hi={opsel_hi & 3}'); all_kw.append(f'opsel_hi2={(opsel_hi >> 2) & 1}') + if opsel_hi is not None: + all_kw.append(f'opsel_hi={opsel_hi & 3}') + if opsel_hi_count >= 3: all_kw.append(f'opsel_hi2={(opsel_hi >> 2) & 1}') # only set opsel_hi2 if 3 elements specified if neg_lo is not None: all_kw.append(f'neg={neg_lo}') if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}') if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1']) # For CDNA _e64 VOP instructions: use keyword args (VOP3 layout) - # Pattern: v_xxx_e64 dst, src0[, src1[, src2]] -> v_xxx(vdst=dst, src0=src0[, src1=src1[, src2=src2]]) - # For v_nop_e64 (no operands), add _vop3=True marker to force VOP3 encoding + # Pattern: v_xxx_e64 dst, src0[, src1[, src2]] -> VOP3A with promoted opcode + # VOP1 to VOP3 promotion: VOP3 op = 384 + (VOP1_op - 64) for VOP1_op >= 64, else 256 + VOP1_op if fn.endswith('_e64') and fn.startswith('v_') and arch == "cdna": - fn_base = fn[:-4] # strip _e64 + fn_base = fn[:-4].upper() # strip _e64 and uppercase for enum lookup + from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOP3AOp, VOP3BOp + # Check if this is a VOP3B instruction (has sdst for carry-out) + vop3b_op = getattr(VOP3BOp, fn_base, None) + if vop3b_op is not None: + # VOP3B: v_xxx_e64 vdst, sdst, src0, src1[, src2] + vop3_args = [] + if len(args) >= 1: vop3_args.append(f'vdst={args[0]}') + if len(args) >= 2: vop3_args.append(f'sdst={args[1]}') + if len(args) >= 3: vop3_args.append(f'src0={args[2]}') + if len(args) >= 4: vop3_args.append(f'src1={args[3]}') + if len(args) >= 5: vop3_args.append(f'src2={args[4]}') + a_str = ', '.join(vop3_args + all_kw) + return f"{fn[:-4]}({a_str})" + # Check if this is a VOP1 instruction that needs promotion + vop1_op = getattr(VOP1Op, fn_base, None) + vop2_op = getattr(VOP2Op, fn_base, None) + vop3a_op = getattr(VOP3AOp, fn_base, None) + if vop1_op is not None and vop3a_op is None: + # VOP1 -> VOP3 promotion: calculate promoted opcode + promoted_op = 384 + (vop1_op.value - 64) if vop1_op.value >= 64 else 256 + vop1_op.value + vop3_args = [f'op={promoted_op}'] + if len(args) >= 1: vop3_args.append(f'vdst={args[0]}') + if len(args) >= 2: vop3_args.append(f'src0={args[1]}') + if len(args) >= 3: vop3_args.append(f'src1={args[2]}') + if len(args) >= 4: vop3_args.append(f'src2={args[3]}') + return f"VOP3A({', '.join(vop3_args + all_kw)})" + # Otherwise try normal VOP3 lookup vop3_args = ['_vop3=True'] # marker for asm() to force VOP3 if len(args) >= 1: vop3_args.append(f'vdst={args[0]}') if len(args) >= 2: vop3_args.append(f'src0={args[1]}') if len(args) >= 3: vop3_args.append(f'src1={args[2]}') if len(args) >= 4: vop3_args.append(f'src2={args[3]}') a_str = ', '.join(vop3_args + all_kw) - return f"{fn_base}({a_str})" + return f"{fn[:-4]}({a_str})" a_str, kw_str = ', '.join(args), ', '.join(all_kw) return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})" diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 9edc6aa101..190ddce8b9 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -305,7 +305,10 @@ class Inst: if isinstance(val, SrcMod): mod_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0) if val.neg and 'neg' in self._fields: self._or_field('neg', mod_bit) - if val.abs_ and 'abs' in self._fields: self._or_field('abs', mod_bit) + # abs can be in 'abs' field (VOP3A) or 'neg_hi' field (VOP3P uses neg_hi for abs) + if val.abs_: + if 'abs' in self._fields: self._or_field('abs', mod_bit) + elif 'neg_hi' in self._fields and 'abs' not in self._fields: self._or_field('neg_hi', mod_bit) # VOP3P uses neg_hi for abs if isinstance(val, Reg) and val.hi and has_opsel: self._or_field('opsel', {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)) # Track literal value if needed @@ -371,7 +374,9 @@ class Inst: # Format-specific setup if cls_name == 'FLAT' and 'sve' in self._fields: seg = self._values.get('seg', 0) - if (seg.val if isinstance(seg, RawImm) else seg) == 1 and isinstance(orig_args.get('addr'), VGPR): self._values['sve'] = 1 + # Only auto-set sve=1 if not explicitly passed and conditions match (seg=1/scratch, addr is VGPR) + if 'sve' not in orig_args and (seg.val if isinstance(seg, RawImm) else seg) == 1 and isinstance(orig_args.get('addr'), VGPR): + self._values['sve'] = 1 if cls_name == 'VOP3P': op = orig_args.get('op') if hasattr(op, 'value'): op = op.value diff --git a/extra/assembly/amd/test/test_llvm.py b/extra/assembly/amd/test/test_llvm.py index f29adefeec..39bcded112 100644 --- a/extra/assembly/amd/test/test_llvm.py +++ b/extra/assembly/amd/test/test_llvm.py @@ -69,13 +69,17 @@ def _make_test(f: str, arch: str, test_type: str): self.assertEqual(decoded.to_bytes()[:len(data)], data) print(f"{name}: {len(tests)} passed") elif test_type == "asm": - passed, skipped = 0, 0 + passed, failed, skipped = 0, 0, 0 for asm_text, expected in tests: try: - self.assertEqual(asm(asm_text).to_bytes(), expected) - passed += 1 - except: skipped += 1 - print(f"{name}: {passed} passed, {skipped} skipped") + result = asm(asm_text, arch=arch) + if result.to_bytes() == expected: + passed += 1 + else: + failed += 1 + except: + skipped += 1 + print(f"{name}: {passed} passed, {failed} failed, {skipped} skipped") elif test_type == "disasm": to_test = [] for _, data in tests: