This commit is contained in:
George Hotz
2026-01-05 08:57:41 -08:00
parent bb5103fdb0
commit 893f34a80f

View File

@@ -557,10 +557,8 @@ SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b5
's_atc_probe', 's_atc_probe_buffer'}
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
SPEC_DSL_CDNA = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SRC_SCC',
'flat_scratch_lo': 'FLAT_SCRATCH_LO', 'flat_scratch_hi': 'FLAT_SCRATCH_HI', 'flat_scratch': 'FLAT_SCRATCH',
'xnack_mask_lo': 'XNACK_MASK_LO', 'xnack_mask_hi': 'XNACK_MASK_HI', 'xnack_mask': 'XNACK_MASK',
SPEC_DSL_CDNA = {**SPEC_DSL, 'src_scc': 'SRC_SCC', 'flat_scratch_lo': 'FLAT_SCRATCH_LO', 'flat_scratch_hi': 'FLAT_SCRATCH_HI',
'flat_scratch': 'FLAT_SCRATCH', 'xnack_mask_lo': 'XNACK_MASK_LO', 'xnack_mask_hi': 'XNACK_MASK_HI', 'xnack_mask': 'XNACK_MASK',
'src_vccz': 'SRC_VCCZ', 'src_execz': 'SRC_EXECZ', 'vccz': 'SRC_VCCZ', 'execz': 'SRC_EXECZ',
'src_lds_direct': 'SRC_LDS_DIRECT', 'lds_direct': 'SRC_LDS_DIRECT'}
@@ -615,25 +613,18 @@ def _parse_src_mods(raw: str) -> tuple[str, bool, bool, bool]:
if sext: raw = raw[5:-1]
return raw, neg, abs_, sext
_SGPR_BY_NAME = {v: k for k, v in SPECIAL_GPRS_CDNA.items()}
_SGPR_BY_NAME.update({'src_vccz': 251, 'src_execz': 252, 'src_scc': 253, 'vcc': 106})
def _parse_sdwa_src(raw: str) -> tuple[int, int]:
"""Parse SDWA source operand. Returns (value, s_flag) where s_flag=1 for SGPR/literal."""
# VGPRs: v0, v[0]
if raw.startswith('v') and (raw[1:].isdigit() or raw[1] == '['):
return int(raw[1:].split('[')[0]) if raw[1:].isdigit() else int(raw.split('[')[1].split(']')[0]), 0
# SGPRs: s0, s[0], s[0:1]
if raw.startswith('s') and (raw[1:].isdigit() or raw[1] == '['):
return int(raw[1:].split('[')[0]) if raw[1:].isdigit() else int(raw.split('[')[1].split(':')[0]), 1
# TTMPs: ttmp0, ttmp[0]
if raw.startswith('v') and (raw[1:].isdigit() or raw[1] == '['): return int(raw.split('[')[1].split(']')[0]) if '[' in raw else int(raw[1:]), 0
if raw.startswith('s') and (raw[1:].isdigit() or raw[1] == '['): return int(raw.split('[')[1].split(':')[0]) if '[' in raw else int(raw[1:]), 1
if raw.startswith('ttmp') and raw[4:].isdigit(): return 108 + int(raw[4:]), 1
# Special registers from SPECIAL_GPRS_CDNA (reverse lookup)
_SGPR_REV = {v: k for k, v in SPECIAL_GPRS_CDNA.items()}
_SGPR_REV.update({'src_vccz': 251, 'src_execz': 252, 'src_scc': 253, 'vcc': 106}) # extras not in SPECIAL_GPRS_CDNA
if raw in _SGPR_REV: return _SGPR_REV[raw], 1
if raw in _SGPR_BY_NAME: return _SGPR_BY_NAME[raw], 1
# Inline constants: integers 0-64 -> 128+N, -1 to -16 -> 192+abs(N), floats use FLOAT_ENC
if raw.lstrip('-').replace('.', '', 1).isdigit():
if '.' in raw:
try: return FLOAT_ENC.get(float(raw), 128), 1
except ValueError: return 128, 1
if '.' in raw: return FLOAT_ENC.get(float(raw), 128), 1
ival = int(raw)
if 0 <= ival <= 64: return 128 + ival, 1
if -16 <= ival < 0: return 192 + (-ival), 1
@@ -716,19 +707,16 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
(bits[0] | (bits[1] << 1) | (bits[2] << 3)) if len(bits) == 3 else sum(b << i for i, b in enumerate(bits))
m, text = _extract(text, r'\s+wait_exp:(\d+)'); waitexp = m.group(1) if m else None
m, text = _extract(text, r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)'); off_val = m.group(1) if m else None
m, text = _extract(text, r'\s+dlc(?:\s|$)'); dlc = 1 if m else None
m, text = _extract(text, r'\s+glc(?:\s|$)'); glc = 1 if m else None
m, text = _extract(text, r'\s+slc(?:\s|$)'); slc = 1 if m else None
# GFX942-specific modifiers: sc0, sc1, nt (and their negations)
m, text = _extract(text, r'\s+sc0(?:\s|$)'); sc0 = 1 if m else None
m, text = _extract(text, r'\s+nosc0(?:\s|$)'); sc0 = 0 if m else sc0
m, text = _extract(text, r'\s+sc1(?:\s|$)'); sc1 = 1 if m else None
m, text = _extract(text, r'\s+nosc1(?:\s|$)'); sc1 = 0 if m else sc1
m, text = _extract(text, r'\s+(?<!no)nt(?:\s|$)'); nt = 1 if m else None
m, text = _extract(text, r'\s+nont(?:\s|$)'); nt = 0 if m else nt
m, text = _extract(text, r'\s+tfe(?:\s|$)'); tfe = 1 if m else None
m, text = _extract(text, r'\s+offen(?:\s|$)'); offen = 1 if m else None
m, text = _extract(text, r'\s+idxen(?:\s|$)'); idxen = 1 if m else None
# Flag modifiers: extract presence/absence
flags = {}
for f in ('dlc', 'glc', 'slc', 'tfe', 'offen', 'idxen', 'gds', 'lds'):
m, text = _extract(text, rf'\s+{f}(?:\s|$)'); flags[f] = 1 if m else None
dlc, glc, slc, tfe, offen, idxen, gds, lds = [flags[f] for f in ('dlc', 'glc', 'slc', 'tfe', 'offen', 'idxen', 'gds', 'lds')]
# GFX942: sc0, sc1, nt with negation variants
for f in ('sc0', 'sc1', 'nt'):
m, text = _extract(text, rf'\s+{f}(?:\s|$)'); flags[f] = 1 if m else None
m, text = _extract(text, rf'\s+no{f}(?:\s|$)'); flags[f] = 0 if m else flags[f]
sc0, sc1, nt = flags['sc0'], flags['sc1'], flags['nt']
m, text = _extract(text, r'\s+format:\[([^\]]+)\]'); fmt_val = m.group(1) if m else None
m, text = _extract(text, r'\s+format:(\d+)'); fmt_val = m.group(1) if m and not fmt_val else fmt_val
# dfmt:N, nfmt:N can appear as comma-separated items (CDNA ACC style) or as space-separated modifiers
@@ -740,10 +728,8 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
m, text = _extract(text, r'\s+neg_hi:\[([^\]]+)\]'); neg_hi = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
m, text = _extract(text, r'\s+op_sel_hi:\[([^\]]+)\]')
opsel_hi, opsel_hi_count = (sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))), len(m.group(1).split(','))) if m else (None, 0)
m, text = _extract(text, r'\s+gds(?:\s|$)'); gds = 1 if m else None
m, text = _extract(text, r'\s+offset0:(\d+)'); offset0 = m.group(1) if m else None
m, text = _extract(text, r'\s+offset1:(\d+)'); offset1 = m.group(1) if m else None
m, text = _extract(text, r'\s+lds(?:\s|$)'); lds = 1 if m else None
# MAI instruction modifiers (MFMA cbsz/abid/blgp)
m, text = _extract(text, r'\s+cbsz:(\d+)'); cbsz = int(m.group(1)) if m else None
m, text = _extract(text, r'\s+abid:(\d+)'); abid = int(m.group(1)) if m else None
@@ -758,24 +744,17 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
m, text = _extract(text, r'\s+src1_sel:(\w+)'); sdwa_src1_sel = _sel(m.group(1)) if m else None
m, text = _extract(text, r'\s+sext\(src0\)'); sdwa_src0_sext = 1 if m else None
m, text = _extract(text, r'\s+sext\(src1\)'); sdwa_src1_sext = 1 if m else None
# DPP modifiers
# DPP modifiers: quad_perm, row_shl/shr/ror, wave_shl/rol/shr/ror, row_mirror, row_bcast, etc.
m, text = _extract(text, r'\s+quad_perm:\[(\d+),(\d+),(\d+),(\d+)\]')
dpp_ctrl = int(m.group(1)) | (int(m.group(2)) << 2) | (int(m.group(3)) << 4) | (int(m.group(4)) << 6) if m else None
m, text = _extract(text, r'\s+row_shl:(\d+)'); dpp_ctrl = 0x100 | int(m.group(1)) if m else dpp_ctrl
m, text = _extract(text, r'\s+row_shr:(\d+)'); dpp_ctrl = 0x110 | int(m.group(1)) if m else dpp_ctrl
m, text = _extract(text, r'\s+row_ror:(\d+)'); dpp_ctrl = 0x120 | int(m.group(1)) if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_shl:1'); dpp_ctrl = 0x130 if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_rol:1'); dpp_ctrl = 0x134 if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_shr:1'); dpp_ctrl = 0x138 if m else dpp_ctrl
m, text = _extract(text, r'\s+wave_ror:1'); dpp_ctrl = 0x13c if m else dpp_ctrl
m, text = _extract(text, r'\s+row_mirror(?:\s|$)'); dpp_ctrl = 0x140 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_half_mirror(?:\s|$)'); dpp_ctrl = 0x141 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_bcast:15(?:\s|$)'); dpp_ctrl = 0x142 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_bcast:31(?:\s|$)'); dpp_ctrl = 0x143 if m else dpp_ctrl
m, text = _extract(text, r'\s+row_newbcast:(\d+)'); dpp_ctrl = 0x150 + int(m.group(1)) if m else dpp_ctrl
for pat, base in [('row_shl', 0x100), ('row_shr', 0x110), ('row_ror', 0x120), ('row_newbcast', 0x150)]:
m, text = _extract(text, rf'\s+{pat}:(\d+)'); dpp_ctrl = base + int(m.group(1)) if m else dpp_ctrl
for pat, val in [('wave_shl:1', 0x130), ('wave_rol:1', 0x134), ('wave_shr:1', 0x138), ('wave_ror:1', 0x13c),
('row_mirror', 0x140), ('row_half_mirror', 0x141), ('row_bcast:15', 0x142), ('row_bcast:31', 0x143)]:
m, text = _extract(text, rf'\s+{pat}(?:\s|$)'); dpp_ctrl = val if m else dpp_ctrl
m, text = _extract(text, r'\s+row_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_row_mask = int(m.group(1), 0) if m else None; dpp_row_mask_specified = m is not None
m, text = _extract(text, r'\s+bank_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_bank_mask = int(m.group(1), 0) if m else None; dpp_bank_mask_specified = m is not None
m, text = _extract(text, r'\s+bound_ctrl:([01])'); dpp_bound_ctrl = 1 if m else None # bound_ctrl:0 or bound_ctrl:1 both set bit to 1
m, text = _extract(text, r'\s+bound_ctrl:([01])'); dpp_bound_ctrl = 1 if m else None
if waitexp: kw.append(f'waitexp={waitexp}')
parts = text.replace(',', ' ').split()