mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
simpler
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_GPRS_CDNA, SPECIAL_PAIRS, SPECIAL_PAIRS_CDNA, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp, MTBUFOp)
|
||||
@@ -605,6 +605,40 @@ def _extract(text: str, pat: str, flags=re.I):
|
||||
if m := re.search(pat, text, flags): return m, text[:m.start()] + ' ' + text[m.end():]
|
||||
return None, text
|
||||
|
||||
def _parse_src_mods(raw: str) -> tuple[str, bool, bool, bool]:
|
||||
"""Parse neg/abs/sext modifiers from operand string. Returns (stripped_op, neg, abs_, sext)."""
|
||||
neg = raw.startswith('-') and not raw[1:2].isdigit() and raw[1:3] != '0.'
|
||||
if neg: raw = raw[1:]
|
||||
abs_ = raw.startswith('|') and raw.endswith('|')
|
||||
if abs_: raw = raw[1:-1]
|
||||
sext = raw.startswith('sext(') and raw.endswith(')')
|
||||
if sext: raw = raw[5:-1]
|
||||
return raw, neg, abs_, sext
|
||||
|
||||
def _parse_sdwa_src(raw: str) -> tuple[int, int]:
|
||||
"""Parse SDWA source operand. Returns (value, s_flag) where s_flag=1 for SGPR/literal."""
|
||||
# VGPRs: v0, v[0]
|
||||
if raw.startswith('v') and (raw[1:].isdigit() or raw[1] == '['):
|
||||
return int(raw[1:].split('[')[0]) if raw[1:].isdigit() else int(raw.split('[')[1].split(']')[0]), 0
|
||||
# SGPRs: s0, s[0], s[0:1]
|
||||
if raw.startswith('s') and (raw[1:].isdigit() or raw[1] == '['):
|
||||
return int(raw[1:].split('[')[0]) if raw[1:].isdigit() else int(raw.split('[')[1].split(':')[0]), 1
|
||||
# TTMPs: ttmp0, ttmp[0]
|
||||
if raw.startswith('ttmp') and raw[4:].isdigit(): return 108 + int(raw[4:]), 1
|
||||
# Special registers from SPECIAL_GPRS_CDNA (reverse lookup)
|
||||
_SGPR_REV = {v: k for k, v in SPECIAL_GPRS_CDNA.items()}
|
||||
_SGPR_REV.update({'src_vccz': 251, 'src_execz': 252, 'src_scc': 253, 'vcc': 106}) # extras not in SPECIAL_GPRS_CDNA
|
||||
if raw in _SGPR_REV: return _SGPR_REV[raw], 1
|
||||
# Inline constants: integers 0-64 -> 128+N, -1 to -16 -> 192+abs(N), floats use FLOAT_ENC
|
||||
if raw.lstrip('-').replace('.', '', 1).isdigit():
|
||||
if '.' in raw:
|
||||
try: return FLOAT_ENC.get(float(raw), 128), 1
|
||||
except ValueError: return 128, 1
|
||||
ival = int(raw)
|
||||
if 0 <= ival <= 64: return 128 + ival, 1
|
||||
if -16 <= ival < 0: return 192 + (-ival), 1
|
||||
return 0, 0
|
||||
|
||||
# Instruction aliases: LLVM uses different names for some instructions
|
||||
_ALIASES = {
|
||||
'v_cmp_tru_f16': 'v_cmp_t_f16', 'v_cmp_tru_f32': 'v_cmp_t_f32', 'v_cmp_tru_f64': 'v_cmp_t_f64',
|
||||
@@ -716,13 +750,12 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
m, text = _extract(text, r'\s+blgp:(\d+)'); blgp = int(m.group(1)) if m else None
|
||||
# MFMA neg:[x,y,z] modifier -> sets neg field (same as blgp for MFMA)
|
||||
m, text = _extract(text, r'\s+neg:\[([^\]]+)\]'); mfma_neg = sum(int(x.strip()) << i for i, x in enumerate(m.group(1).split(','))) if m else None
|
||||
# SDWA modifiers
|
||||
_SDWA_SEL = {'BYTE_0': 0, 'BYTE_1': 1, 'BYTE_2': 2, 'BYTE_3': 3, 'WORD_0': 4, 'WORD_1': 5, 'DWORD': 6}
|
||||
_SDWA_DST_UNUSED = {'UNUSED_PAD': 0, 'UNUSED_SEXT': 1, 'UNUSED_PRESERVE': 2}
|
||||
m, text = _extract(text, r'\s+dst_sel:(\w+)'); sdwa_dst_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
|
||||
m, text = _extract(text, r'\s+dst_unused:(\w+)'); sdwa_dst_unused = _SDWA_DST_UNUSED.get(m.group(1), 0) if m else None
|
||||
m, text = _extract(text, r'\s+src0_sel:(\w+)'); sdwa_src0_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
|
||||
m, text = _extract(text, r'\s+src1_sel:(\w+)'); sdwa_src1_sel = _SDWA_SEL.get(m.group(1), 6) if m else None
|
||||
# SDWA modifiers: sel values are BYTE_0-3=0-3, WORD_0-1=4-5, DWORD=6; dst_unused PAD=0, SEXT=1, PRESERVE=2
|
||||
def _sel(s): return {'BYTE_0': 0, 'BYTE_1': 1, 'BYTE_2': 2, 'BYTE_3': 3, 'WORD_0': 4, 'WORD_1': 5, 'DWORD': 6}.get(s, 6)
|
||||
m, text = _extract(text, r'\s+dst_sel:(\w+)'); sdwa_dst_sel = _sel(m.group(1)) if m else None
|
||||
m, text = _extract(text, r'\s+dst_unused:(\w+)'); sdwa_dst_unused = {'UNUSED_PAD': 0, 'UNUSED_SEXT': 1, 'UNUSED_PRESERVE': 2}.get(m.group(1), 0) if m else None
|
||||
m, text = _extract(text, r'\s+src0_sel:(\w+)'); sdwa_src0_sel = _sel(m.group(1)) if m else None
|
||||
m, text = _extract(text, r'\s+src1_sel:(\w+)'); sdwa_src1_sel = _sel(m.group(1)) if m else None
|
||||
m, text = _extract(text, r'\s+sext\(src0\)'); sdwa_src0_sext = 1 if m else None
|
||||
m, text = _extract(text, r'\s+sext\(src1\)'); sdwa_src1_sext = 1 if m else None
|
||||
# DPP modifiers
|
||||
@@ -765,40 +798,37 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP3POp, VOP1Op
|
||||
# Handle aliases: v_accvgpr_read_b32 -> v_accvgpr_read, v_accvgpr_write_b32 -> v_accvgpr_write
|
||||
fn = mn.replace('_b32', '').upper()
|
||||
# MFMA instruction name mapping: LLVM names to enum names
|
||||
# LLVM: v_mfma_f32_32x32x1f32 -> Enum: V_MFMA_F32_32X32X1_2B_F32
|
||||
# NOTE: gfx90a uses different opcodes than gfx942 for some instructions
|
||||
_MFMA_ALIASES = {
|
||||
'V_MFMA_F32_32X32X1F32': 'V_MFMA_F32_32X32X1_2B_F32', 'V_MFMA_F32_16X16X1F32': 'V_MFMA_F32_16X16X1_4B_F32',
|
||||
'V_MFMA_F32_4X4X1F32': 'V_MFMA_F32_4X4X1_16B_F32', 'V_MFMA_F32_32X32X2F32': 'V_MFMA_F32_32X32X2_F32',
|
||||
'V_MFMA_F32_16X16X4F32': 'V_MFMA_F32_16X16X4_F32',
|
||||
'V_MFMA_F32_32X32X4F16': 'V_MFMA_F32_32X32X4_2B_F16', 'V_MFMA_F32_16X16X4F16': 'V_MFMA_F32_16X16X4_4B_F16',
|
||||
'V_MFMA_F32_4X4X4F16': 'V_MFMA_F32_4X4X4_16B_F16', 'V_MFMA_F32_32X32X8F16': 'V_MFMA_F32_32X32X8_F16',
|
||||
'V_MFMA_F32_16X16X16F16': 'V_MFMA_F32_16X16X16_F16',
|
||||
'V_MFMA_I32_32X32X4I8': 'V_MFMA_I32_32X32X4_2B_I8', 'V_MFMA_I32_16X16X4I8': 'V_MFMA_I32_16X16X4_4B_I8',
|
||||
'V_MFMA_I32_4X4X4I8': 'V_MFMA_I32_4X4X4_16B_I8',
|
||||
'V_MFMA_F64_16X16X4F64': 'V_MFMA_F64_16X16X4_F64', 'V_MFMA_F64_4X4X4F64': 'V_MFMA_F64_4X4X4_4B_F64',
|
||||
# GFX942-specific aliases
|
||||
'V_MFMA_I32_32X32X16I8': 'V_MFMA_I32_32X32X16_I8', 'V_MFMA_I32_16X16X32I8': 'V_MFMA_I32_16X16X32_I8',
|
||||
'V_MFMA_F32_32X32X4BF16': 'V_MFMA_F32_32X32X4_2B_BF16', 'V_MFMA_F32_16X16X4BF16': 'V_MFMA_F32_16X16X4_4B_BF16',
|
||||
'V_MFMA_F32_4X4X4BF16': 'V_MFMA_F32_4X4X4_16B_BF16',
|
||||
'V_MFMA_F32_32X32X8BF16': 'V_MFMA_F32_32X32X8_BF16', 'V_MFMA_F32_16X16X16BF16': 'V_MFMA_F32_16X16X16_BF16',
|
||||
# _1K variants map to same opcodes as non-1K (the _1K is for compatibility)
|
||||
'V_MFMA_F32_32X32X4BF16_1K': 'V_MFMA_F32_32X32X4_2B_BF16', 'V_MFMA_F32_16X16X4BF16_1K': 'V_MFMA_F32_16X16X4_4B_BF16',
|
||||
'V_MFMA_F32_4X4X4BF16_1K': 'V_MFMA_F32_4X4X4_16B_BF16',
|
||||
'V_MFMA_F32_32X32X8BF16_1K': 'V_MFMA_F32_32X32X8_BF16', 'V_MFMA_F32_16X16X16BF16_1K': 'V_MFMA_F32_16X16X16_BF16',
|
||||
# XF32 aliases
|
||||
'V_MFMA_F32_16X16X8XF32': 'V_MFMA_F32_16X16X8_XF32', 'V_MFMA_F32_32X32X4XF32': 'V_MFMA_F32_32X32X4_XF32',
|
||||
# SMFMAC aliases (LLVM name -> enum name)
|
||||
'V_SMFMAC_F32_16X16X32F16': 'V_SMFMAC_F32_16X16X32_F16', 'V_SMFMAC_F32_32X32X16F16': 'V_SMFMAC_F32_32X32X16_F16',
|
||||
'V_SMFMAC_F32_16X16X32BF16': 'V_SMFMAC_F32_16X16X32_BF16', 'V_SMFMAC_F32_32X32X16BF16': 'V_SMFMAC_F32_32X32X16_BF16',
|
||||
'V_SMFMAC_I32_16X16X64I8': 'V_SMFMAC_I32_16X16X64_I8', 'V_SMFMAC_I32_32X32X32I8': 'V_SMFMAC_I32_32X32X32_I8',
|
||||
# FP8/BF8 SMFMAC aliases
|
||||
'V_SMFMAC_F32_16X16X64BF8BF8': 'V_SMFMAC_F32_16X16X64_BF8_BF8', 'V_SMFMAC_F32_16X16X64BF8FP8': 'V_SMFMAC_F32_16X16X64_BF8_FP8',
|
||||
'V_SMFMAC_F32_16X16X64FP8BF8': 'V_SMFMAC_F32_16X16X64_FP8_BF8', 'V_SMFMAC_F32_16X16X64FP8FP8': 'V_SMFMAC_F32_16X16X64_FP8_FP8',
|
||||
'V_SMFMAC_F32_32X32X32BF8BF8': 'V_SMFMAC_F32_32X32X32_BF8_BF8', 'V_SMFMAC_F32_32X32X32BF8FP8': 'V_SMFMAC_F32_32X32X32_BF8_FP8',
|
||||
'V_SMFMAC_F32_32X32X32FP8BF8': 'V_SMFMAC_F32_32X32X32_FP8_BF8', 'V_SMFMAC_F32_32X32X32FP8FP8': 'V_SMFMAC_F32_32X32X32_FP8_FP8',
|
||||
}
|
||||
# MFMA/SMFMAC name mapping: LLVM v_mfma_f32_32x32x1f32 -> enum V_MFMA_F32_32X32X1_2B_F32
|
||||
def _mfma_alias(n):
|
||||
n = n.replace('_1K', '') # Strip _1K suffix first (same opcodes)
|
||||
# FP8/BF8 pairs: insert underscore between AND before: 64BF8BF8 -> 64_BF8_BF8
|
||||
for t in ('BF8BF8', 'BF8FP8', 'FP8BF8', 'FP8FP8'):
|
||||
if t in n: n = n.replace(t, '_' + t[:3] + '_' + t[3:])
|
||||
# Insert underscore before dtype suffix
|
||||
for t in ('F32', 'F16', 'BF16', 'I8', 'F64', 'XF32'):
|
||||
n = re.sub(rf'(X\d+)({t})$', rf'\1_{t}', n)
|
||||
# Block sizes for specific shapes
|
||||
for pat, blk in [('32X32X1_', '32X32X1_2B_'), ('16X16X1_', '16X16X1_4B_'), ('4X4X1_', '4X4X1_16B_'),
|
||||
('32X32X4_F16', '32X32X4_2B_F16'), ('16X16X4_F16', '16X16X4_4B_F16'), ('4X4X4_F16', '4X4X4_16B_F16'),
|
||||
('32X32X4_I8', '32X32X4_2B_I8'), ('16X16X4_I8', '16X16X4_4B_I8'), ('4X4X4_I8', '4X4X4_16B_I8'),
|
||||
('32X32X4_BF16', '32X32X4_2B_BF16'), ('16X16X4_BF16', '16X16X4_4B_BF16'), ('4X4X4_BF16', '4X4X4_16B_BF16'),
|
||||
('4X4X4_F64', '4X4X4_4B_F64')]:
|
||||
n = n.replace(pat, blk)
|
||||
return n
|
||||
_MFMA_ALIASES = {n: _mfma_alias(n) for n in [
|
||||
'V_MFMA_F32_32X32X1F32', 'V_MFMA_F32_16X16X1F32', 'V_MFMA_F32_4X4X1F32', 'V_MFMA_F32_32X32X2F32', 'V_MFMA_F32_16X16X4F32',
|
||||
'V_MFMA_F32_32X32X4F16', 'V_MFMA_F32_16X16X4F16', 'V_MFMA_F32_4X4X4F16', 'V_MFMA_F32_32X32X8F16', 'V_MFMA_F32_16X16X16F16',
|
||||
'V_MFMA_F32_32X32X16F16', 'V_MFMA_F32_16X16X32F16',
|
||||
'V_MFMA_I32_32X32X4I8', 'V_MFMA_I32_16X16X4I8', 'V_MFMA_I32_4X4X4I8', 'V_MFMA_I32_32X32X16I8', 'V_MFMA_I32_16X16X32I8',
|
||||
'V_MFMA_F64_16X16X4F64', 'V_MFMA_F64_4X4X4F64',
|
||||
'V_MFMA_F32_32X32X4BF16', 'V_MFMA_F32_16X16X4BF16', 'V_MFMA_F32_4X4X4BF16', 'V_MFMA_F32_32X32X8BF16', 'V_MFMA_F32_16X16X16BF16',
|
||||
'V_MFMA_F32_32X32X16BF16', 'V_MFMA_F32_16X16X32BF16',
|
||||
'V_MFMA_F32_32X32X4BF16_1K', 'V_MFMA_F32_16X16X4BF16_1K', 'V_MFMA_F32_4X4X4BF16_1K', 'V_MFMA_F32_32X32X8BF16_1K', 'V_MFMA_F32_16X16X16BF16_1K',
|
||||
'V_MFMA_F32_16X16X8XF32', 'V_MFMA_F32_32X32X4XF32',
|
||||
'V_SMFMAC_F32_16X16X32F16', 'V_SMFMAC_F32_32X32X16F16', 'V_SMFMAC_F32_16X16X32BF16', 'V_SMFMAC_F32_32X32X16BF16',
|
||||
'V_SMFMAC_I32_16X16X64I8', 'V_SMFMAC_I32_32X32X32I8',
|
||||
'V_SMFMAC_F32_16X16X64BF8BF8', 'V_SMFMAC_F32_16X16X64BF8FP8', 'V_SMFMAC_F32_16X16X64FP8BF8', 'V_SMFMAC_F32_16X16X64FP8FP8',
|
||||
'V_SMFMAC_F32_32X32X32BF8BF8', 'V_SMFMAC_F32_32X32X32BF8FP8', 'V_SMFMAC_F32_32X32X32FP8BF8', 'V_SMFMAC_F32_32X32X32FP8FP8']}
|
||||
# GFX90a-specific opcodes (different from gfx942 enum): map to raw opcode values
|
||||
# These instructions use opcodes that don't match our gfx942-based enum
|
||||
_MFMA_GFX90A_OPS = {
|
||||
@@ -883,242 +913,82 @@ def get_dsl(text: str, arch: str = "rdna3", gfx942: bool = False) -> str:
|
||||
# SDWA instructions (CDNA)
|
||||
if mn.endswith('_sdwa') and arch == "cdna":
|
||||
base_mn = mn[:-5] # strip _sdwa
|
||||
# Get VOP1/VOP2/VOPC opcode
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, VOPCOp, SDWA
|
||||
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
|
||||
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
|
||||
vopc_op = getattr(VOPCOp, base_mn.upper(), None)
|
||||
vop1_op, vop2_op, vopc_op = getattr(VOP1Op, base_mn.upper(), None), getattr(VOP2Op, base_mn.upper(), None), getattr(VOPCOp, base_mn.upper(), None)
|
||||
if vop1_op is None and vop2_op is None and vopc_op is None: raise ValueError(f"unknown SDWA instruction: {mn}")
|
||||
# Parse operands: vdst, [vcc,] src0[, vsrc1]
|
||||
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
|
||||
vdst = args[0] # keep as v[N] for VGPRField
|
||||
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
|
||||
has_carry = base_mn in carry_out_ops
|
||||
src0_idx = 2 if has_carry else 1
|
||||
src1_idx = 3 if has_carry else 2
|
||||
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
|
||||
src0 = args[1] if len(args) > 1 else 'v[0]'
|
||||
# Parse neg/abs/sext modifiers from src0_raw
|
||||
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit() and src0_raw[1:3] != '0.'
|
||||
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
|
||||
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
|
||||
if src0_neg_mod: src0_raw = src0_raw[1:]
|
||||
if src0_abs_mod: src0_raw = src0_raw[1:-1]
|
||||
if src0_sext_mod: src0_raw = src0_raw[5:-1]
|
||||
# Extract src0 register number for RawImm
|
||||
_SDWA_SGPR_MAP = {'flat_scratch_lo': 102, 'flat_scratch_hi': 103, 'xnack_mask_lo': 104, 'xnack_mask_hi': 105,
|
||||
'vcc_lo': 106, 'vcc_hi': 107, 'vcc': 106, 'ttmp0': 108, 'ttmp1': 109, 'ttmp2': 110, 'ttmp3': 111,
|
||||
'ttmp4': 112, 'ttmp5': 113, 'ttmp6': 114, 'ttmp7': 115, 'ttmp8': 116, 'ttmp9': 117,
|
||||
'ttmp10': 118, 'ttmp11': 119, 'ttmp12': 120, 'ttmp13': 121, 'ttmp14': 122, 'ttmp15': 123,
|
||||
'm0': 124, 'exec_lo': 126, 'exec_hi': 127,
|
||||
'src_vccz': 251, 'src_execz': 252, 'src_scc': 253}
|
||||
# Inline constant encoding for SDWA src0
|
||||
_SDWA_INLINE_CONST = {'0': 128, '0.0': 128, '1': 129, '1.0': 242, '2': 130, '3': 131, '4': 132, '-1': 193, '-2': 194, '-3': 195, '-4': 196,
|
||||
'0.5': 240, '-0.5': 241, '-1.0': 243, '2.0': 244, '-2.0': 245, '4.0': 246, '-4.0': 247}
|
||||
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
|
||||
s0 = 0
|
||||
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
|
||||
s0 = 1
|
||||
elif src0_raw in _SDWA_SGPR_MAP: src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
|
||||
elif src0_raw.startswith('ttmp') and src0_raw[4:].isdigit(): src0_val, s0 = 108 + int(src0_raw[4:]), 1
|
||||
elif src0_raw in _SDWA_INLINE_CONST: src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
|
||||
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
|
||||
# Integer or float inline constant
|
||||
if '.' in src0_raw:
|
||||
src0_val, s0 = _SDWA_INLINE_CONST.get(src0_raw, (0, 0))
|
||||
if src0_val == 0 and src0_raw != '0.0': s0 = 0
|
||||
else: s0 = 1
|
||||
else:
|
||||
ival = int(src0_raw)
|
||||
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
|
||||
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
|
||||
else: src0_val, s0 = 0, 0 # Not an inline constant
|
||||
else: src0_val, s0 = 0, 0
|
||||
# For VOP2, parse vsrc1 and its modifiers
|
||||
vsrc1_val, src1_neg_mod, src1_abs_mod, src1_sext_mod, s1 = 0, False, False, False, 0
|
||||
if vop2_op is not None and len(ops) > src1_idx:
|
||||
src1_raw = ops[src1_idx].strip().lower()
|
||||
# Parse neg/abs/sext modifiers
|
||||
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit() and src1_raw[1:3] != '0.'
|
||||
if src1_neg_mod: src1_raw = src1_raw[1:]
|
||||
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
|
||||
if src1_abs_mod: src1_raw = src1_raw[1:-1]
|
||||
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
|
||||
if src1_sext_mod: src1_raw = src1_raw[5:-1]
|
||||
# Extract vsrc1 register number
|
||||
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
|
||||
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
|
||||
s1 = 0
|
||||
elif src1_raw in _SDWA_SGPR_MAP: vsrc1_val, s1 = _SDWA_SGPR_MAP[src1_raw], 1
|
||||
elif src1_raw in _SDWA_INLINE_CONST: vsrc1_val, s1 = _SDWA_INLINE_CONST[src1_raw], 1
|
||||
# Build SDWA kwargs
|
||||
# VOP1 SDWA: vop_op = VOP1 opcode, vop2_op = 0x3f (63)
|
||||
# VOP2 SDWA: vop_op = vsrc1, vop2_op = VOP2 opcode
|
||||
# VOPC SDWA: vop_op = src1, vop2_op = 0x3e (62), vdst = VOPC opcode, dst_sel/dst_u/clmp/omod = sdst encoding
|
||||
sdwa_kw = []
|
||||
# Operand layout: vdst, [vcc,] src0[, vsrc1] - carry-out ops have vcc at index 1
|
||||
carry_out = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
|
||||
src0_idx, src1_idx = (2, 3) if base_mn in carry_out else (1, 2)
|
||||
# VOPC SDWA: sdst at [0], src0 at [1], src1 at [2]
|
||||
if vopc_op is not None:
|
||||
# VOPC SDWA: opcode goes in vdst field, vop2_op=62
|
||||
# Parse sdst from first operand (e.g., vcc, s[n:n+1], flat_scratch, ttmp[n:n+1])
|
||||
_SDWA_SDST_MAP = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 128+102, 'flat_scratch_lo': 128+102,
|
||||
'ttmp0': 128+108, 'ttmp2': 128+110, 'ttmp4': 128+112, 'ttmp6': 128+114,
|
||||
'ttmp8': 128+116, 'ttmp10': 128+118, 'ttmp12': 128+120, 'ttmp14': 128+122}
|
||||
_SDWA_SDST = {'vcc': 0, 'vcc_lo': 0, 'flat_scratch': 230, 'flat_scratch_lo': 230}
|
||||
sdst_raw = ops[0].strip().lower()
|
||||
if sdst_raw in _SDWA_SDST_MAP: sdst_enc = _SDWA_SDST_MAP[sdst_raw]
|
||||
elif sdst_raw.startswith('s[') and ':' in sdst_raw: sdst_enc = 128 + int(sdst_raw[2:].split(':')[0])
|
||||
elif sdst_raw.startswith('s') and sdst_raw[1:].isdigit(): sdst_enc = 128 + int(sdst_raw[1:])
|
||||
elif sdst_raw.startswith('ttmp[') and ':' in sdst_raw: sdst_enc = 128 + 108 + int(sdst_raw[5:].split(':')[0])
|
||||
else: sdst_enc = 0 # Default: vcc
|
||||
# For VOPC SDWA, src0 is ops[1], src1 is ops[2]
|
||||
src0_raw = ops[1].strip().lower() if len(ops) > 1 else 'v0'
|
||||
src1_raw = ops[2].strip().lower() if len(ops) > 2 else 'v0'
|
||||
# Parse src0 with modifiers
|
||||
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
|
||||
if src0_neg_mod: src0_raw = src0_raw[1:]
|
||||
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
|
||||
if src0_abs_mod: src0_raw = src0_raw[1:-1]
|
||||
src0_sext_mod = src0_raw.startswith('sext(') and src0_raw.endswith(')')
|
||||
if src0_sext_mod: src0_raw = src0_raw[5:-1]
|
||||
# Extract src0 value and type
|
||||
if src0_raw.startswith('v') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0])
|
||||
s0 = 0
|
||||
elif src0_raw.startswith('s') and (src0_raw[1:].isdigit() or src0_raw[1] == '['):
|
||||
src0_val = int(src0_raw[1:].split('[')[0]) if src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(':')[0])
|
||||
s0 = 1
|
||||
elif src0_raw in _SDWA_SGPR_MAP:
|
||||
src0_val, s0 = _SDWA_SGPR_MAP[src0_raw], 1
|
||||
elif src0_raw in _SDWA_INLINE_CONST:
|
||||
src0_val, s0 = _SDWA_INLINE_CONST[src0_raw], 1
|
||||
elif src0_raw.lstrip('-').replace('.', '', 1).isdigit():
|
||||
# Integer or float inline constant
|
||||
if '.' in src0_raw:
|
||||
src0_val = _SDWA_INLINE_CONST.get(src0_raw, 128)
|
||||
s0 = 1
|
||||
else:
|
||||
ival = int(src0_raw)
|
||||
if 0 <= ival <= 64: src0_val, s0 = 128 + ival, 1
|
||||
elif -16 <= ival < 0: src0_val, s0 = 192 + (-ival), 1
|
||||
else: src0_val, s0 = 0, 0
|
||||
else: src0_val, s0 = 0, 0
|
||||
# Parse src1 with modifiers
|
||||
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
|
||||
if src1_neg_mod: src1_raw = src1_raw[1:]
|
||||
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
|
||||
if src1_abs_mod: src1_raw = src1_raw[1:-1]
|
||||
src1_sext_mod = src1_raw.startswith('sext(') and src1_raw.endswith(')')
|
||||
if src1_sext_mod: src1_raw = src1_raw[5:-1]
|
||||
# Extract src1 value and type
|
||||
if src1_raw.startswith('v') and (src1_raw[1:].isdigit() or src1_raw[1] == '['):
|
||||
vsrc1_val = int(src1_raw[1:].split('[')[0]) if src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0])
|
||||
s1 = 0
|
||||
else: vsrc1_val, s1 = 0, 0
|
||||
sdwa_kw.append(f'vop_op={vsrc1_val}')
|
||||
sdwa_kw.append('vop2_op=62') # 0x3e indicates VOPC mode
|
||||
sdwa_kw.append(f'vdst=RawImm({vopc_op.value})') # VOPC opcode in vdst
|
||||
sdwa_kw.append(f'src0=RawImm({src0_val})')
|
||||
# Encode sdst in dst_sel/dst_u/clmp/omod fields
|
||||
sdwa_kw.append(f'dst_sel={sdst_enc & 7}')
|
||||
sdwa_kw.append(f'dst_u={(sdst_enc >> 3) & 3}')
|
||||
sdwa_kw.append(f'clmp={(sdst_enc >> 5) & 1}')
|
||||
sdwa_kw.append(f'omod={(sdst_enc >> 6) & 3}')
|
||||
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
|
||||
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
|
||||
if src0_sext_mod or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
|
||||
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
|
||||
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
|
||||
sdst_enc = _SDWA_SDST.get(sdst_raw, 128 + int(sdst_raw[2:].split(':')[0]) if sdst_raw.startswith('s[') else
|
||||
128 + int(sdst_raw[1:]) if sdst_raw.startswith('s') and sdst_raw[1:].isdigit() else
|
||||
128 + 108 + int(sdst_raw[5:].split(':')[0]) if sdst_raw.startswith('ttmp[') else 0)
|
||||
src0_raw, src0_neg, src0_abs, src0_sext = _parse_src_mods(ops[1].strip().lower() if len(ops) > 1 else 'v0')
|
||||
src1_raw, src1_neg, src1_abs, src1_sext = _parse_src_mods(ops[2].strip().lower() if len(ops) > 2 else 'v0')
|
||||
src0_val, s0 = _parse_sdwa_src(src0_raw)
|
||||
vsrc1_val, s1 = _parse_sdwa_src(src1_raw)
|
||||
sdwa_kw = [f'vop_op={vsrc1_val}', 'vop2_op=62', f'vdst=RawImm({vopc_op.value})', f'src0=RawImm({src0_val})',
|
||||
f'dst_sel={sdst_enc & 7}', f'dst_u={(sdst_enc >> 3) & 3}', f'clmp={(sdst_enc >> 5) & 1}', f'omod={(sdst_enc >> 6) & 3}',
|
||||
f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}', f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}']
|
||||
if src0_sext or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
|
||||
if src0_neg: sdwa_kw.append('src0_neg=1')
|
||||
if src0_abs: sdwa_kw.append('src0_abs=1')
|
||||
if s0: sdwa_kw.append('s0=1')
|
||||
if src1_sext_mod or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
|
||||
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
|
||||
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
|
||||
if src1_sext or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
|
||||
if src1_neg: sdwa_kw.append('src1_neg=1')
|
||||
if src1_abs: sdwa_kw.append('src1_abs=1')
|
||||
if s1: sdwa_kw.append('s1=1')
|
||||
return f"SDWA({', '.join(sdwa_kw)})"
|
||||
elif vop1_op is not None:
|
||||
sdwa_kw.append(f'vop_op={vop1_op.value}')
|
||||
sdwa_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
|
||||
else:
|
||||
sdwa_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 SDWA
|
||||
sdwa_kw.append(f'vop2_op={vop2_op.value}')
|
||||
sdwa_kw.append(f'vdst={vdst}')
|
||||
sdwa_kw.append(f'src0=RawImm({src0_val})')
|
||||
# Defaults: dst_sel=6 (DWORD), dst_unused=2 (UNUSED_PRESERVE), src0_sel=6 (DWORD), src1_sel=6 (DWORD)
|
||||
sdwa_kw.append(f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}')
|
||||
sdwa_kw.append(f'dst_u={sdwa_dst_unused if sdwa_dst_unused is not None else 2}')
|
||||
sdwa_kw.append(f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}')
|
||||
if sdwa_src0_sext or src0_sext_mod: sdwa_kw.append('src0_sext=1')
|
||||
if src0_neg_mod: sdwa_kw.append('src0_neg=1')
|
||||
if src0_abs_mod: sdwa_kw.append('src0_abs=1')
|
||||
# VOP1/VOP2 SDWA
|
||||
src0_raw, src0_neg, src0_abs, src0_sext = _parse_src_mods(ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0')
|
||||
src0_val, s0 = _parse_sdwa_src(src0_raw)
|
||||
vsrc1_val, src1_neg, src1_abs, src1_sext, s1 = 0, False, False, False, 0
|
||||
if vop2_op is not None and len(ops) > src1_idx:
|
||||
src1_raw, src1_neg, src1_abs, src1_sext = _parse_src_mods(ops[src1_idx].strip().lower())
|
||||
vsrc1_val, s1 = _parse_sdwa_src(src1_raw)
|
||||
sdwa_kw = [f'vop_op={vop1_op.value if vop1_op else vsrc1_val}', f'vop2_op={63 if vop1_op else vop2_op.value}',
|
||||
f'vdst={args[0]}', f'src0=RawImm({src0_val})', f'dst_sel={sdwa_dst_sel if sdwa_dst_sel is not None else 6}',
|
||||
f'dst_u={sdwa_dst_unused if sdwa_dst_unused is not None else 2}', f'src0_sel={sdwa_src0_sel if sdwa_src0_sel is not None else 6}']
|
||||
if src0_sext or sdwa_src0_sext: sdwa_kw.append('src0_sext=1')
|
||||
if src0_neg: sdwa_kw.append('src0_neg=1')
|
||||
if src0_abs: sdwa_kw.append('src0_abs=1')
|
||||
if s0: sdwa_kw.append('s0=1')
|
||||
# VOP2 SDWA src1 modifiers and defaults
|
||||
if vop2_op is not None:
|
||||
sdwa_kw.append(f'src1_sel={sdwa_src1_sel if sdwa_src1_sel is not None else 6}')
|
||||
if sdwa_src1_sext or src1_sext_mod: sdwa_kw.append('src1_sext=1')
|
||||
if src1_neg_mod: sdwa_kw.append('src1_neg=1')
|
||||
if src1_abs_mod: sdwa_kw.append('src1_abs=1')
|
||||
if src1_sext or sdwa_src1_sext: sdwa_kw.append('src1_sext=1')
|
||||
if src1_neg: sdwa_kw.append('src1_neg=1')
|
||||
if src1_abs: sdwa_kw.append('src1_abs=1')
|
||||
if s1: sdwa_kw.append('s1=1')
|
||||
# Add clamp/omod from kw if present
|
||||
for k in kw:
|
||||
if k.startswith('clmp='): sdwa_kw.append(k)
|
||||
elif k.startswith('omod='): sdwa_kw.append(k)
|
||||
if k.startswith('clmp=') or k.startswith('omod='): sdwa_kw.append(k)
|
||||
return f"SDWA({', '.join(sdwa_kw)})"
|
||||
|
||||
# DPP instructions (CDNA)
|
||||
if mn.endswith('_dpp') and arch == "cdna" and dpp_ctrl is not None:
|
||||
base_mn = mn[:-4] # strip _dpp
|
||||
from extra.assembly.amd.autogen.cdna.ins import VOP1Op, VOP2Op, DPP
|
||||
vop1_op = getattr(VOP1Op, base_mn.upper(), None)
|
||||
vop2_op = getattr(VOP2Op, base_mn.upper(), None)
|
||||
vop1_op, vop2_op = getattr(VOP1Op, base_mn.upper(), None), getattr(VOP2Op, base_mn.upper(), None)
|
||||
if vop1_op is None and vop2_op is None: raise ValueError(f"unknown DPP instruction: {mn}")
|
||||
# Parse operands: vdst, [vcc,] src0[, vsrc1]
|
||||
# For carry-out ops (v_add_co_u32, etc.), vcc is at ops[1], src0 is at ops[2], vsrc1 is at ops[3]
|
||||
vdst = args[0]
|
||||
carry_out_ops = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
|
||||
has_carry = base_mn in carry_out_ops
|
||||
src0_idx = 2 if has_carry else 1
|
||||
src1_idx = 3 if has_carry else 2
|
||||
src0_raw = ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0'
|
||||
# Parse neg/abs modifiers for src0 (neg before abs for -|v1| case)
|
||||
src0_neg_mod = src0_raw.startswith('-') and not src0_raw[1:2].isdigit()
|
||||
if src0_neg_mod: src0_raw = src0_raw[1:]
|
||||
src0_abs_mod = src0_raw.startswith('|') and src0_raw.endswith('|')
|
||||
if src0_abs_mod: src0_raw = src0_raw[1:-1]
|
||||
# Extract src0 VGPR number
|
||||
if src0_raw.startswith('v') and src0_raw[1:].isdigit(): src0_val = int(src0_raw[1:])
|
||||
elif 'v[' in src0_raw: src0_val = int(src0_raw.split('[')[1].split(']')[0])
|
||||
else: src0_val = 0
|
||||
# For VOP2, parse vsrc1 and its modifiers
|
||||
vsrc1_val, src1_neg_mod, src1_abs_mod = 0, False, False
|
||||
carry_out = {'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32', 'v_addc_co_u32', 'v_subb_co_u32', 'v_subbrev_co_u32'}
|
||||
src0_idx, src1_idx = (2, 3) if base_mn in carry_out else (1, 2)
|
||||
src0_raw, src0_neg, src0_abs, _ = _parse_src_mods(ops[src0_idx].strip().lower() if len(ops) > src0_idx else 'v0')
|
||||
src0_val = int(src0_raw[1:]) if src0_raw.startswith('v') and src0_raw[1:].isdigit() else int(src0_raw.split('[')[1].split(']')[0]) if 'v[' in src0_raw else 0
|
||||
vsrc1_val, src1_neg, src1_abs = 0, False, False
|
||||
if vop2_op is not None and len(ops) > src1_idx:
|
||||
src1_raw = ops[src1_idx].strip().lower()
|
||||
src1_neg_mod = src1_raw.startswith('-') and not src1_raw[1:2].isdigit()
|
||||
if src1_neg_mod: src1_raw = src1_raw[1:]
|
||||
src1_abs_mod = src1_raw.startswith('|') and src1_raw.endswith('|')
|
||||
if src1_abs_mod: src1_raw = src1_raw[1:-1]
|
||||
if src1_raw.startswith('v') and src1_raw[1:].isdigit(): vsrc1_val = int(src1_raw[1:])
|
||||
elif 'v[' in src1_raw: vsrc1_val = int(src1_raw.split('[')[1].split(']')[0])
|
||||
# Build DPP kwargs
|
||||
# VOP1 DPP: vop_op = VOP1 opcode, vop2_op = 0x3f
|
||||
# VOP2 DPP: vop_op = vsrc1, vop2_op = VOP2 opcode
|
||||
dpp_kw = []
|
||||
if vop1_op is not None:
|
||||
dpp_kw.append(f'vop_op={vop1_op.value}')
|
||||
dpp_kw.append('vop2_op=63') # 0x3f indicates VOP1 mode
|
||||
else:
|
||||
dpp_kw.append(f'vop_op={vsrc1_val}') # vsrc1 goes in vop_op for VOP2 DPP
|
||||
dpp_kw.append(f'vop2_op={vop2_op.value}')
|
||||
dpp_kw.append(f'vdst={vdst}')
|
||||
dpp_kw.append(f'src0=RawImm({src0_val})')
|
||||
dpp_kw.append(f'dpp_ctrl={dpp_ctrl}')
|
||||
src1_raw, src1_neg, src1_abs, _ = _parse_src_mods(ops[src1_idx].strip().lower())
|
||||
vsrc1_val = int(src1_raw[1:]) if src1_raw.startswith('v') and src1_raw[1:].isdigit() else int(src1_raw.split('[')[1].split(']')[0]) if 'v[' in src1_raw else 0
|
||||
dpp_kw = [f'vop_op={vop1_op.value if vop1_op else vsrc1_val}', f'vop2_op={63 if vop1_op else vop2_op.value}',
|
||||
f'vdst={args[0]}', f'src0=RawImm({src0_val})', f'dpp_ctrl={dpp_ctrl}']
|
||||
if dpp_bound_ctrl: dpp_kw.append('bound_ctrl=1')
|
||||
if src0_neg_mod: dpp_kw.append('src0_neg=1')
|
||||
if src0_abs_mod: dpp_kw.append('src0_abs=1')
|
||||
if src1_neg_mod: dpp_kw.append('src1_neg=1')
|
||||
if src1_abs_mod: dpp_kw.append('src1_abs=1')
|
||||
# Default masks: if one is specified but not the other, the other defaults to 0xf
|
||||
if src0_neg: dpp_kw.append('src0_neg=1')
|
||||
if src0_abs: dpp_kw.append('src0_abs=1')
|
||||
if src1_neg: dpp_kw.append('src1_neg=1')
|
||||
if src1_abs: dpp_kw.append('src1_abs=1')
|
||||
if dpp_bank_mask_specified or dpp_row_mask_specified:
|
||||
dpp_kw.append(f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 0xf}')
|
||||
dpp_kw.append(f'row_mask={dpp_row_mask if dpp_row_mask is not None else 0xf}')
|
||||
dpp_kw.extend([f'bank_mask={dpp_bank_mask if dpp_bank_mask is not None else 0xf}', f'row_mask={dpp_row_mask if dpp_row_mask is not None else 0xf}'])
|
||||
return f"DPP({', '.join(dpp_kw)})"
|
||||
|
||||
# VOPD (RDNA3 only)
|
||||
|
||||
Reference in New Issue
Block a user