mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 06:34:03 -05:00
no skip
This commit is contained in:
@@ -568,8 +568,8 @@ def _op2dsl(op: str, arch: str = "rdna3") -> str:
|
||||
if lo in spec_dsl: return wrap(spec_dsl[lo])
|
||||
if op in FLOATS: return wrap(op)
|
||||
rp = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}
|
||||
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}:{m.group(3)}]")
|
||||
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp[m.group(1)]}[{m.group(2)}]")
|
||||
if m := re.match(r'^([asvt](?:tmp)?)\[(\d+):(\d+)\]$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}:{m.group(3)}]")
|
||||
if m := re.match(r'^([asvt](?:tmp)?)(\d+)$', lo): return wrap(f"{rp.get(m.group(1), m.group(1))}[{m.group(2)}]")
|
||||
if re.match(r'^-?\d+$|^-?0x[0-9a-fA-F]+$', op): return f"SrcMod({op}, neg={neg}, abs_={abs_})" if neg or abs_ else op
|
||||
return wrap(op)
|
||||
|
||||
@@ -634,6 +634,11 @@ _CDNA_ALIASES = {
|
||||
# VOP aliases (inverse of _CDNA_DISASM_ALIASES)
|
||||
'v_cvt_pkrtz_f16_f32': 'v_cvt_pk_rtz_f16_f32',
|
||||
'v_mul_legacy_f32': 'v_fmac_f64', 'v_mac_f32': 'v_dot2c_f32_bf16', 'v_madmk_f32': 'v_fmamk_f32', 'v_madak_f32': 'v_fmaak_f32',
|
||||
# VOPC: v_cmp_t_fXX -> v_cmp_tru_fXX for CDNA
|
||||
'v_cmp_t_f16': 'v_cmp_tru_f16', 'v_cmp_t_f32': 'v_cmp_tru_f32', 'v_cmp_t_f64': 'v_cmp_tru_f64',
|
||||
'v_cmpx_t_f16': 'v_cmpx_tru_f16', 'v_cmpx_t_f32': 'v_cmpx_tru_f32', 'v_cmpx_t_f64': 'v_cmpx_tru_f64',
|
||||
# VOP1: flr/rpi -> floor/nearest for CDNA
|
||||
'v_cvt_flr_i32_f32': 'v_cvt_floor_i32_f32', 'v_cvt_rpi_i32_f32': 'v_cvt_nearest_i32_f32',
|
||||
}
|
||||
|
||||
def _apply_alias(text: str, arch: str = "rdna3") -> str:
|
||||
@@ -655,7 +660,7 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
opsel, m, text = None, *_extract(text, r'\s+op_sel:\[([^\]]+)\]')
|
||||
if m:
|
||||
bits, mn = [int(x.strip()) for x in m.group(1).split(',')], text.split()[0].lower()
|
||||
is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot'))
|
||||
is3p = mn.startswith(('v_pk_', 'v_wmma_', 'v_dot', 'v_mad_mix', 'v_fma_mix'))
|
||||
opsel = (bits[0] | (bits[1] << 1) | (bits[2] << 2)) if len(bits) == 3 and is3p else \
|
||||
(bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits))
|
||||
m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None
|
||||
@@ -670,7 +675,8 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
|
||||
m, text = _extract(text, r'\s+neg_lo:\[([^\]]+)\]'); neg_lo = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]'); opsel_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]')
|
||||
opsel_hi, opsel_hi_count = (sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))), len(m.group(1).split(','))) if m else (None, 0)
|
||||
m, text = _extract(text, r'\s+gds(?:\s|$)'); gds = 1 if m else None
|
||||
m, text = _extract(text, r'\s+offset0:(\d+)'); offset0 = m.group(1) if m else None
|
||||
m, text = _extract(text, r'\s+offset1:(\d+)'); offset1 = m.group(1) if m else None
|
||||
@@ -721,11 +727,12 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
# SDWA instructions (CDNA)
|
||||
if mn.endswith('_sdwa') and arch == "cdna":
|
||||
base_mn = mn[:-5] # strip _sdwa
|
||||
# Get VOP1/VOP2 opcode
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, SDWA
|
||||
# Get VOP1/VOP2/VOPC opcode
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOPCOp, SDWA
|
||||
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
|
||||
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
|
||||
if vop1_op is None and vop2_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
|
||||
vopc_op = getattr(VOPCOp, base_mn.upper(), None)
|
||||
if vop1_op is None and vop2_op is None and vopc_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
|
||||
# Parse operands: vdst, [vcc,] src0[, vsrc1]
|
||||
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
|
||||
vdst = args[0] # keep as v[N] for VGPRField
|
||||
@@ -793,8 +800,85 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
# Build SDWA kwargs
|
||||
# VOP1 SDWA: vop_op = VOP1 opcode, vop2_op = 0x3f (63)
|
||||
# VOP2 SDWA: vop_op = vsrc1, vop2_op = VOP2 opcode
|
||||
# VOPC SDWA: vop_op = src1, vop2_op = 0x3e (62), vdst = VOPC opcode, dst_sel/dst_u/clmp/omod = sdst encoding
|
||||
sdwa_kw = []
|
||||
if vop1_op is not None:
|
||||
if vopc_op is not None:
|
||||
# VOPC SDWA: opcode goes in vdst field, vop2_op=62
|
||||
# Parse sdst from first operand (e.g., vcc, s[n:n+1], flat_scratch, ttmp[n:n+1])
|
||||
_SDWA_SDST_MAP = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 128+102, 'flat_scratch_lo': 128+102,
|
||||
'ttmp0': 128+108, 'ttmp2': 128+110, 'ttmp4': 128+112, 'ttmp6': 128+114,
|
||||
'ttmp8': 128+116, 'ttmp10': 128+118, 'ttmp12': 128+120, 'ttmp14': 128+122}
|
||||
sdst_raw = ops[0].strip().lower()
|
||||
if sdst_raw in _SDWA_SDST_MAP: sdst_enc = _SDWA_SDST_MAP[sdst_raw]
|
||||
elif sdst_raw.startswith('s[') and ':' in sdst_raw: sdst_enc = 128 + int(sdst_raw[2:].split(':')[0])
|
||||
elif sdst_raw.startswith('s') and sdst_raw[1:].isdigit(): sdst_enc = 128 + int(sdst_raw[1:])
|
||||
elif sdst_raw.startswith('ttmp[') and ':' in sdst_raw: sdst_enc = 128 + 108 + int(sdst_raw[5:].split(':')[0])
|
||||
else: sdst_enc = 0 # Default: vcc
|
||||
# For VOPC SDWA, src0 is ops[1], src1 is ops[2]
|
||||
src0_raw = ops[1].strip().lower() if len(ops) > 1 else 'v0'
|
||||
src1_raw = ops[2].strip().lower() if len(ops) > 2 else 'v0'
|
||||
# Parse src0 with modifiers
|
||||
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
|
||||
if src0_neg_mod: src0_raw = src0_raw[1:]
|
||||
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
|
||||
if src0_abs_mod: src0_raw = src0_raw[1:-1]
|
||||
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
|
||||
if src0_sext_mod: src0_raw = src0_raw[5:-1]
|
||||
# Extract src0 value and type
|
||||
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
|
||||
s0 = 0
|
||||
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
|
||||
s0 = 1
|
||||
elif src0_raw in _SDWA_SGPR_MAP:
|
||||
src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
|
||||
elif src0_raw in _SDWA_INLINE_CONST:
|
||||
src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
|
||||
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
|
||||
# Integer or float inline constant
|
||||
if '.' in src0_raw:
|
||||
src0_val = _SDWA_INLINE_CONST.get(src0_raw, 128)
|
||||
s0 = 1
|
||||
else:
|
||||
ival = int(src0_raw)
|
||||
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
|
||||
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
|
||||
else: src0_val, s0 = 0, 0
|
||||
else: src0_val, s0 = 0, 0
|
||||
# Parse src1 with modifiers
|
||||
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
|
||||
if src1_neg_mod: src1_raw = src1_raw[1:]
|
||||
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
|
||||
if src1_abs_mod: src1_raw = src1_raw[1:-1]
|
||||
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
|
||||
if src1_sext_mod: src1_raw = src1_raw[5:-1]
|
||||
# Extract src1 value and type
|
||||
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
|
||||
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
|
||||
s1 = 0
|
||||
else: vsrc1_val, s1 = 0, 0
|
||||
sdwa_kw.append(f'vop_op={vsrc1_val}')
|
||||
sdwa_kw.append('vop2_op=62') # 0x3e indicates VOPC mode
|
||||
sdwa_kw.append(f'vdst=RawImm({vopc_op.value})') # VOPC opcode in vdst
|
||||
sdwa_kw.append(f'src0=RawImm({src0_val})')
|
||||
# Encode sdst in dst_sel/dst_u/clmp/omod fields
|
||||
sdwa_kw.append(f'dst_sel={sdst_enc & 7}')
|
||||
sdwa_kw.append(f'dst_u={(sdst_enc >> 3) & 3}')
|
||||
sdwa_kw.append(f'clmp={(sdst_enc >> 5) & 1}')
|
||||
sdwa_kw.append(f'omod={(sdst_enc >> 6) & 3}')
|
||||
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
|
||||
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
|
||||
if src0_sext_mod or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
|
||||
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
|
||||
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
|
||||
if s0: sdwa_kw.append('s0=1')
|
||||
if src1_sext_mod or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
|
||||
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
|
||||
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
|
||||
if s1: sdwa_kw.append('s1=1')
|
||||
return f"SDWA({', '.join(sdwa_kw)})"
|
||||
elif vop1_op is not None:
|
||||
sdwa_kw.append(f'vop_op={vop1_op.value}')
|
||||
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
|
||||
else:
|
||||
@@ -902,14 +986,45 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
elif dst.startswith('ttmp') and dst[4:].isdigit(): dst_val = 108 + int(dst[4:])
|
||||
else:
|
||||
sgpr_map = {'vcc_lo': 106, 'vcc_hi': 107, 'm0': 124, 'exec_lo': 126, 'exec_hi': 127,
|
||||
'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105}
|
||||
'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105,
|
||||
'null': 124} # null register for RDNA3
|
||||
dst_val = sgpr_map.get(dst, int(dst) if dst.isdigit() else 0)
|
||||
return f"v_readfirstlane_b32_e32(vdst=RawImm({dst_val}), src0={args[1]})"
|
||||
if mn in ('s_setpc_b64', 's_rfe_b64'): return f"{mn}(ssrc0={args[0]})"
|
||||
if mn in ('s_cbranch_join', 's_set_gpr_idx_idx'): return f"{mn}(ssrc0={args[0]}, sdst=RawImm(0))" # No destination, only source
|
||||
if mn == 's_cbranch_g_fork': return f"{mn}(ssrc0={args[0]}, ssrc1={args[1]}, sdst=RawImm(0))" # Two sources, no dest
|
||||
if mn == 's_set_gpr_idx_on': return f"{mn}(ssrc0={args[0]}, ssrc1=RawImm({int(args[1], 0)}))" # Mode bits as raw value
|
||||
if mn in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'): return f"{mn}(sdst={args[0]}, ssrc0=RawImm({args[1].strip()}))"
|
||||
if mn == 's_version': return f"{mn}(simm16={args[0]})"
|
||||
if mn == 's_setreg_b32': return f"{mn}(simm16={args[0]}, sdst={args[1]})"
|
||||
|
||||
# SMEM: s_dcache_discard has swapped operand layout (saddr→sbase, soffset→sdata)
|
||||
if arch == "cdna" and mn.startswith('s_dcache_discard'):
|
||||
gs = ", glc=1" if glc else ""
|
||||
# Syntax: s_dcache_discard saddr, soffset [offset:imm]
|
||||
if off_val and len(ops) >= 2:
|
||||
# SGPR + immediate offset: soe=1, imm=1, soffset=SGPR, offset=imm
|
||||
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={off_val}, soffset={args[1]}, soe=1, imm=1{gs})"
|
||||
if len(ops) >= 2 and re.match(r'^-?[0-9]|^-?0x', ops[1].strip().lower()):
|
||||
# Immediate offset only: imm=1
|
||||
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0), imm=1{gs})"
|
||||
# SGPR offset only: imm=0, offset=SGPR
|
||||
return f"{mn}(sbase={args[0]}, sdata=RawImm(0), offset={args[1]}, soffset=RawImm(0){gs})"
|
||||
|
||||
# SMEM: s_atomic_*/s_buffer_atomic_* uses offset field for SGPR (imm=0), not soffset
|
||||
if arch == "cdna" and (mn.startswith('s_buffer_atomic') or (mn.startswith('s_atomic') and not mn.startswith('s_atc'))):
|
||||
gs = ", glc=1" if glc else ""
|
||||
if len(ops) >= 3:
|
||||
# Syntax: s_atomic_* sdata, sbase, soffset [offset:imm]
|
||||
if off_val:
|
||||
# SGPR + immediate offset: soe=1, imm=1
|
||||
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={off_val}, soffset={args[2]}, soe=1, imm=1{gs})"
|
||||
if re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
|
||||
# Immediate offset only: imm=1
|
||||
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0), imm=1{gs})"
|
||||
# SGPR offset only: imm=0, offset=SGPR
|
||||
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs})"
|
||||
|
||||
# SMEM
|
||||
if mn in SMEM_OPS or (arch == "cdna" and mn.startswith(('s_load_dword', 's_buffer_load_dword'))):
|
||||
gs, ds = ", glc=1" if glc else "", ", dlc=1" if dlc else ""
|
||||
@@ -924,6 +1039,9 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
if len(ops) >= 3:
|
||||
# SGPR offset only: offset=SGPR index, soffset=0
|
||||
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset={args[2]}, soffset=RawImm(0){gs}{ds})"
|
||||
if len(ops) == 2:
|
||||
# No offset specified: imm=1, offset=0
|
||||
return f"{mn}(sdata={args[0]}, sbase={args[1]}, offset=0, soffset=RawImm(0), imm=1{gs}{ds})"
|
||||
else:
|
||||
# RDNA3 encoding
|
||||
if len(ops) >= 3 and re.match(r'^-?[0-9]|^-?0x', ops[2].strip().lower()):
|
||||
@@ -1003,12 +1121,17 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
return f"{mn}(vdst=v[0], addr={addr_val}, saddr={saddr_val}{flat_mods})"
|
||||
# For scratch, 'off' as vaddr means vaddr=0 (no offset), not null register
|
||||
# For load: args=[vdst, addr, saddr], for store: args=[addr, data, saddr]
|
||||
# For RDNA3 scratch with 'off' as vaddr, set sve=0 (no VGPR address)
|
||||
if 'store' in pre:
|
||||
addr_val = 'v[0]' if seg == 'scratch' and args[0] == 'OFF' else args[0]
|
||||
return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
|
||||
addr_off = seg == 'scratch' and args[0] == 'OFF'
|
||||
addr_val = 'v[0]' if addr_off else args[0]
|
||||
sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else ''
|
||||
return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})"
|
||||
else:
|
||||
addr_val = 'v[0]' if seg == 'scratch' and args[1] == 'OFF' else args[1]
|
||||
return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
|
||||
addr_off = seg == 'scratch' and args[1] == 'OFF'
|
||||
addr_val = 'v[0]' if addr_off else args[1]
|
||||
sve_mod = ', sve=0' if addr_off and arch == 'rdna3' else ''
|
||||
return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{sve_mod}{flat_mods})"
|
||||
for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'):
|
||||
if mn.startswith(pre):
|
||||
seg = pre.split('_')[0] # 'flat', 'global', or 'scratch'
|
||||
@@ -1034,15 +1157,17 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
if 'load' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
|
||||
if 'store' in mn and 'xchg' not in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
|
||||
if 'load' in mn: return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
|
||||
if 'load' in mn or ('read' in mn and 'read2' not in mn): return f"{mn}(vdst={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
|
||||
if 'read2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
|
||||
if 'write2' in mn: return f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
|
||||
if 'xchg2' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
|
||||
if 'store' in mn and not _has(mn, 'cmp', 'xchg'):
|
||||
return f"{mn}(data0={args[0]}{off_kw})" if 'addtid' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
|
||||
if 'swizzle' in mn or 'ordered_count' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}{off_kw})"
|
||||
if 'permute' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
|
||||
if 'bvh' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})"
|
||||
if 'condxchg' in mn: return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})"
|
||||
if _has(mn, 'cmpstore', 'mskor', 'wrap'):
|
||||
if _has(mn, 'cmpst', 'mskor', 'wrap'):
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
|
||||
|
||||
@@ -1063,16 +1188,22 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
if mn.replace('_e32', '') in vcc_ops and len(args) >= 5: mn, args = mn.replace('_e32', '') + '_e32', [args[0], args[2], args[3]]
|
||||
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
|
||||
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
|
||||
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
|
||||
# For RDNA3 v_cmpx, destination is implicitly exec (126)
|
||||
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2 and arch == 'rdna3': args = ['RawImm(126)'] + args
|
||||
# v_cmp_*_e64 and v_cmpx_*_e64 have SGPR destination in vdst field - encode as RawImm
|
||||
# For CDNA, v_cmpx also writes to SGPR pair (first operand)
|
||||
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
|
||||
if mn.startswith('v_cmp') and mn.endswith('_e64') and len(args) >= 1:
|
||||
dst = ops[0].strip().lower()
|
||||
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
|
||||
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
|
||||
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
|
||||
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
|
||||
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
|
||||
# For CDNA v_cmpx with 3 operands (sdst, src0, src1), convert sdst to RawImm
|
||||
# For RDNA3, v_cmpx only has 2 operands (src0, src1) - already handled above
|
||||
is_cmpx = 'cmpx' in mn
|
||||
if not is_cmpx or arch == 'cdna':
|
||||
dst = ops[0].strip().lower()
|
||||
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
|
||||
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
|
||||
elif dst.startswith('ttmp') and dst[4:].isdigit(): args[0] = f'RawImm({108 + int(dst[4:])})'
|
||||
elif dst.startswith('ttmp[') and ':' in dst: args[0] = f'RawImm({108 + int(dst[5:].split(":")[0])})'
|
||||
elif dst in _SGPR_NAMES: args[0] = f'RawImm({_SGPR_NAMES[dst]})'
|
||||
|
||||
fn = mn.replace('.', '_')
|
||||
if opsel is not None: args = [re.sub(r'\.[hl]$', '', a) for a in args]
|
||||
@@ -1096,23 +1227,52 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
all_kw = list(kw)
|
||||
if lit_s: all_kw.append(lit_s.lstrip(', '))
|
||||
if opsel is not None: all_kw.append(f'opsel={opsel}')
|
||||
if opsel_hi is not None: all_kw.append(f'opsel_hi={opsel_hi & 3}'); all_kw.append(f'opsel_hi2={(opsel_hi >> 2) & 1}')
|
||||
if opsel_hi is not None:
|
||||
all_kw.append(f'opsel_hi={opsel_hi & 3}')
|
||||
if opsel_hi_count >= 3: all_kw.append(f'opsel_hi2={(opsel_hi >> 2) & 1}') # only set opsel_hi2 if 3 elements specified
|
||||
if neg_lo is not None: all_kw.append(f'neg={neg_lo}')
|
||||
if neg_hi is not None: all_kw.append(f'neg_hi={neg_hi}')
|
||||
if 'bvh' in mn and 'intersect_ray' in mn: all_kw.extend(['dmask=15', 'unrm=1', 'r128=1'])
|
||||
|
||||
# For CDNA _e64 VOP instructions: use keyword args (VOP3 layout)
|
||||
# Pattern: v_xxx_e64 dst, src0[, src1[, src2]] -> v_xxx(vdst=dst, src0=src0[, src1=src1[, src2=src2]])
|
||||
# For v_nop_e64 (no operands), add _vop3=True marker to force VOP3 encoding
|
||||
# Pattern: v_xxx_e64 dst, src0[, src1[, src2]] -> VOP3A with promoted opcode
|
||||
# VOP1 to VOP3 promotion: VOP3 op = 384 + (VOP1_op - 64) for VOP1_op >= 64, else 256 + VOP1_op
|
||||
if fn.endswith('_e64') and fn.startswith('v_') and arch == "cdna":
|
||||
fn_base = fn[:-4] # strip _e64
|
||||
fn_base = fn[:-4].upper() # strip _e64 and uppercase for enum lookup
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOP3AOp, VOP3BOp
|
||||
# Check if this is a VOP3B instruction (has sdst for carry-out)
|
||||
vop3b_op = getattr(VOP3BOp, fn_base, None)
|
||||
if vop3b_op is not None:
|
||||
# VOP3B: v_xxx_e64 vdst, sdst, src0, src1[, src2]
|
||||
vop3_args = []
|
||||
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
|
||||
if len(args) >= 2: vop3_args.append(f'sdst={args[1]}')
|
||||
if len(args) >= 3: vop3_args.append(f'src0={args[2]}')
|
||||
if len(args) >= 4: vop3_args.append(f'src1={args[3]}')
|
||||
if len(args) >= 5: vop3_args.append(f'src2={args[4]}')
|
||||
a_str = ', '.join(vop3_args + all_kw)
|
||||
return f"{fn[:-4]}({a_str})"
|
||||
# Check if this is a VOP1 instruction that needs promotion
|
||||
vop1_op = getattr(VOP1Op, fn_base, None)
|
||||
vop2_op = getattr(VOP2Op, fn_base, None)
|
||||
vop3a_op = getattr(VOP3AOp, fn_base, None)
|
||||
if vop1_op is not None and vop3a_op is None:
|
||||
# VOP1 -> VOP3 promotion: calculate promoted opcode
|
||||
promoted_op = 384 + (vop1_op.value - 64) if vop1_op.value >= 64 else 256 + vop1_op.value
|
||||
vop3_args = [f'op={promoted_op}']
|
||||
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
|
||||
if len(args) >= 2: vop3_args.append(f'src0={args[1]}')
|
||||
if len(args) >= 3: vop3_args.append(f'src1={args[2]}')
|
||||
if len(args) >= 4: vop3_args.append(f'src2={args[3]}')
|
||||
return f"VOP3A({', '.join(vop3_args + all_kw)})"
|
||||
# Otherwise try normal VOP3 lookup
|
||||
vop3_args = ['_vop3=True'] # marker for asm() to force VOP3
|
||||
if len(args) >= 1: vop3_args.append(f'vdst={args[0]}')
|
||||
if len(args) >= 2: vop3_args.append(f'src0={args[1]}')
|
||||
if len(args) >= 3: vop3_args.append(f'src1={args[2]}')
|
||||
if len(args) >= 4: vop3_args.append(f'src2={args[3]}')
|
||||
a_str = ', '.join(vop3_args + all_kw)
|
||||
return f"{fn_base}({a_str})"
|
||||
return f"{fn[:-4]}({a_str})"
|
||||
|
||||
a_str, kw_str = ', '.join(args), ', '.join(all_kw)
|
||||
return f"{fn}({a_str}, {kw_str})" if kw_str and a_str else f"{fn}({kw_str})" if kw_str else f"{fn}({a_str})"
|
||||
|
||||
@@ -305,7 +305,10 @@ class Inst:
|
||||
if isinstance(val, SrcMod):
|
||||
mod_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)
|
||||
if val.neg and 'neg' in self._fields: self._or_field('neg', mod_bit)
|
||||
if val.abs_ and 'abs' in self._fields: self._or_field('abs', mod_bit)
|
||||
# abs can be in 'abs' field (VOP3A) or 'neg_hi' field (VOP3P uses neg_hi for abs)
|
||||
if val.abs_:
|
||||
if 'abs' in self._fields: self._or_field('abs', mod_bit)
|
||||
elif 'neg_hi' in self._fields and 'abs' not in self._fields: self._or_field('neg_hi', mod_bit) # VOP3P uses neg_hi for abs
|
||||
if isinstance(val, Reg) and val.hi and has_opsel:
|
||||
self._or_field('opsel', {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0))
|
||||
# Track literal value if needed
|
||||
@@ -371,7 +374,9 @@ class Inst:
|
||||
# Format-specific setup
|
||||
if cls_name == 'FLAT' and 'sve' in self._fields:
|
||||
seg = self._values.get('seg', 0)
|
||||
if (seg.val if isinstance(seg, RawImm) else seg) == 1 and isinstance(orig_args.get('addr'), VGPR): self._values['sve'] = 1
|
||||
# Only auto-set sve=1 if not explicitly passed and conditions match (seg=1/scratch, addr is VGPR)
|
||||
if 'sve' not in orig_args and (seg.val if isinstance(seg, RawImm) else seg) == 1 and isinstance(orig_args.get('addr'), VGPR):
|
||||
self._values['sve'] = 1
|
||||
if cls_name == 'VOP3P':
|
||||
op = orig_args.get('op')
|
||||
if hasattr(op, 'value'): op = op.value
|
||||
|
||||
@@ -69,13 +69,17 @@ def _make_test(f: str, arch: str, test_type: str):
|
||||
self.assertEqual(decoded.to_bytes()[:len(data)], data)
|
||||
print(f"{name}: {len(tests)} passed")
|
||||
elif test_type == "asm":
|
||||
passed, skipped = 0, 0
|
||||
passed, failed, skipped = 0, 0, 0
|
||||
for asm_text, expected in tests:
|
||||
try:
|
||||
self.assertEqual(asm(asm_text).to_bytes(), expected)
|
||||
passed += 1
|
||||
except: skipped += 1
|
||||
print(f"{name}: {passed} passed, {skipped} skipped")
|
||||
result = asm(asm_text, arch=arch)
|
||||
if result.to_bytes() == expected:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
except:
|
||||
skipped += 1
|
||||
print(f"{name}: {passed} passed, {failed} failed, {skipped} skipped")
|
||||
elif test_type == "disasm":
|
||||
to_test = []
|
||||
for _, data in tests:
|
||||
|
||||
Reference in New Issue
Block a user