mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
more asm
This commit is contained in:
@@ -540,6 +540,10 @@ FLOATS = {str(k): k for k in FLOAT_ENC} # Valid float literal strings: '0.5', '
|
||||
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
|
||||
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
|
||||
's_buffer_load_b32', 's_buffer_load_b64', 's_buffer_load_b128', 's_buffer_load_b256', 's_buffer_load_b512',
|
||||
's_scratch_load_dword', 's_scratch_load_dwordx2', 's_scratch_load_dwordx4',
|
||||
's_scratch_store_dword', 's_scratch_store_dwordx2', 's_scratch_store_dwordx4',
|
||||
's_store_dword', 's_store_dwordx2', 's_store_dwordx4',
|
||||
's_buffer_store_dword', 's_buffer_store_dwordx2', 's_buffer_store_dwordx4',
|
||||
's_atc_probe', 's_atc_probe_buffer'}
|
||||
SPEC_DSL = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', 'm0': 'M0',
|
||||
'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', 'src_scc': 'SCC'}
|
||||
@@ -671,6 +675,32 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
m, text = _extract(text, r'\s+offset0:(\d+)'); offset0 = m.group(1) if m else None
|
||||
m, text = _extract(text, r'\s+offset1:(\d+)'); offset1 = m.group(1) if m else None
|
||||
m, text = _extract(text, r'\s+lds(?:\s|$)'); lds = 1 if m else None
|
||||
# SDWA modifiers
|
||||
_SDWA_SEL = {'BYTE_0': 0, 'BYTE_1': 1, 'BYTE_2': 2, 'BYTE_3': 3, 'WORD_0': 4, 'WORD_1': 5, 'DWORD': 6}
|
||||
_SDWA_DST_UNUSED = {'UNUSED_PAD': 0, 'UNUSED_SEXT': 1, 'UNUSED_PRESERVE': 2}
|
||||
m, text = _extract(text, r'\s+dst_sel:(\w+)'); sdwa_dst_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
|
||||
m, text = _extract(text, r'\s+dst_unused:(\w+)'); sdwa_dst_unused = _SDWA_DST_UNUSED.get(m.group(1), 0) if m else None
|
||||
m, text = _extract(text, r'\s+src0_sel:(\w+)'); sdwa_src0_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
|
||||
m, text = _extract(text, r'\s+src1_sel:(\w+)'); sdwa_src1_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
|
||||
m, text = _extract(text, r'\s+sext\(src0\)'); sdwa_src0_sext = 1 if m else None
|
||||
m, text = _extract(text, r'\s+sext\(src1\)'); sdwa_src1_sext = 1 if m else None
|
||||
# DPP modifiers
|
||||
m, text = _extract(text, r'\s+quad_perm:\[(\d+),(\d+),(\d+),(\d+)\]')
|
||||
dpp_ctrl = int(m.group(1)) | (int(m.group(2)) << 2) | (int(m.group(3)) << 4) | (int(m.group(4)) << 6) if m else None
|
||||
m, text = _extract(text, r'\s+row_shl:(\d+)'); dpp_ctrl = 0x100 | int(m.group(1)) if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_shr:(\d+)'); dpp_ctrl = 0x110 | int(m.group(1)) if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_ror:(\d+)'); dpp_ctrl = 0x120 | int(m.group(1)) if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+wave_shl:1'); dpp_ctrl = 0x130 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+wave_rol:1'); dpp_ctrl = 0x134 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+wave_shr:1'); dpp_ctrl = 0x138 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+wave_ror:1'); dpp_ctrl = 0x13c if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_mirror(?:\s|$)'); dpp_ctrl = 0x140 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_half_mirror(?:\s|$)'); dpp_ctrl = 0x141 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_bcast:15(?:\s|$)'); dpp_ctrl = 0x142 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_bcast:31(?:\s|$)'); dpp_ctrl = 0x143 if m else dpp_ctrl
|
||||
m, text = _extract(text, r'\s+row_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_row_mask = int(m.group(1), 0) if m else None; dpp_row_mask_specified = m is not None
|
||||
m, text = _extract(text, r'\s+bank_mask:(0x[0-9a-fA-F]+|\d+)'); dpp_bank_mask = int(m.group(1), 0) if m else None; dpp_bank_mask_specified = m is not None
|
||||
m, text = _extract(text, r'\s+bound_ctrl:([01])'); dpp_bound_ctrl = 1 if m else None # bound_ctrl:0 or bound_ctrl:1 both set bit to 1
|
||||
if waitexp: kw.append(f'waitexp={waitexp}')
|
||||
|
||||
parts = text.replace(',', ' ').split()
|
||||
@@ -688,6 +718,169 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
elif re.match(r'^0x[0-9a-f]+$|^\d+$', p): return f"s_waitcnt(simm16={int(p, 0)})"
|
||||
return f"s_waitcnt(simm16={waitcnt(vm, exp, lgkm)})"
|
||||
|
||||
# SDWA instructions (CDNA)
|
||||
if mn.endswith('_sdwa') and arch == "cdna":
|
||||
base_mn = mn[:-5] # strip _sdwa
|
||||
# Get VOP1/VOP2 opcode
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, SDWA
|
||||
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
|
||||
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
|
||||
if vop1_op is None and vop2_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
|
||||
# Parse operands: vdst, [vcc,] src0[, vsrc1]
|
||||
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
|
||||
vdst = args[0] # keep as v[N] for VGPRField
|
||||
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
|
||||
has_carry = base_mn in carry_out_ops
|
||||
src0_idx = 2 if has_carry else 1
|
||||
src1_idx = 3 if has_carry else 2
|
||||
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
|
||||
src0 = args[1] if len(args) > 1 else 'v[0]'
|
||||
# Parse neg/abs/sext modifiers from src0_raw
|
||||
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit() and src0_raw[1:3] != '0.'
|
||||
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
|
||||
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
|
||||
if src0_neg_mod: src0_raw = src0_raw[1:]
|
||||
if src0_abs_mod: src0_raw = src0_raw[1:-1]
|
||||
if src0_sext_mod: src0_raw = src0_raw[5:-1]
|
||||
# Extract src0 register number for RawImm
|
||||
_SDWA_SGPR_MAP = {'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105,
|
||||
'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'ttmp0': 108, 'ttmp1': 109, 'ttmp2': 110, 'ttmp3': 111,
|
||||
'ttmp4': 112, 'ttmp5': 113, 'ttmp6': 114, 'ttmp7': 115, 'ttmp8': 116, 'ttmp9': 117,
|
||||
'ttmp10': 118, 'ttmp11': 119, 'ttmp12': 120, 'ttmp13': 121, 'ttmp14': 122, 'ttmp15': 123,
|
||||
'm0': 124, 'exec_lo': 126, 'exec_hi': 127,
|
||||
'src_vccz': 251, 'src_execz': 252, 'src_scc': 253}
|
||||
# Inline constant encoding for SDWA src0
|
||||
_SDWA_INLINE_CONST = {'0': 128, '0.0': 128, '1': 129, '1.0': 242, '2': 130, '3': 131, '4': 132, '-1': 193, '-2': 194, '-3': 195, '-4': 196,
|
||||
'0.5': 240, '-0.5': 241, '-1.0': 243, '2.0': 244, '-2.0': 245, '4.0': 246, '-4.0': 247}
|
||||
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
|
||||
s0 = 0
|
||||
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
|
||||
s0 = 1
|
||||
elif src0_raw in _SDWA_SGPR_MAP: src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
|
||||
elif src0_raw.startswith('ttmp') and src0_raw[4:].isdigit(): src0_val, s0 = 108 + int(src0_raw[4:]), 1
|
||||
elif src0_raw in _SDWA_INLINE_CONST: src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
|
||||
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
|
||||
# Integer or float inline constant
|
||||
if '.' in src0_raw:
|
||||
src0_val, s0 = _SDWA_INLINE_CONST.get(src0_raw, (0, 0))
|
||||
if src0_val == 0 and src0_raw != '0.0': s0 = 0
|
||||
else: s0 = 1
|
||||
else:
|
||||
ival = int(src0_raw)
|
||||
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
|
||||
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
|
||||
else: src0_val, s0 = 0, 0 # Not an inline constant
|
||||
else: src0_val, s0 = 0, 0
|
||||
# For VOP2, parse vsrc1 and its modifiers
|
||||
vsrc1_val, src1_neg_mod, src1_abs_mod, src1_sext_mod, s1 = 0, False, False, False, 0
|
||||
if vop2_op is not None and len(ops) > src1_idx:
|
||||
src1_raw = ops[src1_idx].strip().lower()
|
||||
# Parse neg/abs/sext modifiers
|
||||
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit() and src1_raw[1:3] != '0.'
|
||||
if src1_neg_mod: src1_raw = src1_raw[1:]
|
||||
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
|
||||
if src1_abs_mod: src1_raw = src1_raw[1:-1]
|
||||
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
|
||||
if src1_sext_mod: src1_raw = src1_raw[5:-1]
|
||||
# Extract vsrc1 register number
|
||||
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
|
||||
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
|
||||
s1 = 0
|
||||
elif src1_raw in _SDWA_SGPR_MAP: vsrc1_val, s1 = _SDWA_SGPR_MAP[src1_raw], 1
|
||||
elif src1_raw in _SDWA_INLINE_CONST: vsrc1_val, s1 = _SDWA_INLINE_CONST[src1_raw], 1
|
||||
# Build SDWA kwargs
|
||||
# VOP1 SDWA: vop_op = VOP1 opcode, vop2_op = 0x3f (63)
|
||||
# VOP2 SDWA: vop_op = vsrc1, vop2_op = VOP2 opcode
|
||||
sdwa_kw = []
|
||||
if vop1_op is not None:
|
||||
sdwa_kw.append(f'vop_op={vop1_op.value}')
|
||||
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
|
||||
else:
|
||||
sdwa_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 SDWA
|
||||
sdwa_kw.append(f'vop2_op={vop2_op.value}')
|
||||
sdwa_kw.append(f'vdst={vdst}')
|
||||
sdwa_kw.append(f'src0=RawImm({src0_val})')
|
||||
# Defaults: dst_sel=6 (DWORD), dst_unused=2 (UNUSED_PRESERVE), src0_sel=6 (DWORD), src1_sel=6 (DWORD)
|
||||
sdwa_kw.append(f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}')
|
||||
sdwa_kw.append(f'dst_u={sdwa_dst_unused if sdwa_dst_unused is not None else 2}')
|
||||
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
|
||||
if sdwa_src0_sext or src0_sext_mod: sdwa_kw.append('src0_sext=1')
|
||||
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
|
||||
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
|
||||
if s0: sdwa_kw.append('s0=1')
|
||||
# VOP2 SDWA src1 modifiers and defaults
|
||||
if vop2_op is not None:
|
||||
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
|
||||
if sdwa_src1_sext or src1_sext_mod: sdwa_kw.append('src1_sext=1')
|
||||
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
|
||||
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
|
||||
if s1: sdwa_kw.append('s1=1')
|
||||
# Add clamp/omod from kw if present
|
||||
for k in kw:
|
||||
if k.startswith('clmp='): sdwa_kw.append(k)
|
||||
elif k.startswith('omod='): sdwa_kw.append(k)
|
||||
return f"SDWA({', '.join(sdwa_kw)})"
|
||||
|
||||
# DPP instructions (CDNA)
|
||||
if mn.endswith('_dpp') and arch == "cdna" and dpp_ctrl is not None:
|
||||
base_mn = mn[:-4] # strip _dpp
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, DPP
|
||||
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
|
||||
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
|
||||
if vop1_op is None and vop2_op is None: raise ValueError(f"unknown DPP instruction: {mn}")
|
||||
# Parse operands: vdst, [vcc,] src0[, vsrc1]
|
||||
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
|
||||
vdst = args[0]
|
||||
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
|
||||
has_carry = base_mn in carry_out_ops
|
||||
src0_idx = 2 if has_carry else 1
|
||||
src1_idx = 3 if has_carry else 2
|
||||
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
|
||||
# Parse neg/abs modifiers for src0 (neg before abs for -|v1| case)
|
||||
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
|
||||
if src0_neg_mod: src0_raw = src0_raw[1:]
|
||||
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
|
||||
if src0_abs_mod: src0_raw = src0_raw[1:-1]
|
||||
# Extract src0 VGPR number
|
||||
if src0_raw.startswith('v') and src0_raw[1:].isdigit(): src0_val = int(src0_raw[1:])
|
||||
elif 'v[' in src0_raw: src0_val = int(src0_raw.split('[')[1].split(']')[0])
|
||||
else: src0_val = 0
|
||||
# For VOP2, parse vsrc1 and its modifiers
|
||||
vsrc1_val, src1_neg_mod, src1_abs_mod = 0, False, False
|
||||
if vop2_op is not None and len(ops) > src1_idx:
|
||||
src1_raw = ops[src1_idx].strip().lower()
|
||||
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
|
||||
if src1_neg_mod: src1_raw = src1_raw[1:]
|
||||
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
|
||||
if src1_abs_mod: src1_raw = src1_raw[1:-1]
|
||||
if src1_raw.startswith('v') and src1_raw[1:].isdigit(): vsrc1_val = int(src1_raw[1:])
|
||||
elif 'v[' in src1_raw: vsrc1_val = int(src1_raw.split('[')[1].split(']')[0])
|
||||
# Build DPP kwargs
|
||||
# VOP1 DPP: vop_op = VOP1 opcode, vop2_op = 0x3f
|
||||
# VOP2 DPP: vop_op = vsrc1, vop2_op = VOP2 opcode
|
||||
dpp_kw = []
|
||||
if vop1_op is not None:
|
||||
dpp_kw.append(f'vop_op={vop1_op.value}')
|
||||
dpp_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
|
||||
else:
|
||||
dpp_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 DPP
|
||||
dpp_kw.append(f'vop2_op={vop2_op.value}')
|
||||
dpp_kw.append(f'vdst={vdst}')
|
||||
dpp_kw.append(f'src0=RawImm({src0_val})')
|
||||
dpp_kw.append(f'dpp_ctrl={dpp_ctrl}')
|
||||
if dpp_bound_ctrl: dpp_kw.append('bound_ctrl=1')
|
||||
if src0_neg_mod: dpp_kw.append('src0_neg=1')
|
||||
if src0_abs_mod: dpp_kw.append('src0_abs=1')
|
||||
if src1_neg_mod: dpp_kw.append('src1_neg=1')
|
||||
if src1_abs_mod: dpp_kw.append('src1_abs=1')
|
||||
# Default masks: if one is specified but not the other, the other defaults to 0xf
|
||||
if dpp_bank_mask_specified or dpp_row_mask_specified:
|
||||
dpp_kw.append(f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 0xf}')
|
||||
dpp_kw.append(f'row_mask={dpp_row_mask if dpp_row_mask is not None else 0xf}')
|
||||
return f"DPP({', '.join(dpp_kw)})"
|
||||
|
||||
# VOPD (RDNA3 only)
|
||||
if '::' in text:
|
||||
xp, yp = text.split('::')
|
||||
@@ -745,7 +938,20 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
fmt_num = None
|
||||
if fmt_val is not None:
|
||||
if fmt_val.isdigit(): fmt_num = int(fmt_val)
|
||||
else: fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
|
||||
else:
|
||||
fmt_num = BUF_FMT.get(fmt_val.replace(' ', '')) or _parse_buf_fmt_combo(fmt_val)
|
||||
# CDNA-style: BUF_DATA_FORMAT_X or BUF_NUM_FORMAT_X (or comma-separated pair)
|
||||
if fmt_num is None and arch == "cdna":
|
||||
_dfmt = {'INVALID': 0, '8': 1, '16': 2, '8_8': 3, '32': 4, '16_16': 5, '10_11_11': 6, '11_11_10': 7,
|
||||
'10_10_10_2': 8, '2_10_10_10': 9, '8_8_8_8': 10, '32_32': 11, '16_16_16_16': 12,
|
||||
'32_32_32': 13, '32_32_32_32': 14, 'RESERVED_15': 15}
|
||||
_nfmt = {'UNORM': 0, 'SNORM': 1, 'USCALED': 2, 'SSCALED': 3, 'UINT': 4, 'SINT': 5, 'RESERVED_6': 6, 'FLOAT': 7}
|
||||
parts = [p.strip() for p in fmt_val.split(',')]
|
||||
dfmt, nfmt = 1, 0 # defaults
|
||||
for p in parts:
|
||||
if p.startswith('BUF_DATA_FORMAT_'): dfmt = _dfmt.get(p[16:], 1)
|
||||
elif p.startswith('BUF_NUM_FORMAT_'): nfmt = _nfmt.get(p[15:], 0)
|
||||
fmt_num = dfmt | (nfmt << 4)
|
||||
# Handle special no-arg buffer ops
|
||||
if mn in ('buffer_gl0_inv', 'buffer_gl1_inv', 'buffer_wbl2', 'buffer_inv'): return f"{mn}()"
|
||||
# Build modifiers string - CDNA uses sc0/nt for glc/slc
|
||||
@@ -755,7 +961,9 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
else:
|
||||
buf_mods = "".join([f", offset={off_val}" if off_val else "", ", glc=1" if glc else "", ", dlc=1" if dlc else "",
|
||||
", slc=1" if slc else "", ", tfe=1" if tfe else "", ", offen=1" if offen else "", ", idxen=1" if idxen else ""])
|
||||
if is_tbuf and fmt_num is not None: buf_mods = f", format={fmt_num}" + buf_mods
|
||||
# Default format for tbuffer is dfmt=1, nfmt=0 (format=8 after encoding as (nfmt<<4)|dfmt becomes just dfmt=1)
|
||||
# Actually format is (dfmt | (nfmt << 4)), so dfmt=1, nfmt=0 -> format=1
|
||||
if is_tbuf: buf_mods = f", format={fmt_num if fmt_num is not None else 1}" + buf_mods
|
||||
# Handle LDS mode: first operand is 'off' meaning no vdata, it goes to LDS
|
||||
if len(ops) >= 1 and ops[0].strip().lower() == 'off':
|
||||
# LDS mode: buffer_load_format_x off, srsrc, soffset -> no vdata, just vaddr=off
|
||||
@@ -774,23 +982,38 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
return f"{mn}(vdata={args[0]}, vaddr={vaddr_val}, srsrc={srsrc_val}, soffset={soff_val}{buf_mods})"
|
||||
|
||||
# FLAT/GLOBAL/SCRATCH load/store/atomic - saddr needs RawImm for off/null
|
||||
# CDNA: flat uses saddr=0 for off, RDNA: uses saddr=124 (NULL)
|
||||
# CDNA: flat uses saddr=0 for off, global/scratch use saddr=0x7F (127) for off
|
||||
# RDNA: uses saddr=124 (NULL)
|
||||
# CDNA: uses sc0/sc1 for glc/slc
|
||||
saddr_off = 'RawImm(0)' if arch == "cdna" else 'RawImm(124)'
|
||||
def _saddr(a): return saddr_off if a in ('OFF', 'NULL') else a
|
||||
def _saddr_off(seg): return 'RawImm(0)' if arch == 'cdna' and seg == 'flat' else ('RawImm(127)' if arch == 'cdna' else 'RawImm(124)')
|
||||
def _saddr(a, seg='global'): return _saddr_off(seg) if a in ('OFF', 'NULL') else a
|
||||
if arch == "cdna":
|
||||
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', sc0=1' if glc else ''}{', nt=1' if slc else ''}"
|
||||
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', sc0=1' if glc else ''}{', nt=1' if slc else ''}{', lds=1' if lds else ''}"
|
||||
else:
|
||||
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}"
|
||||
flat_mods = f"{f', offset={off_val}' if off_val else ''}{', glc=1' if glc else ''}{', slc=1' if slc else ''}{', dlc=1' if dlc else ''}{', lds=1' if lds else ''}"
|
||||
for pre, flds in [('flat_load','vdst,addr,saddr'), ('global_load','vdst,addr,saddr'), ('scratch_load','vdst,addr,saddr'),
|
||||
('flat_store','addr,data,saddr'), ('global_store','addr,data,saddr'), ('scratch_store','addr,data,saddr')]:
|
||||
if mn.startswith(pre) and len(args) >= 2:
|
||||
f0, f1, f2 = flds.split(',')
|
||||
return f"{mn}({f0}={args[0]}, {f1}={args[1]}{f', {f2}={_saddr(args[2])}' if len(args) >= 3 else f', saddr={saddr_off}'}{flat_mods})"
|
||||
seg = pre.split('_')[0] # 'flat', 'global', or 'scratch'
|
||||
# LDS mode: args=[addr, saddr], vdst=0, data goes to LDS
|
||||
if lds and 'load' in pre:
|
||||
addr_val = 'v[0]' if seg == 'scratch' and args[0] == 'OFF' else args[0]
|
||||
saddr_val = _saddr(args[1], seg) if len(args) >= 2 else _saddr_off(seg)
|
||||
return f"{mn}(vdst=v[0], addr={addr_val}, saddr={saddr_val}{flat_mods})"
|
||||
# For scratch, 'off' as vaddr means vaddr=0 (no offset), not null register
|
||||
# For load: args=[vdst, addr, saddr], for store: args=[addr, data, saddr]
|
||||
if 'store' in pre:
|
||||
addr_val = 'v[0]' if seg == 'scratch' and args[0] == 'OFF' else args[0]
|
||||
return f"{mn}({f0}={addr_val}, {f1}={args[1]}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
|
||||
else:
|
||||
addr_val = 'v[0]' if seg == 'scratch' and args[1] == 'OFF' else args[1]
|
||||
return f"{mn}({f0}={args[0]}, {f1}={addr_val}{f', {f2}={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
|
||||
for pre in ('flat_atomic', 'global_atomic', 'scratch_atomic'):
|
||||
if mn.startswith(pre):
|
||||
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3])}' if len(args) >= 4 else ', saddr=RawImm(124)'}{flat_mods})"
|
||||
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2])}' if len(args) >= 3 else ', saddr=RawImm(124)'}{flat_mods})"
|
||||
seg = pre.split('_')[0] # 'flat', 'global', or 'scratch'
|
||||
if glc and len(args) >= 3: return f"{mn}(vdst={args[0]}, addr={args[1]}, data={args[2]}{f', saddr={_saddr(args[3], seg)}' if len(args) >= 4 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
|
||||
if len(args) >= 2: return f"{mn}(addr={args[0]}, data={args[1]}{f', saddr={_saddr(args[2], seg)}' if len(args) >= 3 else f', saddr={_saddr_off(seg)}'}{flat_mods})"
|
||||
|
||||
# DS instructions
|
||||
if mn.startswith('ds_'):
|
||||
@@ -823,10 +1046,17 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}, data1={args[3]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}, data1={args[2]}{off_kw})"
|
||||
return f"{mn}(vdst={args[0]}, addr={args[1]}, data0={args[2]}{off_kw})" if '_rtn' in mn else f"{mn}(addr={args[0]}, data0={args[1]}{off_kw})"
|
||||
|
||||
# v_fmaak/v_fmamk literal extraction
|
||||
# v_fmaak/v_fmamk literal handling
|
||||
# RDNA3: use literal= keyword arg; CDNA: keep literal in positional args for _e32 variant
|
||||
# v_fmamk_e32(vdst, src0, K, vsrc1); v_fmaak_e32(vdst, src0, vsrc1, K)
|
||||
lit_s = ""
|
||||
if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
|
||||
elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
|
||||
if arch == "cdna":
|
||||
# For CDNA, reorder args to match _e32 signature: fmamk(vdst, src0, K, vsrc1), fmaak(vdst, src0, vsrc1, K)
|
||||
if mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: args = [args[0], args[1], args[2], args[3]] # already correct order
|
||||
elif mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: args = [args[0], args[1], args[2], args[3]] # already correct order
|
||||
else:
|
||||
if mn in ('v_fmaak_f32', 'v_fmaak_f16') and len(args) == 4: lit_s, args = f", literal={args[3].strip()}", args[:3]
|
||||
elif mn in ('v_fmamk_f32', 'v_fmamk_f16') and len(args) == 4: lit_s, args = f", literal={args[2].strip()}", [args[0], args[1], args[3]]
|
||||
|
||||
# VCC ops cleanup
|
||||
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
|
||||
@@ -834,9 +1064,9 @@ def get_dsl(text: str, arch: str = "rdna3") -> str:
|
||||
if mn.replace('_e64', '') in vcc_ops and mn.endswith('_e64'): mn = mn.replace('_e64', '')
|
||||
if mn.startswith('v_cmp') and not mn.endswith('_e64') and len(args) >= 3 and ops[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): args = args[1:]
|
||||
if 'cmpx' in mn and mn.endswith('_e64') and len(args) == 2: args = ['RawImm(126)'] + args
|
||||
# v_cmp_*_e64 has SGPR destination in vdst field - encode as RawImm
|
||||
# v_cmp_*_e64 and v_cmpx_*_e64 have SGPR destination in vdst field - encode as RawImm
|
||||
_SGPR_NAMES = {'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'null': 124, 'm0': 125, 'exec_lo': 126, 'exec_hi': 127}
|
||||
if mn.startswith('v_cmp') and 'cmpx' not in mn and mn.endswith('_e64') and len(args) >= 1:
|
||||
if mn.startswith('v_cmp') and mn.endswith('_e64') and len(args) >= 1:
|
||||
dst = ops[0].strip().lower()
|
||||
if dst.startswith('s') and dst[1:].isdigit(): args[0] = f'RawImm({int(dst[1:])})'
|
||||
elif dst.startswith('s[') and ':' in dst: args[0] = f'RawImm({int(dst[2:].split(":")[0])})'
|
||||
@@ -911,8 +1141,14 @@ def asm(text: str, arch: str = "rdna3") -> Inst:
|
||||
# - has _vop3=True marker (from _e64 instructions without operands)
|
||||
uses_vop3_kwargs = 'vdst=' in dsl or 'src0=' in dsl or '_vop3=True' in dsl
|
||||
if arch == "cdna" and (m := re.match(r'^(v_\w+)(\(.*\))$', dsl)) and not m.group(1).endswith('_e64') and not uses_vop3_kwargs:
|
||||
e32_name = f"{m.group(1)}_e32"
|
||||
if e32_name in ns: return eval(f"{e32_name}{m.group(2)}", ns)
|
||||
fn_name, args_str = m.group(1), m.group(2)
|
||||
e32_name = f"{fn_name}_e32"
|
||||
# VOP2 carry ops: v_add_co_u32(vdst, vcc, src0, vsrc1) -> v_add_co_u32_e32(vdst, src0, vsrc1)
|
||||
# Strip VCC argument (2nd arg) for VOP2 carry operations when using _e32
|
||||
if e32_name in ns and fn_name in _VOP2_CARRY_OUT | _VOP2_CARRY_INOUT:
|
||||
args_match = re.match(r'\(([^,]+),\s*[^,]+,\s*(.+)\)$', args_str)
|
||||
if args_match: args_str = f"({args_match.group(1)}, {args_match.group(2)})"
|
||||
if e32_name in ns: return eval(f"{e32_name}{args_str}", ns)
|
||||
# For CDNA, _e64 suffix maps to base name (VOP3)
|
||||
if arch == "cdna" and (m := re.match(r'^(v_\w+)_e64(\(.*\))$', dsl)):
|
||||
base_name = m.group(1)
|
||||
|
||||
@@ -974,8 +974,8 @@ v_sub_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_F16)
|
||||
v_subrev_f16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_F16)
|
||||
v_mul_f16_e32 = functools.partial(VOP2, VOP2Op.V_MUL_F16)
|
||||
v_mac_f16_e32 = functools.partial(VOP2, VOP2Op.V_MAC_F16)
|
||||
v_madmk_f16_e32 = functools.partial(VOP2, VOP2Op.V_MADMK_F16)
|
||||
v_madak_f16_e32 = functools.partial(VOP2, VOP2Op.V_MADAK_F16)
|
||||
def v_madmk_f16_e32(vdst, src0, K, vsrc1): return VOP2(VOP2Op.V_MADMK_F16, vdst, src0, vsrc1, literal=K)
|
||||
def v_madak_f16_e32(vdst, src0, vsrc1, K): return VOP2(VOP2Op.V_MADAK_F16, vdst, src0, vsrc1, literal=K)
|
||||
v_add_u16_e32 = functools.partial(VOP2, VOP2Op.V_ADD_U16)
|
||||
v_sub_u16_e32 = functools.partial(VOP2, VOP2Op.V_SUB_U16)
|
||||
v_subrev_u16_e32 = functools.partial(VOP2, VOP2Op.V_SUBREV_U16)
|
||||
|
||||
Reference in New Issue
Block a user