mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
assembly/amd: have asm go through the dsl (#13886)
* assembly/amd: have asm go through the dsl * lil
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# RDNA3 assembler and disassembler
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, FLOAT_ENC, SRC_FIELDS, unwrap
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, FLOAT_ENC, SRC_FIELDS, unwrap
|
||||
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
|
||||
|
||||
# Decoding helpers
|
||||
SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"}
|
||||
@@ -464,7 +465,8 @@ def disasm(inst: Inst) -> str:
|
||||
return f"{op_name} {', '.join(ops)}" if ops else op_name
|
||||
|
||||
# Assembler
|
||||
SPECIAL_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'scc': RawImm(253)}
|
||||
SPECIAL_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125),
|
||||
'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)}
|
||||
FLOAT_CONSTS = {'0.5': 0.5, '-0.5': -0.5, '1.0': 1.0, '-1.0': -1.0, '2.0': 2.0, '-2.0': -2.0, '4.0': 4.0, '-4.0': -4.0}
|
||||
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
|
||||
|
||||
@@ -504,102 +506,255 @@ SOPK_IMM_ONLY = {'s_version'}
|
||||
SOPK_IMM_FIRST = {'s_setreg_b32'}
|
||||
SOPK_UNSUPPORTED = {'s_setreg_imm32_b32'}
|
||||
|
||||
def asm(text: str) -> Inst:
|
||||
from extra.assembly.amd.autogen import rdna3 as autogen
|
||||
def _operand_to_dsl(op: str) -> str:
|
||||
"""Transform a single operand from LLVM assembly syntax to DSL expression string."""
|
||||
op = op.strip()
|
||||
# Handle negation prefix
|
||||
neg = False
|
||||
if op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')):
|
||||
neg, op = True, op[1:]
|
||||
# Handle abs modifier: |x| or abs(x)
|
||||
abs_ = False
|
||||
if op.startswith('|') and op.endswith('|'):
|
||||
abs_, op = True, op[1:-1]
|
||||
elif op.startswith('abs(') and op.endswith(')'):
|
||||
abs_, op = True, op[4:-1]
|
||||
# Handle .h/.l suffix for 16-bit ops
|
||||
hi_suffix = ""
|
||||
if op.endswith('.h'): hi_suffix, op = ".h", op[:-2]
|
||||
elif op.endswith('.l'): hi_suffix, op = ".l", op[:-2]
|
||||
op_lower = op.lower()
|
||||
|
||||
# Helper to apply modifiers
|
||||
def apply_mods(base: str) -> str:
|
||||
if not neg and not abs_: return f"{base}{hi_suffix}"
|
||||
if abs_: return f"{'-' if neg else ''}abs({base}){hi_suffix}"
|
||||
return f"-{base}{hi_suffix}"
|
||||
|
||||
# Special registers - vcc maps to VCC_LO (64-bit alias)
|
||||
special_map = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF',
|
||||
'm0': 'M0', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC',
|
||||
'src_scc': 'SCC'}
|
||||
if op_lower in special_map: return apply_mods(special_map[op_lower])
|
||||
# Float constants
|
||||
float_map = {'0.5': '0.5', '-0.5': '-0.5', '1.0': '1.0', '-1.0': '-1.0', '2.0': '2.0', '-2.0': '-2.0', '4.0': '4.0', '-4.0': '-4.0'}
|
||||
if op in float_map: return apply_mods(float_map[op])
|
||||
# Register range: v[0:3], s[4:7]
|
||||
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op_lower):
|
||||
prefix = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}[m.group(1)]
|
||||
return apply_mods(f"{prefix}[{m.group(2)}:{m.group(3)}]")
|
||||
# Single register: v0, s1, ttmp5
|
||||
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op_lower):
|
||||
prefix = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}[m.group(1)]
|
||||
return apply_mods(f"{prefix}[{m.group(2)}]")
|
||||
# Integer literals (decimal or hex) - use SrcMod wrapper when modifiers present
|
||||
if re.match(r'^-?\d+$', op) or re.match(r'^-?0x([0-9a-fA-F]+)$', op):
|
||||
if neg or abs_:
|
||||
return f"SrcMod({op}, neg={neg}, abs_={abs_})"
|
||||
return op
|
||||
# hwreg(name, offset, size) -> pass through
|
||||
if op_lower.startswith('hwreg('): return apply_mods(op)
|
||||
# sendmsg(...) -> pass through
|
||||
if op_lower.startswith('sendmsg('): return apply_mods(op)
|
||||
# Fallback: return as-is
|
||||
return apply_mods(op)
|
||||
|
||||
def _parse_operands(op_str: str) -> list[str]:
|
||||
"""Parse comma-separated operands, respecting brackets and pipes."""
|
||||
operands, current, depth, in_pipe = [], "", 0, False
|
||||
for ch in op_str:
|
||||
if ch in '[(': depth += 1
|
||||
elif ch in '])': depth -= 1
|
||||
elif ch == '|': in_pipe = not in_pipe
|
||||
if ch == ',' and depth == 0 and not in_pipe:
|
||||
operands.append(current.strip())
|
||||
current = ""
|
||||
else:
|
||||
current += ch
|
||||
if current.strip(): operands.append(current.strip())
|
||||
return operands
|
||||
|
||||
def _unwrap_dsl(s: str) -> str:
|
||||
"""Unwrap a DSL expression to get the raw value for literals."""
|
||||
if re.match(r'^-?\d+$', s): return s
|
||||
if re.match(r'^-?0x[0-9a-fA-F]+$', s): return s
|
||||
return s
|
||||
|
||||
def get_dsl(text: str) -> str:
|
||||
"""Transform LLVM-style assembly instruction to Python DSL expression string."""
|
||||
text = text.strip()
|
||||
clamp = 'clamp' in text.lower()
|
||||
if clamp: text = re.sub(r'\s+clamp\s*$', '', text, flags=re.I)
|
||||
modifiers = {}
|
||||
if m := re.search(r'\s+wait_exp:(\d+)', text, re.I): modifiers['waitexp'] = int(m.group(1)); text = text[:m.start()] + text[m.end():]
|
||||
# Extract and remove trailing modifiers (must happen before operand parsing)
|
||||
kwargs = []
|
||||
# Extract mul:N and div:N modifiers (omod)
|
||||
omod_val = 0
|
||||
if m := re.search(r'\s+mul:2(?:\s|$)', text, re.I):
|
||||
omod_val = 1; text = text[:m.start()] + text[m.end():]
|
||||
elif m := re.search(r'\s+mul:4(?:\s|$)', text, re.I):
|
||||
omod_val = 2; text = text[:m.start()] + text[m.end():]
|
||||
elif m := re.search(r'\s+div:2(?:\s|$)', text, re.I):
|
||||
omod_val = 3; text = text[:m.start()] + text[m.end():]
|
||||
if omod_val: kwargs.append(f'omod={omod_val}')
|
||||
# Extract clamp modifier
|
||||
if m := re.search(r'\s+clamp(?:\s|$)', text, re.I):
|
||||
kwargs.append('clmp=1')
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
# Extract op_sel:[...] modifier - interpretation depends on format:
|
||||
# VOP3: [src0, src1, dst] or [src0, src1, src2, dst] -> bits 0, 1, (2), 3
|
||||
# VOP3P/WMMA: [src0, src1, src2] -> bits 0, 1, 2 (no dst bit, 3-source ops)
|
||||
opsel_explicit = None
|
||||
if m := re.search(r'\s+op_sel:\[([^\]]+)\]', text, re.I):
|
||||
bits = [int(x.strip()) for x in m.group(1).split(',')]
|
||||
# Check if this is a VOP3P instruction (v_pk_*, v_wmma_*, v_dot*)
|
||||
mnemonic = text.split()[0].lower()
|
||||
is_vop3p = mnemonic.startswith(('v_pk_', 'v_wmma_', 'v_dot'))
|
||||
if len(bits) == 3:
|
||||
if is_vop3p:
|
||||
# VOP3P: [src0, src1, src2] -> bits 0, 1, 2
|
||||
opsel_explicit = bits[0] | (bits[1] << 1) | (bits[2] << 2)
|
||||
else:
|
||||
# VOP3: [src0, src1, dst] -> bits 0, 1, 3
|
||||
opsel_explicit = bits[0] | (bits[1] << 1) | (bits[2] << 3)
|
||||
else:
|
||||
opsel_explicit = sum(b << i for i, b in enumerate(bits))
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
if m := re.search(r'\s+wait_exp:(\d+)', text, re.I):
|
||||
kwargs.append(f'waitexp={m.group(1)}')
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
# Extract offset:N for FLAT/GLOBAL/SCRATCH/SMEM (can be hex or decimal)
|
||||
offset_val = None
|
||||
if m := re.search(r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)', text, re.I):
|
||||
offset_val = m.group(1)
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
# Extract dlc modifier (before glc to avoid partial match issues)
|
||||
dlc_val = None
|
||||
if m := re.search(r'\s+dlc(?:\s|$)', text, re.I):
|
||||
dlc_val = 1
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
# Extract glc modifier
|
||||
glc_val = None
|
||||
if m := re.search(r'\s+glc(?:\s|$)', text, re.I):
|
||||
glc_val = 1
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
# Extract neg_lo:[...] and neg_hi:[...] for VOP3P
|
||||
neg_lo_val = None
|
||||
if m := re.search(r'\s+neg_lo:\[([^\]]+)\]', text, re.I):
|
||||
bits = [int(x.strip()) for x in m.group(1).split(',')]
|
||||
neg_lo_val = sum(b << i for i, b in enumerate(bits))
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
neg_hi_val = None
|
||||
if m := re.search(r'\s+neg_hi:\[([^\]]+)\]', text, re.I):
|
||||
bits = [int(x.strip()) for x in m.group(1).split(',')]
|
||||
neg_hi_val = sum(b << i for i, b in enumerate(bits))
|
||||
text = text[:m.start()] + text[m.end():]
|
||||
parts = text.replace(',', ' ').split()
|
||||
if not parts: raise ValueError("empty instruction")
|
||||
mnemonic, op_str = parts[0].lower(), text[len(parts[0]):].strip()
|
||||
# Handle s_waitcnt specially before operand parsing
|
||||
# Handle s_waitcnt specially
|
||||
if mnemonic == 's_waitcnt':
|
||||
vmcnt, expcnt, lgkmcnt = 0x3f, 0x7, 0x3f
|
||||
for part in op_str.replace(',', ' ').split():
|
||||
if m := re.match(r'vmcnt\((\d+)\)', part): vmcnt = int(m.group(1))
|
||||
elif m := re.match(r'expcnt\((\d+)\)', part): expcnt = int(m.group(1))
|
||||
elif m := re.match(r'lgkmcnt\((\d+)\)', part): lgkmcnt = int(m.group(1))
|
||||
elif re.match(r'^0x[0-9a-f]+$|^\d+$', part): return autogen.s_waitcnt(simm16=int(part, 0))
|
||||
return autogen.s_waitcnt(simm16=waitcnt(vmcnt, expcnt, lgkmcnt))
|
||||
# Handle VOPD dual-issue instructions: opx dst, src :: opy dst, src
|
||||
elif re.match(r'^0x[0-9a-f]+$|^\d+$', part): return f"s_waitcnt(simm16={int(part, 0)})"
|
||||
wc = waitcnt(vmcnt, expcnt, lgkmcnt)
|
||||
return f"s_waitcnt(simm16={wc})"
|
||||
# Handle VOPD dual-issue: opx dst, src :: opy dst, src
|
||||
if '::' in text:
|
||||
x_part, y_part = text.split('::')
|
||||
x_parts, y_parts = x_part.strip().replace(',', ' ').split(), y_part.strip().replace(',', ' ').split()
|
||||
opx_name, opy_name = x_parts[0].upper(), y_parts[0].upper()
|
||||
opx, opy = autogen.VOPDOp[opx_name], autogen.VOPDOp[opy_name]
|
||||
x_ops, y_ops = [parse_operand(p)[0] for p in x_parts[1:]], [parse_operand(p)[0] for p in y_parts[1:]]
|
||||
vdstx, srcx0 = x_ops[0], x_ops[1] if len(x_ops) > 1 else 0
|
||||
vsrcx1 = x_ops[2] if len(x_ops) > 2 else VGPR(0)
|
||||
vdsty, srcy0 = y_ops[0], y_ops[1] if len(y_ops) > 1 else 0
|
||||
vsrcy1 = y_ops[2] if len(y_ops) > 2 else VGPR(0)
|
||||
# Handle fmaak/fmamk literals (4th operand on x or y side)
|
||||
x_ops = [_operand_to_dsl(p) for p in x_parts[1:]]
|
||||
y_ops = [_operand_to_dsl(p) for p in y_parts[1:]]
|
||||
vdstx, srcx0 = x_ops[0], x_ops[1] if len(x_ops) > 1 else '0'
|
||||
vsrcx1 = x_ops[2] if len(x_ops) > 2 else 'v[0]'
|
||||
vdsty, srcy0 = y_ops[0], y_ops[1] if len(y_ops) > 1 else '0'
|
||||
vsrcy1 = y_ops[2] if len(y_ops) > 2 else 'v[0]'
|
||||
lit = None
|
||||
if 'fmaak' in opx_name.lower() and len(x_ops) > 3: lit = unwrap(x_ops[3])
|
||||
elif 'fmamk' in opx_name.lower() and len(x_ops) > 3: lit, vsrcx1 = unwrap(x_ops[2]), x_ops[3]
|
||||
elif 'fmaak' in opy_name.lower() and len(y_ops) > 3: lit = unwrap(y_ops[3])
|
||||
elif 'fmamk' in opy_name.lower() and len(y_ops) > 3: lit, vsrcy1 = unwrap(y_ops[2]), y_ops[3]
|
||||
return autogen.VOPD(opx, opy, vdstx=vdstx, vdsty=vdsty, srcx0=srcx0, vsrcx1=vsrcx1, srcy0=srcy0, vsrcy1=vsrcy1, literal=lit)
|
||||
operands, current, depth, in_pipe = [], "", 0, False
|
||||
for ch in op_str:
|
||||
if ch in '[(': depth += 1
|
||||
elif ch in '])': depth -= 1
|
||||
elif ch == '|': in_pipe = not in_pipe
|
||||
if ch == ',' and depth == 0 and not in_pipe: operands.append(current.strip()); current = ""
|
||||
else: current += ch
|
||||
if current.strip(): operands.append(current.strip())
|
||||
parsed = [parse_operand(op) for op in operands]
|
||||
values = [p[0] for p in parsed]
|
||||
neg_bits = sum((1 << (i-1)) for i, p in enumerate(parsed) if i > 0 and p[1])
|
||||
abs_bits = sum((1 << (i-1)) for i, p in enumerate(parsed) if i > 0 and p[2])
|
||||
opsel_bits = (8 if len(parsed) > 0 and parsed[0][3] else 0) | sum((1 << i) for i, p in enumerate(parsed[1:4]) if p[3])
|
||||
lit = None
|
||||
if mnemonic in ('v_fmaak_f32', 'v_fmaak_f16') and len(values) == 4: lit, values = unwrap(values[3]), values[:3]
|
||||
elif mnemonic in ('v_fmamk_f32', 'v_fmamk_f16') and len(values) == 4: lit, values = unwrap(values[2]), [values[0], values[1], values[3]]
|
||||
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32', 'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'}
|
||||
if mnemonic.replace('_e32', '') in vcc_ops and len(values) >= 5: values = [values[0], values[2], values[3]]
|
||||
# v_cmp_*_e32: strip implicit vcc_lo dest. v_cmp_*_e64: keep vdst (vcc_lo encodes to 106)
|
||||
if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(values) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'):
|
||||
values = values[1:]
|
||||
# CMPX instructions with _e64 suffix: prepend implicit EXEC_LO destination (vdst=126)
|
||||
if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(values) == 2:
|
||||
values = [VGPR(126, 1)] + values
|
||||
# Recalculate modifiers: parsed[0]=src0, parsed[1]=src1 (no vdst in user input)
|
||||
neg_bits = sum((1 << i) for i, p in enumerate(parsed[:3]) if p[1])
|
||||
abs_bits = sum((1 << i) for i, p in enumerate(parsed[:3]) if p[2])
|
||||
opsel_bits = sum((1 << i) for i, p in enumerate(parsed[:2]) if p[3])
|
||||
vop3sd_ops = {'v_div_scale_f32', 'v_div_scale_f64'}
|
||||
if mnemonic in vop3sd_ops and len(parsed) >= 5:
|
||||
neg_bits = sum((1 << i) for i, p in enumerate(parsed[2:5]) if p[1])
|
||||
abs_bits = sum((1 << i) for i, p in enumerate(parsed[2:5]) if p[2])
|
||||
if 'fmaak' in opx_name.lower() and len(x_ops) > 3: lit = x_ops[3]
|
||||
elif 'fmamk' in opx_name.lower() and len(x_ops) > 3: lit, vsrcx1 = x_ops[2], x_ops[3]
|
||||
elif 'fmaak' in opy_name.lower() and len(y_ops) > 3: lit = y_ops[3]
|
||||
elif 'fmamk' in opy_name.lower() and len(y_ops) > 3: lit, vsrcy1 = y_ops[2], y_ops[3]
|
||||
lit_str = f", literal={lit}" if lit else ""
|
||||
return f"VOPD(VOPDOp.{opx_name}, VOPDOp.{opy_name}, vdstx={vdstx}, vdsty={vdsty}, srcx0={srcx0}, vsrcx1={vsrcx1}, srcy0={srcy0}, vsrcy1={vsrcy1}{lit_str})"
|
||||
operands = _parse_operands(op_str)
|
||||
dsl_args = [_operand_to_dsl(op) for op in operands]
|
||||
# Handle special instructions
|
||||
if mnemonic in SOPK_UNSUPPORTED: raise ValueError(f"unsupported instruction: {mnemonic}")
|
||||
elif mnemonic in SOP1_SRC_ONLY:
|
||||
return getattr(autogen, mnemonic)(ssrc0=values[0])
|
||||
elif mnemonic in SOP1_MSG_IMM:
|
||||
return getattr(autogen, mnemonic)(sdst=values[0], ssrc0=RawImm(unwrap(values[1])))
|
||||
elif mnemonic in SOPK_IMM_ONLY:
|
||||
return getattr(autogen, mnemonic)(simm16=values[0])
|
||||
elif mnemonic in SOPK_IMM_FIRST:
|
||||
return getattr(autogen, mnemonic)(simm16=values[0], sdst=values[1])
|
||||
elif mnemonic in SMEM_OPS and len(operands) >= 3 and re.match(r'^-?[0-9]|^-?0x', operands[2].strip().lower()):
|
||||
return getattr(autogen, mnemonic)(sdata=values[0], sbase=values[1], offset=values[2], soffset=RawImm(124))
|
||||
elif mnemonic.startswith('buffer_') and len(operands) >= 2 and operands[1].strip().lower() == 'off':
|
||||
return getattr(autogen, mnemonic)(vdata=values[0], vaddr=0, srsrc=values[2], soffset=RawImm(unwrap(values[3])) if len(values) > 3 else RawImm(0))
|
||||
elif (mnemonic.startswith('flat_load') or mnemonic.startswith('global_load') or mnemonic.startswith('scratch_load')) and len(values) >= 3:
|
||||
offset = int(m.group(1)) if (m := re.search(r'offset:(-?\d+)', op_str)) else 0
|
||||
return getattr(autogen, mnemonic)(vdst=values[0], addr=values[1], saddr=values[2], offset=offset)
|
||||
elif (mnemonic.startswith('flat_store') or mnemonic.startswith('global_store') or mnemonic.startswith('scratch_store')) and len(values) >= 3:
|
||||
offset = int(m.group(1)) if (m := re.search(r'offset:(-?\d+)', op_str)) else 0
|
||||
return getattr(autogen, mnemonic)(addr=values[0], data=values[1], saddr=values[2], offset=offset)
|
||||
for suffix in (['_e32', ''] if not (neg_bits or abs_bits or clamp) else ['', '_e32']):
|
||||
if hasattr(autogen, name := mnemonic.replace('.', '_') + suffix):
|
||||
use_opsel = 'opsel' in getattr(autogen, name).func._fields
|
||||
vals = [type(v)(v.idx, v.count, False) if isinstance(v, Reg) and v.hi and use_opsel else v for v in values]
|
||||
inst = getattr(autogen, name)(*vals, literal=lit, **modifiers)
|
||||
if neg_bits and 'neg' in inst._fields: inst._values['neg'] = neg_bits
|
||||
if opsel_bits and use_opsel: inst._values['opsel'] = opsel_bits
|
||||
if abs_bits and 'abs' in inst._fields: inst._values['abs'] = abs_bits
|
||||
if clamp and 'clmp' in inst._fields: inst._values['clmp'] = 1
|
||||
return inst
|
||||
raise ValueError(f"unknown instruction: {mnemonic}")
|
||||
if mnemonic in SOP1_SRC_ONLY: return f"{mnemonic}(ssrc0={dsl_args[0]})"
|
||||
if mnemonic in SOP1_MSG_IMM: return f"{mnemonic}(sdst={dsl_args[0]}, ssrc0=RawImm({_unwrap_dsl(dsl_args[1])}))"
|
||||
if mnemonic in SOPK_IMM_ONLY: return f"{mnemonic}(simm16={dsl_args[0]})"
|
||||
if mnemonic in SOPK_IMM_FIRST: return f"{mnemonic}(simm16={dsl_args[0]}, sdst={dsl_args[1]})"
|
||||
# SMEM with immediate offset (offset in operand[2] or offset: modifier)
|
||||
if mnemonic in SMEM_OPS:
|
||||
glc_str = ", glc=1" if glc_val else ""
|
||||
dlc_str = ", dlc=1" if dlc_val else ""
|
||||
# Pure immediate offset in operand[2]
|
||||
if len(operands) >= 3 and re.match(r'^-?[0-9]|^-?0x', operands[2].strip().lower()):
|
||||
return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, offset={dsl_args[2]}, soffset=RawImm(124){glc_str}{dlc_str})"
|
||||
# Register soffset with offset: modifier
|
||||
if offset_val and len(operands) >= 3:
|
||||
return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, offset={offset_val}, soffset={dsl_args[2]}{glc_str}{dlc_str})"
|
||||
# Register soffset only (no offset modifier)
|
||||
if len(operands) >= 3:
|
||||
return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, soffset={dsl_args[2]}{glc_str}{dlc_str})"
|
||||
# Buffer ops with 'off'
|
||||
if mnemonic.startswith('buffer_') and len(operands) >= 2 and operands[1].strip().lower() == 'off':
|
||||
soff = f"RawImm({_unwrap_dsl(dsl_args[3])})" if len(dsl_args) > 3 else "RawImm(0)"
|
||||
return f"{mnemonic}(vdata={dsl_args[0]}, vaddr=0, srsrc={dsl_args[2]}, soffset={soff})"
|
||||
# FLAT/GLOBAL/SCRATCH load
|
||||
if (mnemonic.startswith('flat_load') or mnemonic.startswith('global_load') or mnemonic.startswith('scratch_load')) and len(dsl_args) >= 3:
|
||||
off = f", offset={offset_val}" if offset_val else ""
|
||||
return f"{mnemonic}(vdst={dsl_args[0]}, addr={dsl_args[1]}, saddr={dsl_args[2]}{off})"
|
||||
# FLAT/GLOBAL/SCRATCH store
|
||||
if (mnemonic.startswith('flat_store') or mnemonic.startswith('global_store') or mnemonic.startswith('scratch_store')) and len(dsl_args) >= 3:
|
||||
off = f", offset={offset_val}" if offset_val else ""
|
||||
return f"{mnemonic}(addr={dsl_args[0]}, data={dsl_args[1]}, saddr={dsl_args[2]}{off})"
|
||||
# Handle v_fmaak/v_fmamk literals
|
||||
lit_str = ""
|
||||
if mnemonic in ('v_fmaak_f32', 'v_fmaak_f16') and len(dsl_args) == 4:
|
||||
lit_str, dsl_args = f", literal={_unwrap_dsl(dsl_args[3])}", dsl_args[:3]
|
||||
elif mnemonic in ('v_fmamk_f32', 'v_fmamk_f16') and len(dsl_args) == 4:
|
||||
lit_str, dsl_args = f", literal={_unwrap_dsl(dsl_args[2])}", [dsl_args[0], dsl_args[1], dsl_args[3]]
|
||||
# Handle v_add_co_ci_u32_e32 etc with vcc operands - strip implicit vcc sdst and carry_in, add _e32 suffix
|
||||
vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'}
|
||||
if mnemonic.replace('_e32', '') in vcc_ops and len(dsl_args) >= 5:
|
||||
mnemonic = mnemonic.replace('_e32', '') + '_e32' # Ensure _e32 suffix for VOP2 encoding
|
||||
dsl_args = [dsl_args[0], dsl_args[2], dsl_args[3]]
|
||||
# v_cmp_*_e32: strip implicit vcc_lo dest
|
||||
if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(dsl_args) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'):
|
||||
dsl_args = dsl_args[1:]
|
||||
# CMPX with _e64: prepend implicit EXEC_LO (vdst=126)
|
||||
if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(dsl_args) == 2:
|
||||
dsl_args = ['RawImm(126)'] + dsl_args
|
||||
# Build the function name - use mnemonic as-is, replacing . with _
|
||||
func_name = mnemonic.replace('.', '_')
|
||||
# When explicit opsel is given, strip .h/.l from register args (opsel overrides)
|
||||
if opsel_explicit is not None:
|
||||
dsl_args = [re.sub(r'\.[hl]$', '', a) for a in dsl_args]
|
||||
args_str = ', '.join(dsl_args)
|
||||
all_kwargs = list(kwargs)
|
||||
if lit_str: all_kwargs.append(lit_str.lstrip(', '))
|
||||
if opsel_explicit is not None: all_kwargs.append(f'opsel={opsel_explicit}')
|
||||
if neg_lo_val is not None: all_kwargs.append(f'neg={neg_lo_val}')
|
||||
if neg_hi_val is not None: all_kwargs.append(f'neg_hi={neg_hi_val}')
|
||||
kwargs_str = ', '.join(all_kwargs)
|
||||
if kwargs_str:
|
||||
return f"{func_name}({args_str}, {kwargs_str})" if args_str else f"{func_name}({kwargs_str})"
|
||||
return f"{func_name}({args_str})"
|
||||
|
||||
def asm(text: str) -> Inst:
|
||||
"""Assemble LLVM-style instruction text to Inst by transforming to DSL and eval."""
|
||||
from extra.assembly.amd.autogen import rdna3 as autogen
|
||||
dsl_expr = get_dsl(text)
|
||||
namespace = {name: getattr(autogen, name) for name in dir(autogen) if not name.startswith('_')}
|
||||
namespace.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP,
|
||||
'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC,
|
||||
'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF})
|
||||
try:
|
||||
return eval(dsl_expr, namespace)
|
||||
except NameError:
|
||||
# Try with _e32 suffix for VOP1/VOP2/VOPC (only for v_* instructions)
|
||||
if m := re.match(r'^(v_\w+)(\(.*\))$', dsl_expr):
|
||||
return eval(f"{m.group(1)}_e32{m.group(2)}", namespace)
|
||||
raise
|
||||
|
||||
@@ -39,11 +39,27 @@ class _Bits:
|
||||
def __getitem__(self, key) -> BitField: return BitField(key.start, key.stop) if isinstance(key, slice) else BitField(key, key)
|
||||
bits = _Bits()
|
||||
|
||||
# Source operand with modifiers - base class for anything that can be a src with neg/abs
|
||||
class SrcMod:
|
||||
__slots__ = ('val', 'neg', 'abs_')
|
||||
def __init__(self, val: int, neg: bool = False, abs_: bool = False): self.val, self.neg, self.abs_ = val, neg, abs_
|
||||
def __repr__(self): return f"{'-' if self.neg else ''}{'|' if self.abs_ else ''}{self.val}{'|' if self.abs_ else ''}"
|
||||
def __neg__(self): return SrcMod(self.val, not self.neg, self.abs_)
|
||||
def __abs__(self): return SrcMod(self.val, self.neg, True)
|
||||
|
||||
# Register types
|
||||
class Reg:
|
||||
def __init__(self, idx: int, count: int = 1, hi: bool = False, neg: bool = False): self.idx, self.count, self.hi, self.neg = idx, count, hi, neg
|
||||
class Reg(SrcMod):
|
||||
__slots__ = ('idx', 'count', 'hi')
|
||||
def __init__(self, idx: int, count: int = 1, hi: bool = False, neg: bool = False, abs_: bool = False):
|
||||
self.idx, self.count, self.hi = idx, count, hi
|
||||
super().__init__(idx, neg, abs_)
|
||||
def __repr__(self): return f"{self.__class__.__name__.lower()[0]}[{self.idx}]" if self.count == 1 else f"{self.__class__.__name__.lower()[0]}[{self.idx}:{self.idx + self.count}]"
|
||||
def __neg__(self): return self.__class__(self.idx, self.count, self.hi, neg=not self.neg)
|
||||
def __neg__(self): return self.__class__(self.idx, self.count, self.hi, not self.neg, self.abs_)
|
||||
def __abs__(self): return self.__class__(self.idx, self.count, self.hi, self.neg, True)
|
||||
@property
|
||||
def l(self): return self.__class__(self.idx, self.count, False, self.neg, self.abs_)
|
||||
@property
|
||||
def h(self): return self.__class__(self.idx, self.count, True, self.neg, self.abs_)
|
||||
|
||||
T = TypeVar('T', bound=Reg)
|
||||
class _RegFactory(Generic[T]):
|
||||
@@ -63,6 +79,11 @@ s: _RegFactory[SGPR] = _RegFactory(SGPR, "SGPR")
|
||||
v: _RegFactory[VGPR] = _RegFactory(VGPR, "VGPR")
|
||||
ttmp: _RegFactory[TTMP] = _RegFactory(TTMP, "TTMP")
|
||||
|
||||
# Special registers as SrcMod objects (support -VCC_LO, abs(EXEC_LO), etc.)
|
||||
VCC_LO, VCC_HI, VCC = SrcMod(106), SrcMod(107), SrcMod(106)
|
||||
EXEC_LO, EXEC_HI, EXEC = SrcMod(126), SrcMod(127), SrcMod(126)
|
||||
SCC, M0, NULL, OFF = SrcMod(253), SrcMod(125), SrcMod(124), SrcMod(124)
|
||||
|
||||
# Field type markers (runtime classes for validation)
|
||||
class _SSrc: pass
|
||||
class _Src: pass
|
||||
@@ -86,21 +107,35 @@ class RawImm:
|
||||
def __eq__(self, other): return isinstance(other, RawImm) and self.val == other.val
|
||||
|
||||
def unwrap(val) -> int:
|
||||
return val.val if isinstance(val, RawImm) else val.value if hasattr(val, 'value') else val.idx if hasattr(val, 'idx') else val
|
||||
if isinstance(val, RawImm): return val.val
|
||||
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special registers like VCC_LO, NULL
|
||||
if hasattr(val, 'value'): return val.value # IntEnum
|
||||
if hasattr(val, 'idx'): return val.idx # Reg
|
||||
return val
|
||||
|
||||
# Encoding helpers
|
||||
FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247}
|
||||
SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'}
|
||||
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata'}
|
||||
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'}
|
||||
|
||||
def _encode_reg(val) -> int:
|
||||
def _encode_reg(val: Reg) -> int:
|
||||
if isinstance(val, TTMP): return 108 + val.idx
|
||||
return val.idx | (0x80 if val.hi else 0)
|
||||
return val.idx # hi bit is handled via opsel, not in register encoding
|
||||
|
||||
def encode_src(val) -> int:
|
||||
if isinstance(val, VGPR): return 256 + _encode_reg(val)
|
||||
if isinstance(val, Reg): return _encode_reg(val)
|
||||
if hasattr(val, 'value'): return val.value
|
||||
if isinstance(val, SrcMod) and not isinstance(val, Reg):
|
||||
# SrcMod wraps either special registers (VCC_LO=106, EXEC_LO=126, etc.) or literals
|
||||
# Special register values are in valid encoding ranges - return as-is
|
||||
# Literals (large integers) need 255 marker
|
||||
v = val.val
|
||||
# Valid source encoding ranges: 0-127 (SGPRs/special), 128-192 (inline const), 193-208 (neg inline), 240-247 (float), 251-253 (special)
|
||||
if 0 <= v <= 127 or 240 <= v <= 255: return v # SGPRs, special regs, float constants
|
||||
if 128 <= v <= 192: return v # Inline positive constants (0-64)
|
||||
if 193 <= v <= 208: return v # Inline negative constants (-1 to -16)
|
||||
return 255 # Literal marker - value stored separately
|
||||
if hasattr(val, 'value'): return val.value # IntEnum
|
||||
if isinstance(val, float): return 128 if val == 0.0 else FLOAT_ENC.get(val, 255)
|
||||
return 128 + val if isinstance(val, int) and 0 <= val <= 64 else 192 + (-val) if isinstance(val, int) and -16 <= val <= -1 else 255
|
||||
|
||||
@@ -156,34 +191,61 @@ class Inst:
|
||||
# Type validation
|
||||
if marker is _SGPRField:
|
||||
if isinstance(val, VGPR): raise TypeError(f"field '{name}' requires SGPR, got VGPR")
|
||||
if not isinstance(val, (SGPR, TTMP, int, RawImm)): raise TypeError(f"field '{name}' requires SGPR, got {type(val).__name__}")
|
||||
if not isinstance(val, (SGPR, TTMP, SrcMod, int, RawImm)): raise TypeError(f"field '{name}' requires SGPR, got {type(val).__name__}")
|
||||
if marker is _VGPRField:
|
||||
if not isinstance(val, VGPR): raise TypeError(f"field '{name}' requires VGPR, got {type(val).__name__}")
|
||||
if marker is _SSrc and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires scalar source, got VGPR")
|
||||
# Encode source fields as RawImm for consistent disassembly
|
||||
if name in SRC_FIELDS:
|
||||
encoded = encode_src(val)
|
||||
# For VOP1/VOP2/VOPC (no opsel field), encode hi bit in src value
|
||||
if isinstance(val, Reg) and val.hi and 'opsel' not in self._fields:
|
||||
encoded |= 0x80
|
||||
self._values[name] = RawImm(encoded)
|
||||
# Handle negation modifier for VOP3 instructions
|
||||
if isinstance(val, Reg) and val.neg and 'neg' in self._fields:
|
||||
# Handle neg/abs/opsel modifiers for VOP3 instructions
|
||||
if isinstance(val, SrcMod):
|
||||
if val.neg and 'neg' in self._fields:
|
||||
neg_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)
|
||||
cur_neg = self._values.get('neg', 0)
|
||||
self._values['neg'] = (cur_neg.val if isinstance(cur_neg, RawImm) else cur_neg) | neg_bit
|
||||
if val.abs_ and 'abs' in self._fields:
|
||||
abs_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)
|
||||
cur_abs = self._values.get('abs', 0)
|
||||
self._values['abs'] = (cur_abs.val if isinstance(cur_abs, RawImm) else cur_abs) | abs_bit
|
||||
# Handle hi (opsel) for 16-bit ops - only for formats with opsel field
|
||||
if isinstance(val, Reg) and val.hi and 'opsel' in self._fields:
|
||||
opsel_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0)
|
||||
cur_opsel = self._values.get('opsel', 0)
|
||||
self._values['opsel'] = (cur_opsel.val if isinstance(cur_opsel, RawImm) else cur_opsel) | opsel_bit
|
||||
# Track literal value if needed (encoded as 255)
|
||||
# For 64-bit ops, store literal in high 32 bits (to match from_bytes decoding and to_bytes encoding)
|
||||
if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum):
|
||||
if encoded == 255 and self._literal is None:
|
||||
if isinstance(val, SrcMod) and not isinstance(val, Reg):
|
||||
# SrcMod wrapping a literal value
|
||||
self._literal = (val.val << 32) if self._is_64bit_op() else val.val
|
||||
elif isinstance(val, int) and not isinstance(val, IntEnum):
|
||||
self._literal = (val << 32) if self._is_64bit_op() else val
|
||||
elif encoded == 255 and self._literal is None and isinstance(val, float):
|
||||
elif isinstance(val, float):
|
||||
import struct
|
||||
lit32 = struct.unpack('<I', struct.pack('<f', val))[0]
|
||||
self._literal = (lit32 << 32) if self._is_64bit_op() else lit32
|
||||
# Encode raw register fields for consistent repr
|
||||
elif name in RAW_FIELDS:
|
||||
if isinstance(val, Reg): self._values[name] = _encode_reg(val)
|
||||
if isinstance(val, Reg):
|
||||
encoded = _encode_reg(val)
|
||||
# For VOP1/VOP2/VOPC (no opsel field), encode hi bit in register value
|
||||
if val.hi and 'opsel' not in self._fields:
|
||||
encoded |= 0x80
|
||||
self._values[name] = encoded
|
||||
# Handle vdst hi (opsel bit 3) for 16-bit ops - only for formats with opsel field
|
||||
if name == 'vdst' and val.hi and 'opsel' in self._fields:
|
||||
cur_opsel = self._values.get('opsel', 0)
|
||||
self._values['opsel'] = (cur_opsel.val if isinstance(cur_opsel, RawImm) else cur_opsel) | 8
|
||||
elif hasattr(val, 'value'): self._values[name] = val.value # IntEnum like SrcEnum.NULL
|
||||
# Encode sbase (divided by 2) and srsrc/ssamp (divided by 4)
|
||||
elif name == 'sbase' and isinstance(val, Reg):
|
||||
self._values[name] = val.idx // 2
|
||||
elif name == 'sbase':
|
||||
if isinstance(val, Reg): self._values[name] = val.idx // 2
|
||||
elif isinstance(val, SrcMod): self._values[name] = val.val // 2 # Special regs like VCC_LO
|
||||
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg):
|
||||
self._values[name] = val.idx // 4
|
||||
# VOPD vdsty: encode as actual >> 1 (constraint: vdsty parity must be opposite of vdstx)
|
||||
@@ -192,8 +254,9 @@ class Inst:
|
||||
|
||||
def _encode_field(self, name: str, val) -> int:
|
||||
if isinstance(val, RawImm): return val.val
|
||||
if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO
|
||||
if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val
|
||||
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val
|
||||
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val
|
||||
if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val
|
||||
if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val)
|
||||
return val.value if hasattr(val, 'value') else val
|
||||
|
||||
Reference in New Issue
Block a user