From 94bca91f3e4c5dde7acb44cdb2a004b930689042 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 29 Dec 2025 17:39:11 -0500 Subject: [PATCH] assembly/amd: have asm go through the dsl (#13886) * assembly/amd: have asm go through the dsl * lil --- extra/assembly/amd/asm.py | 325 ++++++++++++++++++++++++++++---------- extra/assembly/amd/dsl.py | 111 ++++++++++--- 2 files changed, 327 insertions(+), 109 deletions(-) diff --git a/extra/assembly/amd/asm.py b/extra/assembly/amd/asm.py index 7fcef5e64e..a8a5f55f54 100644 --- a/extra/assembly/amd/asm.py +++ b/extra/assembly/amd/asm.py @@ -1,7 +1,8 @@ # RDNA3 assembler and disassembler from __future__ import annotations import re -from extra.assembly.amd.dsl import Inst, RawImm, Reg, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, FLOAT_ENC, SRC_FIELDS, unwrap +from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, FLOAT_ENC, SRC_FIELDS, unwrap +from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF # Decoding helpers SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"} @@ -464,7 +465,8 @@ def disasm(inst: Inst) -> str: return f"{op_name} {', '.join(ops)}" if ops else op_name # Assembler -SPECIAL_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'scc': RawImm(253)} +SPECIAL_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'vcc': RawImm(106), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), + 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'exec': RawImm(126), 'scc': RawImm(253), 'src_scc': RawImm(253)} FLOAT_CONSTS = {'0.5': 0.5, '-0.5': -0.5, '1.0': 1.0, '-1.0': -1.0, '2.0': 2.0, '-2.0': -2.0, '4.0': 4.0, '-4.0': -4.0} REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp} @@ -504,102 +506,255 @@ SOPK_IMM_ONLY = {'s_version'} SOPK_IMM_FIRST = {'s_setreg_b32'} SOPK_UNSUPPORTED = {'s_setreg_imm32_b32'} -def asm(text: str) -> Inst: - from extra.assembly.amd.autogen import rdna3 as autogen +def _operand_to_dsl(op: str) -> str: + """Transform a single operand from LLVM assembly syntax to DSL expression string.""" + op = op.strip() + # Handle negation prefix + neg = False + if op.startswith('-') and not (op[1:2].isdigit() or (len(op) > 2 and op[1] == '0' and op[2] in 'xX')): + neg, op = True, op[1:] + # Handle abs modifier: |x| or abs(x) + abs_ = False + if op.startswith('|') and op.endswith('|'): + abs_, op = True, op[1:-1] + elif op.startswith('abs(') and op.endswith(')'): + abs_, op = True, op[4:-1] + # Handle .h/.l suffix for 16-bit ops + hi_suffix = "" + if op.endswith('.h'): hi_suffix, op = ".h", op[:-2] + elif op.endswith('.l'): hi_suffix, op = ".l", op[:-2] + op_lower = op.lower() + + # Helper to apply modifiers + def apply_mods(base: str) -> str: + if not neg and not abs_: return f"{base}{hi_suffix}" + if abs_: return f"{'-' if neg else ''}abs({base}){hi_suffix}" + return f"-{base}{hi_suffix}" + + # Special registers - vcc maps to VCC_LO (64-bit alias) + special_map = {'vcc_lo': 'VCC_LO', 'vcc_hi': 'VCC_HI', 'vcc': 'VCC_LO', 'null': 'NULL', 'off': 'OFF', + 'm0': 'M0', 'exec_lo': 'EXEC_LO', 'exec_hi': 'EXEC_HI', 'exec': 'EXEC_LO', 'scc': 'SCC', + 'src_scc': 'SCC'} + if op_lower in special_map: return apply_mods(special_map[op_lower]) + # Float constants + float_map = {'0.5': '0.5', '-0.5': '-0.5', '1.0': '1.0', '-1.0': '-1.0', '2.0': '2.0', '-2.0': '-2.0', '4.0': '4.0', '-4.0': '-4.0'} + if op in float_map: return apply_mods(float_map[op]) + # Register range: v[0:3], s[4:7] + if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op_lower): + prefix = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}[m.group(1)] + return apply_mods(f"{prefix}[{m.group(2)}:{m.group(3)}]") + # Single register: v0, s1, ttmp5 + if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op_lower): + prefix = {'s': 's', 'v': 'v', 't': 'ttmp', 'ttmp': 'ttmp'}[m.group(1)] + return apply_mods(f"{prefix}[{m.group(2)}]") + # Integer literals (decimal or hex) - use SrcMod wrapper when modifiers present + if re.match(r'^-?\d+$', op) or re.match(r'^-?0x([0-9a-fA-F]+)$', op): + if neg or abs_: + return f"SrcMod({op}, neg={neg}, abs_={abs_})" + return op + # hwreg(name, offset, size) -> pass through + if op_lower.startswith('hwreg('): return apply_mods(op) + # sendmsg(...) -> pass through + if op_lower.startswith('sendmsg('): return apply_mods(op) + # Fallback: return as-is + return apply_mods(op) + +def _parse_operands(op_str: str) -> list[str]: + """Parse comma-separated operands, respecting brackets and pipes.""" + operands, current, depth, in_pipe = [], "", 0, False + for ch in op_str: + if ch in '[(': depth += 1 + elif ch in '])': depth -= 1 + elif ch == '|': in_pipe = not in_pipe + if ch == ',' and depth == 0 and not in_pipe: + operands.append(current.strip()) + current = "" + else: + current += ch + if current.strip(): operands.append(current.strip()) + return operands + +def _unwrap_dsl(s: str) -> str: + """Unwrap a DSL expression to get the raw value for literals.""" + if re.match(r'^-?\d+$', s): return s + if re.match(r'^-?0x[0-9a-fA-F]+$', s): return s + return s + +def get_dsl(text: str) -> str: + """Transform LLVM-style assembly instruction to Python DSL expression string.""" text = text.strip() - clamp = 'clamp' in text.lower() - if clamp: text = re.sub(r'\s+clamp\s*$', '', text, flags=re.I) - modifiers = {} - if m := re.search(r'\s+wait_exp:(\d+)', text, re.I): modifiers['waitexp'] = int(m.group(1)); text = text[:m.start()] + text[m.end():] + # Extract and remove trailing modifiers (must happen before operand parsing) + kwargs = [] + # Extract mul:N and div:N modifiers (omod) + omod_val = 0 + if m := re.search(r'\s+mul:2(?:\s|$)', text, re.I): + omod_val = 1; text = text[:m.start()] + text[m.end():] + elif m := re.search(r'\s+mul:4(?:\s|$)', text, re.I): + omod_val = 2; text = text[:m.start()] + text[m.end():] + elif m := re.search(r'\s+div:2(?:\s|$)', text, re.I): + omod_val = 3; text = text[:m.start()] + text[m.end():] + if omod_val: kwargs.append(f'omod={omod_val}') + # Extract clamp modifier + if m := re.search(r'\s+clamp(?:\s|$)', text, re.I): + kwargs.append('clmp=1') + text = text[:m.start()] + text[m.end():] + # Extract op_sel:[...] modifier - interpretation depends on format: + # VOP3: [src0, src1, dst] or [src0, src1, src2, dst] -> bits 0, 1, (2), 3 + # VOP3P/WMMA: [src0, src1, src2] -> bits 0, 1, 2 (no dst bit, 3-source ops) + opsel_explicit = None + if m := re.search(r'\s+op_sel:\[([^\]]+)\]', text, re.I): + bits = [int(x.strip()) for x in m.group(1).split(',')] + # Check if this is a VOP3P instruction (v_pk_*, v_wmma_*, v_dot*) + mnemonic = text.split()[0].lower() + is_vop3p = mnemonic.startswith(('v_pk_', 'v_wmma_', 'v_dot')) + if len(bits) == 3: + if is_vop3p: + # VOP3P: [src0, src1, src2] -> bits 0, 1, 2 + opsel_explicit = bits[0] | (bits[1] << 1) | (bits[2] << 2) + else: + # VOP3: [src0, src1, dst] -> bits 0, 1, 3 + opsel_explicit = bits[0] | (bits[1] << 1) | (bits[2] << 3) + else: + opsel_explicit = sum(b << i for i, b in enumerate(bits)) + text = text[:m.start()] + text[m.end():] + if m := re.search(r'\s+wait_exp:(\d+)', text, re.I): + kwargs.append(f'waitexp={m.group(1)}') + text = text[:m.start()] + text[m.end():] + # Extract offset:N for FLAT/GLOBAL/SCRATCH/SMEM (can be hex or decimal) + offset_val = None + if m := re.search(r'\s+offset:(0x[0-9a-fA-F]+|-?\d+)', text, re.I): + offset_val = m.group(1) + text = text[:m.start()] + text[m.end():] + # Extract dlc modifier (before glc to avoid partial match issues) + dlc_val = None + if m := re.search(r'\s+dlc(?:\s|$)', text, re.I): + dlc_val = 1 + text = text[:m.start()] + text[m.end():] + # Extract glc modifier + glc_val = None + if m := re.search(r'\s+glc(?:\s|$)', text, re.I): + glc_val = 1 + text = text[:m.start()] + text[m.end():] + # Extract neg_lo:[...] and neg_hi:[...] for VOP3P + neg_lo_val = None + if m := re.search(r'\s+neg_lo:\[([^\]]+)\]', text, re.I): + bits = [int(x.strip()) for x in m.group(1).split(',')] + neg_lo_val = sum(b << i for i, b in enumerate(bits)) + text = text[:m.start()] + text[m.end():] + neg_hi_val = None + if m := re.search(r'\s+neg_hi:\[([^\]]+)\]', text, re.I): + bits = [int(x.strip()) for x in m.group(1).split(',')] + neg_hi_val = sum(b << i for i, b in enumerate(bits)) + text = text[:m.start()] + text[m.end():] parts = text.replace(',', ' ').split() if not parts: raise ValueError("empty instruction") mnemonic, op_str = parts[0].lower(), text[len(parts[0]):].strip() - # Handle s_waitcnt specially before operand parsing + # Handle s_waitcnt specially if mnemonic == 's_waitcnt': vmcnt, expcnt, lgkmcnt = 0x3f, 0x7, 0x3f for part in op_str.replace(',', ' ').split(): if m := re.match(r'vmcnt\((\d+)\)', part): vmcnt = int(m.group(1)) elif m := re.match(r'expcnt\((\d+)\)', part): expcnt = int(m.group(1)) elif m := re.match(r'lgkmcnt\((\d+)\)', part): lgkmcnt = int(m.group(1)) - elif re.match(r'^0x[0-9a-f]+$|^\d+$', part): return autogen.s_waitcnt(simm16=int(part, 0)) - return autogen.s_waitcnt(simm16=waitcnt(vmcnt, expcnt, lgkmcnt)) - # Handle VOPD dual-issue instructions: opx dst, src :: opy dst, src + elif re.match(r'^0x[0-9a-f]+$|^\d+$', part): return f"s_waitcnt(simm16={int(part, 0)})" + wc = waitcnt(vmcnt, expcnt, lgkmcnt) + return f"s_waitcnt(simm16={wc})" + # Handle VOPD dual-issue: opx dst, src :: opy dst, src if '::' in text: x_part, y_part = text.split('::') x_parts, y_parts = x_part.strip().replace(',', ' ').split(), y_part.strip().replace(',', ' ').split() opx_name, opy_name = x_parts[0].upper(), y_parts[0].upper() - opx, opy = autogen.VOPDOp[opx_name], autogen.VOPDOp[opy_name] - x_ops, y_ops = [parse_operand(p)[0] for p in x_parts[1:]], [parse_operand(p)[0] for p in y_parts[1:]] - vdstx, srcx0 = x_ops[0], x_ops[1] if len(x_ops) > 1 else 0 - vsrcx1 = x_ops[2] if len(x_ops) > 2 else VGPR(0) - vdsty, srcy0 = y_ops[0], y_ops[1] if len(y_ops) > 1 else 0 - vsrcy1 = y_ops[2] if len(y_ops) > 2 else VGPR(0) - # Handle fmaak/fmamk literals (4th operand on x or y side) + x_ops = [_operand_to_dsl(p) for p in x_parts[1:]] + y_ops = [_operand_to_dsl(p) for p in y_parts[1:]] + vdstx, srcx0 = x_ops[0], x_ops[1] if len(x_ops) > 1 else '0' + vsrcx1 = x_ops[2] if len(x_ops) > 2 else 'v[0]' + vdsty, srcy0 = y_ops[0], y_ops[1] if len(y_ops) > 1 else '0' + vsrcy1 = y_ops[2] if len(y_ops) > 2 else 'v[0]' lit = None - if 'fmaak' in opx_name.lower() and len(x_ops) > 3: lit = unwrap(x_ops[3]) - elif 'fmamk' in opx_name.lower() and len(x_ops) > 3: lit, vsrcx1 = unwrap(x_ops[2]), x_ops[3] - elif 'fmaak' in opy_name.lower() and len(y_ops) > 3: lit = unwrap(y_ops[3]) - elif 'fmamk' in opy_name.lower() and len(y_ops) > 3: lit, vsrcy1 = unwrap(y_ops[2]), y_ops[3] - return autogen.VOPD(opx, opy, vdstx=vdstx, vdsty=vdsty, srcx0=srcx0, vsrcx1=vsrcx1, srcy0=srcy0, vsrcy1=vsrcy1, literal=lit) - operands, current, depth, in_pipe = [], "", 0, False - for ch in op_str: - if ch in '[(': depth += 1 - elif ch in '])': depth -= 1 - elif ch == '|': in_pipe = not in_pipe - if ch == ',' and depth == 0 and not in_pipe: operands.append(current.strip()); current = "" - else: current += ch - if current.strip(): operands.append(current.strip()) - parsed = [parse_operand(op) for op in operands] - values = [p[0] for p in parsed] - neg_bits = sum((1 << (i-1)) for i, p in enumerate(parsed) if i > 0 and p[1]) - abs_bits = sum((1 << (i-1)) for i, p in enumerate(parsed) if i > 0 and p[2]) - opsel_bits = (8 if len(parsed) > 0 and parsed[0][3] else 0) | sum((1 << i) for i, p in enumerate(parsed[1:4]) if p[3]) - lit = None - if mnemonic in ('v_fmaak_f32', 'v_fmaak_f16') and len(values) == 4: lit, values = unwrap(values[3]), values[:3] - elif mnemonic in ('v_fmamk_f32', 'v_fmamk_f16') and len(values) == 4: lit, values = unwrap(values[2]), [values[0], values[1], values[3]] - vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32', 'v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'} - if mnemonic.replace('_e32', '') in vcc_ops and len(values) >= 5: values = [values[0], values[2], values[3]] - # v_cmp_*_e32: strip implicit vcc_lo dest. v_cmp_*_e64: keep vdst (vcc_lo encodes to 106) - if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(values) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): - values = values[1:] - # CMPX instructions with _e64 suffix: prepend implicit EXEC_LO destination (vdst=126) - if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(values) == 2: - values = [VGPR(126, 1)] + values - # Recalculate modifiers: parsed[0]=src0, parsed[1]=src1 (no vdst in user input) - neg_bits = sum((1 << i) for i, p in enumerate(parsed[:3]) if p[1]) - abs_bits = sum((1 << i) for i, p in enumerate(parsed[:3]) if p[2]) - opsel_bits = sum((1 << i) for i, p in enumerate(parsed[:2]) if p[3]) - vop3sd_ops = {'v_div_scale_f32', 'v_div_scale_f64'} - if mnemonic in vop3sd_ops and len(parsed) >= 5: - neg_bits = sum((1 << i) for i, p in enumerate(parsed[2:5]) if p[1]) - abs_bits = sum((1 << i) for i, p in enumerate(parsed[2:5]) if p[2]) + if 'fmaak' in opx_name.lower() and len(x_ops) > 3: lit = x_ops[3] + elif 'fmamk' in opx_name.lower() and len(x_ops) > 3: lit, vsrcx1 = x_ops[2], x_ops[3] + elif 'fmaak' in opy_name.lower() and len(y_ops) > 3: lit = y_ops[3] + elif 'fmamk' in opy_name.lower() and len(y_ops) > 3: lit, vsrcy1 = y_ops[2], y_ops[3] + lit_str = f", literal={lit}" if lit else "" + return f"VOPD(VOPDOp.{opx_name}, VOPDOp.{opy_name}, vdstx={vdstx}, vdsty={vdsty}, srcx0={srcx0}, vsrcx1={vsrcx1}, srcy0={srcy0}, vsrcy1={vsrcy1}{lit_str})" + operands = _parse_operands(op_str) + dsl_args = [_operand_to_dsl(op) for op in operands] + # Handle special instructions if mnemonic in SOPK_UNSUPPORTED: raise ValueError(f"unsupported instruction: {mnemonic}") - elif mnemonic in SOP1_SRC_ONLY: - return getattr(autogen, mnemonic)(ssrc0=values[0]) - elif mnemonic in SOP1_MSG_IMM: - return getattr(autogen, mnemonic)(sdst=values[0], ssrc0=RawImm(unwrap(values[1]))) - elif mnemonic in SOPK_IMM_ONLY: - return getattr(autogen, mnemonic)(simm16=values[0]) - elif mnemonic in SOPK_IMM_FIRST: - return getattr(autogen, mnemonic)(simm16=values[0], sdst=values[1]) - elif mnemonic in SMEM_OPS and len(operands) >= 3 and re.match(r'^-?[0-9]|^-?0x', operands[2].strip().lower()): - return getattr(autogen, mnemonic)(sdata=values[0], sbase=values[1], offset=values[2], soffset=RawImm(124)) - elif mnemonic.startswith('buffer_') and len(operands) >= 2 and operands[1].strip().lower() == 'off': - return getattr(autogen, mnemonic)(vdata=values[0], vaddr=0, srsrc=values[2], soffset=RawImm(unwrap(values[3])) if len(values) > 3 else RawImm(0)) - elif (mnemonic.startswith('flat_load') or mnemonic.startswith('global_load') or mnemonic.startswith('scratch_load')) and len(values) >= 3: - offset = int(m.group(1)) if (m := re.search(r'offset:(-?\d+)', op_str)) else 0 - return getattr(autogen, mnemonic)(vdst=values[0], addr=values[1], saddr=values[2], offset=offset) - elif (mnemonic.startswith('flat_store') or mnemonic.startswith('global_store') or mnemonic.startswith('scratch_store')) and len(values) >= 3: - offset = int(m.group(1)) if (m := re.search(r'offset:(-?\d+)', op_str)) else 0 - return getattr(autogen, mnemonic)(addr=values[0], data=values[1], saddr=values[2], offset=offset) - for suffix in (['_e32', ''] if not (neg_bits or abs_bits or clamp) else ['', '_e32']): - if hasattr(autogen, name := mnemonic.replace('.', '_') + suffix): - use_opsel = 'opsel' in getattr(autogen, name).func._fields - vals = [type(v)(v.idx, v.count, False) if isinstance(v, Reg) and v.hi and use_opsel else v for v in values] - inst = getattr(autogen, name)(*vals, literal=lit, **modifiers) - if neg_bits and 'neg' in inst._fields: inst._values['neg'] = neg_bits - if opsel_bits and use_opsel: inst._values['opsel'] = opsel_bits - if abs_bits and 'abs' in inst._fields: inst._values['abs'] = abs_bits - if clamp and 'clmp' in inst._fields: inst._values['clmp'] = 1 - return inst - raise ValueError(f"unknown instruction: {mnemonic}") + if mnemonic in SOP1_SRC_ONLY: return f"{mnemonic}(ssrc0={dsl_args[0]})" + if mnemonic in SOP1_MSG_IMM: return f"{mnemonic}(sdst={dsl_args[0]}, ssrc0=RawImm({_unwrap_dsl(dsl_args[1])}))" + if mnemonic in SOPK_IMM_ONLY: return f"{mnemonic}(simm16={dsl_args[0]})" + if mnemonic in SOPK_IMM_FIRST: return f"{mnemonic}(simm16={dsl_args[0]}, sdst={dsl_args[1]})" + # SMEM with immediate offset (offset in operand[2] or offset: modifier) + if mnemonic in SMEM_OPS: + glc_str = ", glc=1" if glc_val else "" + dlc_str = ", dlc=1" if dlc_val else "" + # Pure immediate offset in operand[2] + if len(operands) >= 3 and re.match(r'^-?[0-9]|^-?0x', operands[2].strip().lower()): + return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, offset={dsl_args[2]}, soffset=RawImm(124){glc_str}{dlc_str})" + # Register soffset with offset: modifier + if offset_val and len(operands) >= 3: + return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, offset={offset_val}, soffset={dsl_args[2]}{glc_str}{dlc_str})" + # Register soffset only (no offset modifier) + if len(operands) >= 3: + return f"{mnemonic}(sdata={dsl_args[0]}, sbase={dsl_args[1]}, soffset={dsl_args[2]}{glc_str}{dlc_str})" + # Buffer ops with 'off' + if mnemonic.startswith('buffer_') and len(operands) >= 2 and operands[1].strip().lower() == 'off': + soff = f"RawImm({_unwrap_dsl(dsl_args[3])})" if len(dsl_args) > 3 else "RawImm(0)" + return f"{mnemonic}(vdata={dsl_args[0]}, vaddr=0, srsrc={dsl_args[2]}, soffset={soff})" + # FLAT/GLOBAL/SCRATCH load + if (mnemonic.startswith('flat_load') or mnemonic.startswith('global_load') or mnemonic.startswith('scratch_load')) and len(dsl_args) >= 3: + off = f", offset={offset_val}" if offset_val else "" + return f"{mnemonic}(vdst={dsl_args[0]}, addr={dsl_args[1]}, saddr={dsl_args[2]}{off})" + # FLAT/GLOBAL/SCRATCH store + if (mnemonic.startswith('flat_store') or mnemonic.startswith('global_store') or mnemonic.startswith('scratch_store')) and len(dsl_args) >= 3: + off = f", offset={offset_val}" if offset_val else "" + return f"{mnemonic}(addr={dsl_args[0]}, data={dsl_args[1]}, saddr={dsl_args[2]}{off})" + # Handle v_fmaak/v_fmamk literals + lit_str = "" + if mnemonic in ('v_fmaak_f32', 'v_fmaak_f16') and len(dsl_args) == 4: + lit_str, dsl_args = f", literal={_unwrap_dsl(dsl_args[3])}", dsl_args[:3] + elif mnemonic in ('v_fmamk_f32', 'v_fmamk_f16') and len(dsl_args) == 4: + lit_str, dsl_args = f", literal={_unwrap_dsl(dsl_args[2])}", [dsl_args[0], dsl_args[1], dsl_args[3]] + # Handle v_add_co_ci_u32_e32 etc with vcc operands - strip implicit vcc sdst and carry_in, add _e32 suffix + vcc_ops = {'v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'} + if mnemonic.replace('_e32', '') in vcc_ops and len(dsl_args) >= 5: + mnemonic = mnemonic.replace('_e32', '') + '_e32' # Ensure _e32 suffix for VOP2 encoding + dsl_args = [dsl_args[0], dsl_args[2], dsl_args[3]] + # v_cmp_*_e32: strip implicit vcc_lo dest + if mnemonic.startswith('v_cmp') and not mnemonic.endswith('_e64') and len(dsl_args) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'): + dsl_args = dsl_args[1:] + # CMPX with _e64: prepend implicit EXEC_LO (vdst=126) + if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(dsl_args) == 2: + dsl_args = ['RawImm(126)'] + dsl_args + # Build the function name - use mnemonic as-is, replacing . with _ + func_name = mnemonic.replace('.', '_') + # When explicit opsel is given, strip .h/.l from register args (opsel overrides) + if opsel_explicit is not None: + dsl_args = [re.sub(r'\.[hl]$', '', a) for a in dsl_args] + args_str = ', '.join(dsl_args) + all_kwargs = list(kwargs) + if lit_str: all_kwargs.append(lit_str.lstrip(', ')) + if opsel_explicit is not None: all_kwargs.append(f'opsel={opsel_explicit}') + if neg_lo_val is not None: all_kwargs.append(f'neg={neg_lo_val}') + if neg_hi_val is not None: all_kwargs.append(f'neg_hi={neg_hi_val}') + kwargs_str = ', '.join(all_kwargs) + if kwargs_str: + return f"{func_name}({args_str}, {kwargs_str})" if args_str else f"{func_name}({kwargs_str})" + return f"{func_name}({args_str})" + +def asm(text: str) -> Inst: + """Assemble LLVM-style instruction text to Inst by transforming to DSL and eval.""" + from extra.assembly.amd.autogen import rdna3 as autogen + dsl_expr = get_dsl(text) + namespace = {name: getattr(autogen, name) for name in dir(autogen) if not name.startswith('_')} + namespace.update({'s': s, 'v': v, 'ttmp': ttmp, 'abs': abs, 'RawImm': RawImm, 'SrcMod': SrcMod, 'VGPR': VGPR, 'SGPR': SGPR, 'TTMP': TTMP, + 'VCC_LO': VCC_LO, 'VCC_HI': VCC_HI, 'VCC': VCC, 'EXEC_LO': EXEC_LO, 'EXEC_HI': EXEC_HI, 'EXEC': EXEC, + 'SCC': SCC, 'M0': M0, 'NULL': NULL, 'OFF': OFF}) + try: + return eval(dsl_expr, namespace) + except NameError: + # Try with _e32 suffix for VOP1/VOP2/VOPC (only for v_* instructions) + if m := re.match(r'^(v_\w+)(\(.*\))$', dsl_expr): + return eval(f"{m.group(1)}_e32{m.group(2)}", namespace) + raise diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 2988073f3e..ae62c2fcee 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -39,11 +39,27 @@ class _Bits: def __getitem__(self, key) -> BitField: return BitField(key.start, key.stop) if isinstance(key, slice) else BitField(key, key) bits = _Bits() +# Source operand with modifiers - base class for anything that can be a src with neg/abs +class SrcMod: + __slots__ = ('val', 'neg', 'abs_') + def __init__(self, val: int, neg: bool = False, abs_: bool = False): self.val, self.neg, self.abs_ = val, neg, abs_ + def __repr__(self): return f"{'-' if self.neg else ''}{'|' if self.abs_ else ''}{self.val}{'|' if self.abs_ else ''}" + def __neg__(self): return SrcMod(self.val, not self.neg, self.abs_) + def __abs__(self): return SrcMod(self.val, self.neg, True) + # Register types -class Reg: - def __init__(self, idx: int, count: int = 1, hi: bool = False, neg: bool = False): self.idx, self.count, self.hi, self.neg = idx, count, hi, neg +class Reg(SrcMod): + __slots__ = ('idx', 'count', 'hi') + def __init__(self, idx: int, count: int = 1, hi: bool = False, neg: bool = False, abs_: bool = False): + self.idx, self.count, self.hi = idx, count, hi + super().__init__(idx, neg, abs_) def __repr__(self): return f"{self.__class__.__name__.lower()[0]}[{self.idx}]" if self.count == 1 else f"{self.__class__.__name__.lower()[0]}[{self.idx}:{self.idx + self.count}]" - def __neg__(self): return self.__class__(self.idx, self.count, self.hi, neg=not self.neg) + def __neg__(self): return self.__class__(self.idx, self.count, self.hi, not self.neg, self.abs_) + def __abs__(self): return self.__class__(self.idx, self.count, self.hi, self.neg, True) + @property + def l(self): return self.__class__(self.idx, self.count, False, self.neg, self.abs_) + @property + def h(self): return self.__class__(self.idx, self.count, True, self.neg, self.abs_) T = TypeVar('T', bound=Reg) class _RegFactory(Generic[T]): @@ -63,6 +79,11 @@ s: _RegFactory[SGPR] = _RegFactory(SGPR, "SGPR") v: _RegFactory[VGPR] = _RegFactory(VGPR, "VGPR") ttmp: _RegFactory[TTMP] = _RegFactory(TTMP, "TTMP") +# Special registers as SrcMod objects (support -VCC_LO, abs(EXEC_LO), etc.) +VCC_LO, VCC_HI, VCC = SrcMod(106), SrcMod(107), SrcMod(106) +EXEC_LO, EXEC_HI, EXEC = SrcMod(126), SrcMod(127), SrcMod(126) +SCC, M0, NULL, OFF = SrcMod(253), SrcMod(125), SrcMod(124), SrcMod(124) + # Field type markers (runtime classes for validation) class _SSrc: pass class _Src: pass @@ -86,21 +107,35 @@ class RawImm: def __eq__(self, other): return isinstance(other, RawImm) and self.val == other.val def unwrap(val) -> int: - return val.val if isinstance(val, RawImm) else val.value if hasattr(val, 'value') else val.idx if hasattr(val, 'idx') else val + if isinstance(val, RawImm): return val.val + if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special registers like VCC_LO, NULL + if hasattr(val, 'value'): return val.value # IntEnum + if hasattr(val, 'idx'): return val.idx # Reg + return val # Encoding helpers FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0: 246, -4.0: 247} SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'} -RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata'} +RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata', 'vsrc1'} -def _encode_reg(val) -> int: +def _encode_reg(val: Reg) -> int: if isinstance(val, TTMP): return 108 + val.idx - return val.idx | (0x80 if val.hi else 0) + return val.idx # hi bit is handled via opsel, not in register encoding def encode_src(val) -> int: if isinstance(val, VGPR): return 256 + _encode_reg(val) if isinstance(val, Reg): return _encode_reg(val) - if hasattr(val, 'value'): return val.value + if isinstance(val, SrcMod) and not isinstance(val, Reg): + # SrcMod wraps either special registers (VCC_LO=106, EXEC_LO=126, etc.) or literals + # Special register values are in valid encoding ranges - return as-is + # Literals (large integers) need 255 marker + v = val.val + # Valid source encoding ranges: 0-127 (SGPRs/special), 128-192 (inline const), 193-208 (neg inline), 240-247 (float), 251-253 (special) + if 0 <= v <= 127 or 240 <= v <= 255: return v # SGPRs, special regs, float constants + if 128 <= v <= 192: return v # Inline positive constants (0-64) + if 193 <= v <= 208: return v # Inline negative constants (-1 to -16) + return 255 # Literal marker - value stored separately + if hasattr(val, 'value'): return val.value # IntEnum if isinstance(val, float): return 128 if val == 0.0 else FLOAT_ENC.get(val, 255) return 128 + val if isinstance(val, int) and 0 <= val <= 64 else 192 + (-val) if isinstance(val, int) and -16 <= val <= -1 else 255 @@ -156,34 +191,61 @@ class Inst: # Type validation if marker is _SGPRField: if isinstance(val, VGPR): raise TypeError(f"field '{name}' requires SGPR, got VGPR") - if not isinstance(val, (SGPR, TTMP, int, RawImm)): raise TypeError(f"field '{name}' requires SGPR, got {type(val).__name__}") + if not isinstance(val, (SGPR, TTMP, SrcMod, int, RawImm)): raise TypeError(f"field '{name}' requires SGPR, got {type(val).__name__}") if marker is _VGPRField: if not isinstance(val, VGPR): raise TypeError(f"field '{name}' requires VGPR, got {type(val).__name__}") if marker is _SSrc and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires scalar source, got VGPR") # Encode source fields as RawImm for consistent disassembly if name in SRC_FIELDS: encoded = encode_src(val) + # For VOP1/VOP2/VOPC (no opsel field), encode hi bit in src value + if isinstance(val, Reg) and val.hi and 'opsel' not in self._fields: + encoded |= 0x80 self._values[name] = RawImm(encoded) - # Handle negation modifier for VOP3 instructions - if isinstance(val, Reg) and val.neg and 'neg' in self._fields: - neg_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0) - cur_neg = self._values.get('neg', 0) - self._values['neg'] = (cur_neg.val if isinstance(cur_neg, RawImm) else cur_neg) | neg_bit + # Handle neg/abs/opsel modifiers for VOP3 instructions + if isinstance(val, SrcMod): + if val.neg and 'neg' in self._fields: + neg_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0) + cur_neg = self._values.get('neg', 0) + self._values['neg'] = (cur_neg.val if isinstance(cur_neg, RawImm) else cur_neg) | neg_bit + if val.abs_ and 'abs' in self._fields: + abs_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0) + cur_abs = self._values.get('abs', 0) + self._values['abs'] = (cur_abs.val if isinstance(cur_abs, RawImm) else cur_abs) | abs_bit + # Handle hi (opsel) for 16-bit ops - only for formats with opsel field + if isinstance(val, Reg) and val.hi and 'opsel' in self._fields: + opsel_bit = {'src0': 1, 'src1': 2, 'src2': 4}.get(name, 0) + cur_opsel = self._values.get('opsel', 0) + self._values['opsel'] = (cur_opsel.val if isinstance(cur_opsel, RawImm) else cur_opsel) | opsel_bit # Track literal value if needed (encoded as 255) # For 64-bit ops, store literal in high 32 bits (to match from_bytes decoding and to_bytes encoding) - if encoded == 255 and self._literal is None and isinstance(val, int) and not isinstance(val, IntEnum): - self._literal = (val << 32) if self._is_64bit_op() else val - elif encoded == 255 and self._literal is None and isinstance(val, float): - import struct - lit32 = struct.unpack('> 1 (constraint: vdsty parity must be opposite of vdstx) @@ -192,8 +254,9 @@ class Inst: def _encode_field(self, name: str, val) -> int: if isinstance(val, RawImm): return val.val + if isinstance(val, SrcMod) and not isinstance(val, Reg): return val.val # Special regs like VCC_LO if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val - if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val + if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val.val // 2 if isinstance(val, SrcMod) else val if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val) return val.value if hasattr(val, 'value') else val