mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
assembly/amd: cleanups to asm and emu (#13912)
* a bunch of cleanups * ops are back * bug fixes * cleanups * a lil simpler * more refactors * _disasm_vop1 * sops * more * continue * more * num_srcs * simpler * no _is16 * op cleanups * isinstnace
This commit is contained in:
@@ -1,15 +1,12 @@
|
||||
# RDNA3 assembler and disassembler
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, SRC_FIELDS, unwrap
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, Reg, SrcMod, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory
|
||||
from extra.assembly.amd.dsl import VCC_LO, VCC_HI, VCC, EXEC_LO, EXEC_HI, EXEC, SCC, M0, NULL, OFF
|
||||
from extra.assembly.amd.dsl import SPECIAL_GPRS, SPECIAL_PAIRS, FLOAT_DEC, FLOAT_ENC, decode_src
|
||||
from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, VOPD, VINTERP, SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, DS, FLAT, MUBUF, MTBUF, MIMG, EXP,
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, VINTERPOp, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp)
|
||||
|
||||
# VOP3SD opcodes that share VOP3 encoding
|
||||
VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOPDOp, SOP1Op, SOPKOp, SOPPOp, SMEMOp, DSOp, MUBUFOp)
|
||||
|
||||
def _matches_encoding(word: int, cls: type[Inst]) -> bool:
|
||||
"""Check if word matches the encoding pattern of an instruction class."""
|
||||
@@ -29,7 +26,7 @@ def detect_format(data: bytes) -> type[Inst]:
|
||||
if (word >> 30) == 0b11:
|
||||
for cls in _FORMATS_64:
|
||||
if _matches_encoding(word, cls):
|
||||
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in VOP3SD_OPS else cls
|
||||
return VOP3SD if cls is VOP3 and ((word >> 16) & 0x3ff) in Inst._VOP3SD_OPS else cls
|
||||
raise ValueError(f"unknown 64-bit format word={word:#010x}")
|
||||
# 32-bit formats
|
||||
for cls in _FORMATS_32:
|
||||
@@ -79,8 +76,6 @@ def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int:
|
||||
return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
|
||||
|
||||
def _has(op: str, *subs) -> bool: return any(s in op for s in subs)
|
||||
def _is16(op: str) -> bool: return _has(op, 'f16', 'i16', 'u16', 'b16') and not _has(op, '_f32', '_i32')
|
||||
def _is64(op: str) -> bool: return _has(op, 'f64', 'i64', 'u64', 'b64')
|
||||
def _omod(v: int) -> str: return {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(v, "")
|
||||
def _src16(inst, v: int) -> str: return _fmt_v16(v) if v >= 256 else inst.lit(v) # format 16-bit src: vgpr.h/l or literal
|
||||
def _mods(*pairs) -> str: return " ".join(m for c, m in pairs if c)
|
||||
@@ -105,50 +100,43 @@ def _opsel_str(opsel: int, n: int, need: bool, is16_d: bool) -> str:
|
||||
# DISASSEMBLER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_VOP1_F64 = {VOP1Op.V_CEIL_F64, VOP1Op.V_FLOOR_F64, VOP1Op.V_FRACT_F64, VOP1Op.V_FREXP_MANT_F64, VOP1Op.V_RCP_F64, VOP1Op.V_RNDNE_F64, VOP1Op.V_RSQ_F64, VOP1Op.V_SQRT_F64, VOP1Op.V_TRUNC_F64}
|
||||
|
||||
def _disasm_vop1(inst: VOP1) -> str:
|
||||
op, name = VOP1Op(inst.op), VOP1Op(inst.op).name.lower()
|
||||
if op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
|
||||
if op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
|
||||
parts, is_f64_d = name.split('_'), op in _VOP1_F64 or op in (VOP1Op.V_CVT_F64_F32, VOP1Op.V_CVT_F64_I32, VOP1Op.V_CVT_F64_U32)
|
||||
is_f64_s = op in _VOP1_F64 or op in (VOP1Op.V_CVT_F32_F64, VOP1Op.V_CVT_I32_F64, VOP1Op.V_CVT_U32_F64, VOP1Op.V_FREXP_EXP_I32_F64)
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (VOP1Op.V_NOP, VOP1Op.V_PIPEFLUSH): return name
|
||||
if inst.op == VOP1Op.V_READFIRSTLANE_B32: return f"v_readfirstlane_b32 {decode_src(inst.vdst)}, v{inst.src0 - 256 if inst.src0 >= 256 else inst.src0}"
|
||||
# 16-bit dst: uses .h/.l suffix (determined by name pattern, not dtype - e.g. sat_pk_u8_i16 outputs 8-bit but uses 16-bit encoding)
|
||||
parts = name.split('_')
|
||||
is_16d = any(p in ('f16','i16','u16','b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16','i16','u16','b16') and 'cvt' not in name)
|
||||
is_16s = parts[-1] in ('f16','i16','u16','b16') and 'sat_pk' not in name
|
||||
dst = _vreg(inst.vdst, 2) if is_f64_d else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = _fmt_src(inst.src0, 2) if is_f64_s else _src16(inst, inst.src0) if is_16s else inst.lit(inst.src0)
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else _fmt_v16(inst.vdst, 0, 128) if is_16d else f"v{inst.vdst}"
|
||||
src = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_src_16(0) and 'sat_pk' not in name else inst.lit(inst.src0)
|
||||
return f"{name}_e32 {dst}, {src}"
|
||||
|
||||
def _disasm_vop2(inst: VOP2) -> str:
|
||||
op, name = VOP2Op(inst.op), VOP2Op(inst.op).name.lower()
|
||||
suf, is16 = "" if op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32", _is16(name) and 'pk_' not in name
|
||||
name = inst.op_name.lower()
|
||||
suf = "" if inst.op == VOP2Op.V_DOT2ACC_F32_F16 else "_e32"
|
||||
# fmaak: dst = src0 * vsrc1 + K, fmamk: dst = src0 * K + vsrc1
|
||||
if op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
|
||||
if op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
|
||||
if is16: return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if op == VOP2Op.V_CNDMASK_B32 else "")
|
||||
|
||||
VOPC_CLASS = {VOPCOp.V_CMP_CLASS_F16, VOPCOp.V_CMP_CLASS_F32, VOPCOp.V_CMP_CLASS_F64,
|
||||
VOPCOp.V_CMPX_CLASS_F16, VOPCOp.V_CMPX_CLASS_F32, VOPCOp.V_CMPX_CLASS_F64}
|
||||
if inst.op in (VOP2Op.V_FMAAK_F32, VOP2Op.V_FMAAK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}, 0x{inst._literal:x}"
|
||||
if inst.op in (VOP2Op.V_FMAMK_F32, VOP2Op.V_FMAMK_F16): return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, 0x{inst._literal:x}, v{inst.vsrc1}"
|
||||
if inst.is_16bit(): return f"{name}{suf} {_fmt_v16(inst.vdst, 0, 128)}, {_src16(inst, inst.src0)}, {_fmt_v16(inst.vsrc1, 0, 128)}"
|
||||
return f"{name}{suf} v{inst.vdst}, {inst.lit(inst.src0)}, v{inst.vsrc1}" + (", vcc_lo" if inst.op == VOP2Op.V_CNDMASK_B32 else "")
|
||||
|
||||
def _disasm_vopc(inst: VOPC) -> str:
|
||||
op, name = VOPCOp(inst.op), VOPCOp(inst.op).name.lower()
|
||||
is64, is16 = _is64(name), _is16(name)
|
||||
s0 = _fmt_src(inst.src0, 2) if is64 else _src16(inst, inst.src0) if is16 else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, 2) if is64 and op not in VOPC_CLASS else _fmt_v16(inst.vsrc1, 0, 128) if is16 else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
name = inst.op_name.lower()
|
||||
s0 = _fmt_src(inst.src0, inst.src_regs(0)) if inst.src_regs(0) > 1 else _src16(inst, inst.src0) if inst.is_16bit() else inst.lit(inst.src0)
|
||||
s1 = _vreg(inst.vsrc1, inst.src_regs(1)) if inst.src_regs(1) > 1 else _fmt_v16(inst.vsrc1, 0, 128) if inst.is_16bit() else f"v{inst.vsrc1}"
|
||||
return f"{name}_e32 {s0}, {s1}" if inst.op.value >= 128 else f"{name}_e32 vcc_lo, {s0}, {s1}"
|
||||
|
||||
NO_ARG_SOPP = {SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER, SOPPOp.S_WAKEUP, SOPPOp.S_ICACHE_INV,
|
||||
SOPPOp.S_WAIT_IDLE, SOPPOp.S_ENDPGM_SAVED, SOPPOp.S_CODE_END, SOPPOp.S_ENDPGM_ORDERED_PS_DONE}
|
||||
|
||||
def _disasm_sopp(inst: SOPP) -> str:
|
||||
op, name = SOPPOp(inst.op), SOPPOp(inst.op).name.lower()
|
||||
if op in NO_ARG_SOPP: return name
|
||||
if op == SOPPOp.S_WAITCNT:
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in NO_ARG_SOPP: return name
|
||||
if inst.op == SOPPOp.S_WAITCNT:
|
||||
vm, exp, lgkm = (inst.simm16 >> 10) & 0x3f, inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x3f
|
||||
p = [f"vmcnt({vm})" if vm != 0x3f else "", f"expcnt({exp})" if exp != 7 else "", f"lgkmcnt({lgkm})" if lgkm != 0x3f else ""]
|
||||
return f"s_waitcnt {' '.join(x for x in p if x) or '0'}"
|
||||
if op == SOPPOp.S_DELAY_ALU:
|
||||
if inst.op == SOPPOp.S_DELAY_ALU:
|
||||
deps, skips = ['VALU_DEP_1','VALU_DEP_2','VALU_DEP_3','VALU_DEP_4','TRANS32_DEP_1','TRANS32_DEP_2','TRANS32_DEP_3','FMA_ACCUM_CYCLE_1','SALU_CYCLE_1','SALU_CYCLE_2','SALU_CYCLE_3'], ['SAME','NEXT','SKIP_1','SKIP_2','SKIP_3','SKIP_4']
|
||||
id0, skip, id1 = inst.simm16 & 0xf, (inst.simm16 >> 4) & 0x7, (inst.simm16 >> 7) & 0xf
|
||||
dep = lambda v: deps[v-1] if 0 < v <= len(deps) else str(v)
|
||||
@@ -157,25 +145,20 @@ def _disasm_sopp(inst: SOPP) -> str:
|
||||
return f"{name} {inst.simm16}" if name.startswith(('s_cbranch', 's_branch')) else f"{name} 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_smem(inst: SMEM) -> str:
|
||||
op = SMEMOp(inst.op)
|
||||
name = op.name.lower()
|
||||
if op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (SMEMOp.S_GL1_INV, SMEMOp.S_DCACHE_INV): return name
|
||||
off_s = f"{decode_src(inst.soffset)} offset:0x{inst.offset:x}" if inst.offset and inst.soffset != 124 else f"0x{inst.offset:x}" if inst.offset else decode_src(inst.soffset)
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op <= 12 or name == 's_atc_probe_buffer') else 2
|
||||
sbase_idx, sbase_count = inst.sbase * 2, 4 if (8 <= inst.op.value <= 12 or name == 's_atc_probe_buffer') else 2
|
||||
sbase_str = _fmt_src(sbase_idx, sbase_count) if sbase_count == 2 else _sreg(sbase_idx, sbase_count) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_count)
|
||||
if name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{name} {inst.sdata}, {sbase_str}, {off_s}"
|
||||
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(inst.op, 1)
|
||||
return f"{name} {_fmt_sdst(inst.sdata, width)}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
return f"{name} {_fmt_sdst(inst.sdata, inst.dst_regs())}, {sbase_str}, {off_s}" + _mods((inst.glc, " glc"), (inst.dlc, " dlc"))
|
||||
|
||||
def _disasm_flat(inst: FLAT) -> str:
|
||||
name = FLATOp(inst.op).name.lower()
|
||||
name = inst.op_name.lower()
|
||||
seg = ['flat', 'scratch', 'global'][inst.seg] if inst.seg < 3 else 'flat'
|
||||
instr = f"{seg}_{name.split('_', 1)[1] if '_' in name else name}"
|
||||
off_val = inst.offset if seg == 'flat' else (inst.offset if inst.offset < 4096 else inst.offset - 8192)
|
||||
suffix = name.split('_')[-1]
|
||||
w = {'b32':1,'b64':2,'b96':3,'b128':4,'u8':1,'i8':1,'u16':1,'i16':1,'u32':1,'i32':1,'u64':2,'i64':2,'f32':1,'f64':2}.get(suffix, 1)
|
||||
if 'cmpswap' in name: w *= 2
|
||||
if name.endswith('_x2') or 'x2' in suffix: w = max(w, 2)
|
||||
w = inst.dst_regs() * (2 if 'cmpswap' in name else 1)
|
||||
mods = f"{f' offset:{off_val}' if off_val else ''}{' glc' if inst.glc else ''}{' slc' if inst.slc else ''}{' dlc' if inst.dlc else ''}"
|
||||
# saddr
|
||||
if seg == 'flat' or inst.saddr == 0x7F: saddr_s = ""
|
||||
@@ -195,11 +178,11 @@ def _disasm_flat(inst: FLAT) -> str:
|
||||
return f"{instr} {_vreg(inst.vdst, w)}, {addr_s}{saddr_s}{mods}"
|
||||
|
||||
def _disasm_ds(inst: DS) -> str:
|
||||
op, name = DSOp(inst.op), DSOp(inst.op).name.lower()
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
gds = " gds" if inst.gds else ""
|
||||
off = f" offset:{inst.offset0 | (inst.offset1 << 8)}" if inst.offset0 or inst.offset1 else ""
|
||||
off2 = f" offset0:{inst.offset0} offset1:{inst.offset1}" if inst.offset0 or inst.offset1 else ""
|
||||
w = 4 if '128' in name else 3 if '96' in name else 2 if (name.endswith('64') or 'gs_reg' in name) else 1
|
||||
w = inst.dst_regs()
|
||||
d0, d1, dst, addr = _vreg(inst.data0, w), _vreg(inst.data1, w), _vreg(inst.vdst, w), f"v{inst.addr}"
|
||||
|
||||
if op == DSOp.DS_NOP: return name
|
||||
@@ -223,49 +206,34 @@ def _disasm_ds(inst: DS) -> str:
|
||||
return f"{name} {dst}, {addr}, {d0}{off}{gds}" if '_rtn' in name else f"{name} {addr}, {d0}{off}{gds}"
|
||||
|
||||
def _disasm_vop3(inst: VOP3) -> str:
|
||||
op = VOP3SDOp(inst.op) if inst.op in VOP3SD_OPS else VOP3Op(inst.op)
|
||||
name = op.name.lower()
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
|
||||
# VOP3SD (shared encoding)
|
||||
if inst.op in VOP3SD_OPS:
|
||||
if isinstance(op, VOP3SDOp):
|
||||
sdst = (inst.clmp << 7) | (inst.opsel << 3) | inst.abs
|
||||
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32')
|
||||
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
|
||||
dst = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}"
|
||||
if op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}"
|
||||
if op in (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32): return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}"
|
||||
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {s0}, {s1}, {s2}" + _omod(inst.omod)
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
return f"{name} {dst}, {_fmt_sdst(sdst, 1)}, {srcs}" + _omod(inst.omod)
|
||||
|
||||
# Detect operand sizes
|
||||
is64 = _is64(name)
|
||||
is64_src, is64_dst = False, False
|
||||
# Detect 16-bit operand sizes (for .h/.l suffix handling)
|
||||
is16_d = is16_s = is16_s2 = False
|
||||
if 'cvt_pk' in name: is16_s = name.endswith('16')
|
||||
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', name):
|
||||
is16_d, is16_s = _has(m.group(1), 'f16','i16','u16','b16'), _has(m.group(2), 'f16','i16','u16','b16')
|
||||
is64_src, is64_dst = '64' in m.group(2), '64' in m.group(1)
|
||||
is16_s2, is64 = is16_s, False
|
||||
is16_s2 = is16_s
|
||||
elif re.match(r'v_mad_[iu]32_[iu]16', name): is16_s = True
|
||||
elif 'pack_b32' in name: is16_s = is16_s2 = True
|
||||
else: is16_d = is16_s = is16_s2 = _is16(name) and not _has(name, 'dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad')
|
||||
|
||||
# Source counts
|
||||
shift64 = 'rev' in name and '64' in name and name.startswith('v_')
|
||||
ldexp64 = op == VOP3Op.V_LDEXP_F64
|
||||
trig = op == VOP3Op.V_TRIG_PREOP_F64
|
||||
sad64, mqsad = _has(name, 'qsad_pk', 'mqsad_pk'), 'mqsad_u32' in name
|
||||
s0n = 2 if ((is64 and not shift64) or sad64 or mqsad or is64_src) else 1
|
||||
s1n = 2 if (is64 and not _has(name, 'class') and not ldexp64 and not trig) else 1
|
||||
s2n = 4 if mqsad else 2 if (is64 or sad64) else 1
|
||||
else: is16_d = is16_s = is16_s2 = inst.is_16bit()
|
||||
|
||||
any_hi = inst.opsel != 0
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, s0n, is16_s, any_hi)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, s1n, is16_s, any_hi)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, s2n, is16_s2, any_hi)
|
||||
s0 = _vop3_src(inst, inst.src0, inst.neg&1, inst.abs&1, inst.opsel&1, inst.src_regs(0), is16_s, any_hi)
|
||||
s1 = _vop3_src(inst, inst.src1, inst.neg&2, inst.abs&2, inst.opsel&2, inst.src_regs(1), is16_s, any_hi)
|
||||
s2 = _vop3_src(inst, inst.src2, inst.neg&4, inst.abs&4, inst.opsel&4, inst.src_regs(2), is16_s2, any_hi)
|
||||
|
||||
# Destination
|
||||
dn = 4 if mqsad else 2 if (is64 or sad64 or is64_dst) else 1
|
||||
dn = inst.dst_regs()
|
||||
if op == VOP3Op.V_READLANE_B32: dst = _fmt_sdst(inst.vdst, 1)
|
||||
elif dn > 1: dst = _vreg(inst.vdst, dn)
|
||||
elif is16_d: dst = f"v{inst.vdst}.h" if (inst.opsel & 8) else f"v{inst.vdst}.l" if any_hi else f"v{inst.vdst}"
|
||||
@@ -278,24 +246,24 @@ def _disasm_vop3(inst: VOP3) -> str:
|
||||
if inst.op < 256: # VOPC
|
||||
return f"{name}_e64 {s0}, {s1}" if name.startswith('v_cmpx') else f"{name}_e64 {_fmt_sdst(inst.vdst, 1)}, {s0}, {s1}"
|
||||
if inst.op < 384: # VOP2
|
||||
os = _opsel_str(inst.opsel, 3, need_opsel, is16_d) if 'cndmask' in name else _opsel_str(inst.opsel, 2, need_opsel, is16_d)
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if 'cndmask' in name else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
n = inst.num_srcs()
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
return f"{name}_e64 {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name}_e64 {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
if inst.op < 512: # VOP1
|
||||
return f"{name}_e64" if op in (VOP3Op.V_NOP, VOP3Op.V_PIPEFLUSH) else f"{name}_e64 {dst}, {s0}{_opsel_str(inst.opsel, 1, need_opsel, is16_d)}{cl}{om}"
|
||||
# Native VOP3
|
||||
is3 = _has(name, 'fma', 'mad', 'min3', 'max3', 'med3', 'div_fix', 'div_fmas', 'sad', 'lerp', 'align', 'cube', 'bfe', 'bfi',
|
||||
'perm_b32', 'permlane', 'cndmask', 'xor3', 'or3', 'add3', 'lshl_or', 'and_or', 'lshl_add', 'add_lshl', 'xad', 'maxmin', 'minmax', 'dot2', 'cvt_pk_u8', 'mullit')
|
||||
os = _opsel_str(inst.opsel, 3 if is3 else 2, need_opsel, is16_d)
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if is3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
n = inst.num_srcs()
|
||||
os = _opsel_str(inst.opsel, n, need_opsel, is16_d)
|
||||
return f"{name} {dst}, {s0}, {s1}, {s2}{os}{cl}{om}" if n == 3 else f"{name} {dst}, {s0}, {s1}{os}{cl}{om}"
|
||||
|
||||
def _disasm_vop3sd(inst: VOP3SD) -> str:
|
||||
op, name = VOP3SDOp(inst.op), VOP3SDOp(inst.op).name.lower()
|
||||
is64, mad64 = 'f64' in name, _has(name, 'mad_i64_i32', 'mad_u64_u32')
|
||||
def src(v, neg, ext=False): s = _fmt_src(v, 2) if ext or is64 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1), src(inst.src1, inst.neg & 2), src(inst.src2, inst.neg & 4, mad64)
|
||||
dst, is2src = _vreg(inst.vdst, 2) if is64 or mad64 else f"v{inst.vdst}", op in (VOP3SDOp.V_ADD_CO_U32, VOP3SDOp.V_SUB_CO_U32, VOP3SDOp.V_SUBREV_CO_U32)
|
||||
name = inst.op_name.lower()
|
||||
def src(v, neg, n): s = _fmt_src(v, n) if n > 1 else inst.lit(v); return f"-{s}" if neg else s
|
||||
s0, s1, s2 = src(inst.src0, inst.neg & 1, inst.src_regs(0)), src(inst.src1, inst.neg & 2, inst.src_regs(1)), src(inst.src2, inst.neg & 4, inst.src_regs(2))
|
||||
dst = _vreg(inst.vdst, inst.dst_regs()) if inst.dst_regs() > 1 else f"v{inst.vdst}"
|
||||
srcs = f"{s0}, {s1}, {s2}" if inst.num_srcs() == 3 else f"{s0}, {s1}"
|
||||
suffix = "_e64" if name.startswith('v_') and 'co_' in name else ""
|
||||
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {s0}, {s1}{'' if is2src else f', {s2}'}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
|
||||
return f"{name}{suffix} {dst}, {_fmt_sdst(inst.sdst, 1)}, {srcs}{' clamp' if inst.clmp else ''}{_omod(inst.omod)}"
|
||||
|
||||
def _disasm_vopd(inst: VOPD) -> str:
|
||||
lit = inst._literal or inst.literal
|
||||
@@ -304,26 +272,25 @@ def _disasm_vopd(inst: VOPD) -> str:
|
||||
return f"{half(nx, inst.vdstx, inst.srcx0, inst.vsrcx1)} :: {half(ny, vdst_y, inst.srcy0, inst.vsrcy1)}"
|
||||
|
||||
def _disasm_vop3p(inst: VOP3P) -> str:
|
||||
name = VOP3POp(inst.op).name.lower()
|
||||
is_wmma, is_3src, is_fma_mix = 'wmma' in name, _has(name, 'fma', 'mad', 'dot', 'wmma'), 'fma_mix' in name
|
||||
name = inst.op_name.lower()
|
||||
is_wmma, n, is_fma_mix = 'wmma' in name, inst.num_srcs(), 'fma_mix' in name
|
||||
if is_wmma:
|
||||
sc = 2 if 'iu4' in name else 4 if 'iu8' in name else 8
|
||||
src0, src1, src2, dst = _fmt_src(inst.src0, sc), _fmt_src(inst.src1, sc), _fmt_src(inst.src2, 8), _vreg(inst.vdst, 8)
|
||||
else: src0, src1, src2, dst = _fmt_src(inst.src0, 1), _fmt_src(inst.src1, 1), _fmt_src(inst.src2, 1), f"v{inst.vdst}"
|
||||
n, opsel_hi = 3 if is_3src else 2, inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
opsel_hi = inst.opsel_hi | (inst.opsel_hi2 << 2)
|
||||
if is_fma_mix:
|
||||
def m(s, neg, abs_): return f"-{f'|{s}|' if abs_ else s}" if neg else (f"|{s}|" if abs_ else s)
|
||||
src0, src1, src2 = m(src0, inst.neg & 1, inst.neg_hi & 1), m(src1, inst.neg & 2, inst.neg_hi & 2), m(src2, inst.neg & 4, inst.neg_hi & 4)
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
else:
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if is_3src else 3) else []) + \
|
||||
mods = ([_fmt_bits("op_sel", inst.opsel, n)] if inst.opsel else []) + ([_fmt_bits("op_sel_hi", opsel_hi, n)] if opsel_hi != (7 if n == 3 else 3) else []) + \
|
||||
([_fmt_bits("neg_lo", inst.neg, n)] if inst.neg else []) + ([_fmt_bits("neg_hi", inst.neg_hi, n)] if inst.neg_hi else []) + (["clamp"] if inst.clmp else [])
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if is_3src else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
return f"{name} {dst}, {src0}, {src1}, {src2}{' ' + ' '.join(mods) if mods else ''}" if n == 3 else f"{name} {dst}, {src0}, {src1}{' ' + ' '.join(mods) if mods else ''}"
|
||||
|
||||
def _disasm_buf(inst: MUBUF | MTBUF) -> str:
|
||||
op = MTBUFOp(inst.op) if isinstance(inst, MTBUF) else MUBUFOp(inst.op)
|
||||
name = op.name.lower()
|
||||
if op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
name = inst.op_name.lower()
|
||||
if inst.op in (MUBUFOp.BUFFER_GL0_INV, MUBUFOp.BUFFER_GL1_INV): return name
|
||||
w = (2 if _has(name, 'xyz', 'xyzw') else 1) if 'd16' in name else \
|
||||
((2 if _has(name, 'b64', 'u64', 'i64') else 1) * (2 if 'cmpswap' in name else 1)) if 'atomic' in name else \
|
||||
{'b32':1,'b64':2,'b96':3,'b128':4,'b16':1,'x':1,'xy':2,'xyz':3,'xyzw':4}.get(name.split('_')[-1], 1)
|
||||
@@ -351,7 +318,7 @@ def _mimg_vaddr_width(name: str, dim: int, a16: bool) -> int:
|
||||
return (base + packed + 1) // 2 + unpacked if a16 else base + packed + unpacked
|
||||
|
||||
def _disasm_mimg(inst: MIMG) -> str:
|
||||
name = MIMGOp(inst.op).name.lower()
|
||||
name = inst.op_name.lower()
|
||||
srsrc_base = inst.srsrc * 4
|
||||
srsrc_str = _sreg_or_ttmp(srsrc_base, 8)
|
||||
# BVH intersect ray: special case with 4 SGPR srsrc
|
||||
@@ -379,66 +346,38 @@ def _disasm_mimg(inst: MIMG) -> str:
|
||||
ssamp_str = ", " + _sreg_or_ttmp(inst.ssamp * 4, 4)
|
||||
return f"{name} {_vreg(inst.vdata, vdata)}, {vaddr_str}, {srsrc_str}{ssamp_str} {' '.join(mods)}"
|
||||
|
||||
def _sop_widths(name: str) -> tuple[int, int, int]:
|
||||
"""Return (dst_width, src0_width, src1_width) in register count for SOP instructions."""
|
||||
if name in ('s_bitset0_b64', 's_bitset1_b64', 's_bfm_b64'): return 2, 1, 1
|
||||
if name in ('s_lshl_b64', 's_lshr_b64', 's_ashr_i64', 's_bfe_u64', 's_bfe_i64'): return 2, 2, 1
|
||||
if name in ('s_bitcmp0_b64', 's_bitcmp1_b64'): return 1, 2, 1
|
||||
if m := re.search(r'_(b|i|u)(32|64)_(b|i|u)(32|64)$', name): return 2 if m.group(2) == '64' else 1, 2 if m.group(4) == '64' else 1, 1
|
||||
if m := re.search(r'_(b|i|u)(32|64)$', name): sz = 2 if m.group(2) == '64' else 1; return sz, sz, sz
|
||||
return 1, 1, 1
|
||||
|
||||
def _disasm_sop1(inst: SOP1) -> str:
|
||||
op, name = SOP1Op(inst.op), SOP1Op(inst.op).name.lower()
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if op == SOP1Op.S_GETPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}"
|
||||
if op in (SOP1Op.S_SETPC_B64, SOP1Op.S_RFE_B64): return f"{name} {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op == SOP1Op.S_SWAPPC_B64: return f"{name} {_fmt_sdst(inst.sdst, 2)}, {_fmt_src(inst.ssrc0, 2)}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, 2 if 'b64' in name else 1)}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
dn, s0n, _ = _sop_widths(name)
|
||||
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if s0n == 1 else _fmt_src(inst.ssrc0, s0n)}"
|
||||
if op in (SOP1Op.S_SENDMSG_RTN_B32, SOP1Op.S_SENDMSG_RTN_B64): return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, sendmsg({MSG.get(inst.ssrc0, str(inst.ssrc0))})"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.src_regs(0) == 1 else _fmt_src(inst.ssrc0, inst.src_regs(0))}"
|
||||
|
||||
def _disasm_sop2(inst: SOP2) -> str:
|
||||
name = SOP2Op(inst.op).name.lower()
|
||||
dn, s0n, s1n = _sop_widths(name)
|
||||
return f"{name} {_fmt_sdst(inst.sdst, dn)}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, s0n)}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, s1n)}"
|
||||
return f"{inst.op_name.lower()} {_fmt_sdst(inst.sdst, inst.dst_regs())}, {inst.lit(inst.ssrc0) if inst.ssrc0 == 255 else _fmt_src(inst.ssrc0, inst.src_regs(0))}, {inst.lit(inst.ssrc1) if inst.ssrc1 == 255 else _fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
|
||||
def _disasm_sopc(inst: SOPC) -> str:
|
||||
name = SOPCOp(inst.op).name.lower()
|
||||
_, s0n, s1n = _sop_widths(name)
|
||||
return f"{name} {_fmt_src(inst.ssrc0, s0n)}, {_fmt_src(inst.ssrc1, s1n)}"
|
||||
return f"{inst.op_name.lower()} {_fmt_src(inst.ssrc0, inst.src_regs(0))}, {_fmt_src(inst.ssrc1, inst.src_regs(1))}"
|
||||
|
||||
def _disasm_sopk(inst: SOPK) -> str:
|
||||
op, name = SOPKOp(inst.op), SOPKOp(inst.op).name.lower()
|
||||
op, name = inst.op, inst.op_name.lower()
|
||||
if op == SOPKOp.S_VERSION: return f"{name} 0x{inst.simm16:x}"
|
||||
if op in (SOPKOp.S_SETREG_B32, SOPKOp.S_GETREG_B32):
|
||||
hid, hoff, hsz = inst.simm16 & 0x3f, (inst.simm16 >> 6) & 0x1f, ((inst.simm16 >> 11) & 0x1f) + 1
|
||||
hs = f"0x{inst.simm16:x}" if hid in (16, 17) else f"hwreg({HWREG.get(hid, str(hid))}, {hoff}, {hsz})"
|
||||
return f"{name} {hs}, {_fmt_sdst(inst.sdst, 1)}" if op == SOPKOp.S_SETREG_B32 else f"{name} {_fmt_sdst(inst.sdst, 1)}, {hs}"
|
||||
dn, _, _ = _sop_widths(name)
|
||||
return f"{name} {_fmt_sdst(inst.sdst, dn)}, 0x{inst.simm16:x}"
|
||||
return f"{name} {_fmt_sdst(inst.sdst, inst.dst_regs())}, 0x{inst.simm16:x}"
|
||||
|
||||
def _disasm_vinterp(inst: VINTERP) -> str:
|
||||
name = VINTERPOp(inst.op).name.lower()
|
||||
src0 = f"-{inst.lit(inst.src0)}" if inst.neg & 1 else inst.lit(inst.src0)
|
||||
src1 = f"-{inst.lit(inst.src1)}" if inst.neg & 2 else inst.lit(inst.src1)
|
||||
src2 = f"-{inst.lit(inst.src2)}" if inst.neg & 4 else inst.lit(inst.src2)
|
||||
mods = _mods((inst.waitexp, f"wait_exp:{inst.waitexp}"), (inst.clmp, "clamp"))
|
||||
return f"{name} v{inst.vdst}, {src0}, {src1}, {src2}" + (" " + mods if mods else "")
|
||||
|
||||
def _disasm_generic(inst: Inst) -> str:
|
||||
name = f"op_{inst.op}"
|
||||
def format_field(field_name, val):
|
||||
val = unwrap(val)
|
||||
if field_name in SRC_FIELDS: return inst.lit(val) if val != 255 else "0xff"
|
||||
return f"{'s' if field_name == 'sdst' else 'v'}{val}" if field_name in ('sdst', 'vdst') else f"v{val}" if field_name == 'vsrc1' else f"0x{val:x}" if field_name == 'simm16' else str(val)
|
||||
operands = [format_field(field_name, inst._values.get(field_name, 0)) for field_name in inst._fields if field_name not in ('encoding', 'op')]
|
||||
return f"{name} {', '.join(operands)}" if operands else name
|
||||
return f"{inst.op_name.lower()} v{inst.vdst}, {inst.lit(inst.src0, inst.neg & 1)}, {inst.lit(inst.src1, inst.neg & 2)}, {inst.lit(inst.src2, inst.neg & 4)}" + (" " + mods if mods else "")
|
||||
|
||||
DISASM_HANDLERS = {VOP1: _disasm_vop1, VOP2: _disasm_vop2, VOPC: _disasm_vopc, VOP3: _disasm_vop3, VOP3SD: _disasm_vop3sd, VOPD: _disasm_vopd, VOP3P: _disasm_vop3p,
|
||||
VINTERP: _disasm_vinterp, SOPP: _disasm_sopp, SMEM: _disasm_smem, DS: _disasm_ds, FLAT: _disasm_flat, MUBUF: _disasm_buf, MTBUF: _disasm_buf,
|
||||
MIMG: _disasm_mimg, SOP1: _disasm_sop1, SOP2: _disasm_sop2, SOPC: _disasm_sopc, SOPK: _disasm_sopk}
|
||||
|
||||
def disasm(inst: Inst) -> str: return DISASM_HANDLERS.get(type(inst), _disasm_generic)(inst)
|
||||
def disasm(inst: Inst) -> str: return DISASM_HANDLERS[type(inst)](inst)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ASSEMBLER
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
import struct, math
|
||||
from enum import IntEnum
|
||||
from typing import overload, Annotated, TypeVar, Generic
|
||||
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
|
||||
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
|
||||
|
||||
# Common masks and bit conversion functions
|
||||
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
|
||||
@@ -28,6 +30,79 @@ def _i64(f):
|
||||
try: return struct.unpack("<Q", struct.pack("<d", f))[0]
|
||||
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
|
||||
|
||||
# Instruction spec - register counts and dtypes derived from instruction names
|
||||
import re
|
||||
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
|
||||
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
|
||||
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
|
||||
def _suffix(name: str) -> tuple[str | None, str | None]:
|
||||
name = name.upper()
|
||||
if m := re.search(r'CVT_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
|
||||
if m := re.search(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$', name): return m.group(1), m.group(2)
|
||||
if m := re.search(r'PACK_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
|
||||
# Generic dst_src pattern: S_BCNT0_I32_B64, S_BITREPLICATE_B64_B32, V_FREXP_EXP_I32_F64, etc.
|
||||
if m := re.search(r'_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
|
||||
if m := re.search(r'_([FIUB](?:32|64|16|8|96|128|256|512))$', name): return m.group(1), m.group(1)
|
||||
return None, None
|
||||
_SPECIAL_REGS = {
|
||||
'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1),
|
||||
'S_LSHL_B64': (2, 2, 1, 1), 'S_LSHR_B64': (2, 2, 1, 1), 'S_ASHR_I64': (2, 2, 1, 1),
|
||||
'S_BFE_U64': (2, 2, 1, 1), 'S_BFE_I64': (2, 2, 1, 1), 'S_BFM_B64': (2, 1, 1, 1),
|
||||
'S_BITSET0_B64': (2, 1, 1, 1), 'S_BITSET1_B64': (2, 1, 1, 1),
|
||||
'S_BITCMP0_B64': (1, 2, 1, 1), 'S_BITCMP1_B64': (1, 2, 1, 1),
|
||||
'V_LDEXP_F64': (2, 2, 1, 1), 'V_TRIG_PREOP_F64': (2, 2, 1, 1),
|
||||
'V_CMP_CLASS_F64': (1, 2, 1, 1), 'V_CMPX_CLASS_F64': (1, 2, 1, 1),
|
||||
'V_CMP_CLASS_F32': (1, 1, 1, 1), 'V_CMPX_CLASS_F32': (1, 1, 1, 1),
|
||||
'V_CMP_CLASS_F16': (1, 1, 1, 1), 'V_CMPX_CLASS_F16': (1, 1, 1, 1),
|
||||
'V_MAD_U64_U32': (2, 1, 1, 2), 'V_MAD_I64_I32': (2, 1, 1, 2),
|
||||
'V_QSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_PK_U16_U8': (2, 2, 1, 2), 'V_MQSAD_U32_U8': (4, 2, 1, 4),
|
||||
}
|
||||
_SPECIAL_DTYPE = {
|
||||
'V_LSHLREV_B64': ('B64', 'U32', 'B64', None), 'V_LSHRREV_B64': ('B64', 'U32', 'B64', None), 'V_ASHRREV_I64': ('I64', 'U32', 'I64', None),
|
||||
'S_LSHL_B64': ('B64', 'B64', 'U32', None), 'S_LSHR_B64': ('B64', 'B64', 'U32', None), 'S_ASHR_I64': ('I64', 'I64', 'U32', None),
|
||||
'S_BFE_U64': ('U64', 'U64', 'U32', None), 'S_BFE_I64': ('I64', 'I64', 'U32', None),
|
||||
'S_BFM_B64': ('B64', 'U32', 'U32', None), 'S_BITSET0_B64': ('B64', 'U32', None, None), 'S_BITSET1_B64': ('B64', 'U32', None, None),
|
||||
'S_BITCMP0_B64': ('SCC', 'B64', 'U32', None), 'S_BITCMP1_B64': ('SCC', 'B64', 'U32', None),
|
||||
'V_LDEXP_F64': ('F64', 'F64', 'I32', None), 'V_TRIG_PREOP_F64': ('F64', 'F64', 'U32', None),
|
||||
'V_CMP_CLASS_F64': ('VCC', 'F64', 'U32', None), 'V_CMPX_CLASS_F64': ('EXEC', 'F64', 'U32', None),
|
||||
'V_CMP_CLASS_F32': ('VCC', 'F32', 'U32', None), 'V_CMPX_CLASS_F32': ('EXEC', 'F32', 'U32', None),
|
||||
'V_CMP_CLASS_F16': ('VCC', 'F16', 'U32', None), 'V_CMPX_CLASS_F16': ('EXEC', 'F16', 'U32', None),
|
||||
'V_MAD_U64_U32': ('U64', 'U32', 'U32', 'U64'), 'V_MAD_I64_I32': ('I64', 'I32', 'I32', 'I64'),
|
||||
'V_QSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'), 'V_MQSAD_PK_U16_U8': ('B64', 'B64', 'B64', 'B64'),
|
||||
'V_MQSAD_U32_U8': ('B128', 'B64', 'B64', 'B128'),
|
||||
}
|
||||
def spec_regs(name: str) -> tuple[int, int, int, int]:
|
||||
name = name.upper()
|
||||
if name in _SPECIAL_REGS: return _SPECIAL_REGS[name]
|
||||
if 'SAD' in name and 'U8' in name and 'QSAD' not in name and 'MQSAD' not in name: return 1, 1, 1, 1
|
||||
dst_suf, src_suf = _suffix(name)
|
||||
return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1)
|
||||
def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]:
|
||||
name = name.upper()
|
||||
if name in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[name]
|
||||
if 'SAD' in name and ('U8' in name or 'U16' in name) and 'QSAD' not in name and 'MQSAD' not in name: return 'U32', 'U32', 'U32', 'U32'
|
||||
if '_CMP_' in name or '_CMPX_' in name:
|
||||
dst_suf, src_suf = _suffix(name)
|
||||
return 'EXEC' if '_CMPX_' in name else 'VCC', src_suf, src_suf, None
|
||||
dst_suf, src_suf = _suffix(name)
|
||||
return dst_suf, src_suf, src_suf, src_suf
|
||||
def spec_is_16bit(name: str) -> bool:
|
||||
name = name.upper()
|
||||
if 'SAD' in name or 'PACK' in name or '_PK_' in name or 'SAT_PK' in name or 'DOT2' in name: return False
|
||||
if '_F32' in name or '_I32' in name or '_U32' in name or '_B32' in name: return False # mixed ops like V_DOT2ACC_F32_F16
|
||||
return bool(re.search(r'_[FIUB]16(?:_|$)', name))
|
||||
def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper()))
|
||||
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
|
||||
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
|
||||
'MINMAX', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
|
||||
_2SRC = {'FMAC'} # FMAC uses dst as implicit accumulator, so only 2 explicit sources
|
||||
def spec_num_srcs(name: str) -> int:
|
||||
name = name.upper()
|
||||
if any(k in name for k in _2SRC): return 2
|
||||
return 3 if any(k in name for k in _3SRC) else 2
|
||||
def is_dtype_16(dt: str | None) -> bool: return dt is not None and '16' in dt
|
||||
def is_dtype_64(dt: str | None) -> bool: return dt is not None and '64' in dt
|
||||
|
||||
# Bit field DSL
|
||||
class BitField:
|
||||
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name, self._marker = hi, lo, name, None
|
||||
@@ -54,6 +129,10 @@ class BitField:
|
||||
val = unwrap(obj._values.get(self.name, 0))
|
||||
# Convert to IntEnum if marker is an IntEnum subclass
|
||||
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
|
||||
# VOP3 with VOPC opcodes (0-255) -> VOPCOp, VOP3SD opcodes -> VOP3SDOp
|
||||
if self.marker is VOP3Op:
|
||||
if val < 256: return VOPCOp(val)
|
||||
if val in Inst._VOP3SD_OPS: return VOP3SDOp(val)
|
||||
try: return self.marker(val)
|
||||
except ValueError: pass
|
||||
return val
|
||||
@@ -201,10 +280,13 @@ class Inst:
|
||||
# Track literal value if needed
|
||||
if encoded == 255 and self._literal is None:
|
||||
import struct
|
||||
is_64 = self._is_64bit_op()
|
||||
# Check if THIS source uses 64-bit encoding (not just src0)
|
||||
src_idx = {'src0': 0, 'src1': 1, 'src2': 2, 'ssrc0': 0, 'ssrc1': 1}.get(name, 0)
|
||||
src_regs = self.src_regs(src_idx)
|
||||
is_64 = src_regs == 2
|
||||
if isinstance(val, SrcMod) and not isinstance(val, Reg): lit32 = val.val & MASK32
|
||||
elif isinstance(val, int) and not isinstance(val, IntEnum): lit32 = val & MASK32
|
||||
elif isinstance(val, float): lit32 = _i32(val)
|
||||
elif isinstance(val, float): lit32 = (_i64(val) >> 32) if is_64 else _i32(val) # f64: high 32 bits of f64 repr
|
||||
else: return
|
||||
self._literal = (lit32 << 32) if is_64 else lit32
|
||||
|
||||
@@ -235,11 +317,21 @@ class Inst:
|
||||
raise ValueError(f"SOP1 {orig_args['op'].name} expects {expected} register(s) for {fld}, got {orig_args[fld].count}")
|
||||
|
||||
def __init__(self, *args, literal: int | None = None, **kwargs):
|
||||
self._values, self._literal = dict(self._defaults), literal
|
||||
self._values, self._literal = dict(self._defaults), None
|
||||
field_names = [n for n in self._fields if n != 'encoding']
|
||||
orig_args = dict(zip(field_names, args)) | kwargs
|
||||
self._values.update(orig_args)
|
||||
self._validate(orig_args)
|
||||
# Pre-shift literal for 64-bit sources (literal param is always raw 32-bit value from user)
|
||||
if literal is not None:
|
||||
# Find which source uses the literal (255) and check its register count
|
||||
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
|
||||
v = orig_args.get(n)
|
||||
if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255):
|
||||
self._literal = (literal << 32) if self.src_regs(idx) == 2 else literal
|
||||
break
|
||||
else:
|
||||
self._literal = literal # fallback if no literal source found
|
||||
cls_name = self.__class__.__name__
|
||||
|
||||
# Format-specific setup
|
||||
@@ -297,11 +389,9 @@ class Inst:
|
||||
op_name = op.name if hasattr(op, 'name') else None
|
||||
# Look up op name from int if needed (happens in from_bytes path)
|
||||
if op_name is None and self.__class__.__name__ == 'VOP3':
|
||||
from extra.assembly.amd.autogen.rdna3.ins import VOP3Op
|
||||
try: op_name = VOP3Op(op).name
|
||||
except ValueError: pass
|
||||
if op_name is None and self.__class__.__name__ == 'VOPC':
|
||||
from extra.assembly.amd.autogen.rdna3.ins import VOPCOp
|
||||
try: op_name = VOPCOp(op).name
|
||||
except ValueError: pass
|
||||
if op_name is None: return False
|
||||
@@ -312,8 +402,16 @@ class Inst:
|
||||
result = self.to_int().to_bytes(self._size(), 'little')
|
||||
lit = self._get_literal() or getattr(self, '_literal', None)
|
||||
if lit is None: return result
|
||||
# For 64-bit ops, literal is stored in high 32 bits internally, but encoded as 4 bytes
|
||||
lit32 = (lit >> 32) if self._is_64bit_op() else lit
|
||||
# For 64-bit sources, literal is stored in high 32 bits internally, but encoded as 4 bytes
|
||||
# Find which source uses the literal (255) and check its register count
|
||||
lit_src_is_64 = False
|
||||
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2), ('ssrc0', 0), ('ssrc1', 1)]:
|
||||
if n not in self._values: continue
|
||||
v = self._values[n]
|
||||
if (isinstance(v, RawImm) and v.val == 255) or (isinstance(v, int) and v == 255):
|
||||
lit_src_is_64 = self.is_src_64(idx)
|
||||
break
|
||||
lit32 = (lit >> 32) if lit_src_is_64 else lit
|
||||
return result + (lit32 & MASK32).to_bytes(4, 'little')
|
||||
|
||||
@classmethod
|
||||
@@ -343,9 +441,16 @@ class Inst:
|
||||
if has_literal:
|
||||
# For 64-bit ops, the literal is 32 bits placed in the HIGH 32 bits of the 64-bit value
|
||||
# (low 32 bits are zero). This is how AMD hardware interprets 32-bit literals for 64-bit ops.
|
||||
# Check which source uses the literal and whether THAT source is 64-bit
|
||||
if len(data) >= cls._size() + 4:
|
||||
lit32 = int.from_bytes(data[cls._size():cls._size()+4], 'little')
|
||||
inst._literal = (lit32 << 32) if inst._is_64bit_op() else lit32
|
||||
# Find which source has literal (255) and check its register count
|
||||
lit_src_is_64 = False
|
||||
for n, idx in [('src0', 0), ('src1', 1), ('src2', 2)]:
|
||||
if n in inst._values and isinstance(inst._values[n], RawImm) and inst._values[n].val == 255:
|
||||
lit_src_is_64 = inst.src_regs(idx) == 2
|
||||
break
|
||||
inst._literal = (lit32 << 32) if lit_src_is_64 else lit32
|
||||
return inst
|
||||
|
||||
def __repr__(self):
|
||||
@@ -360,7 +465,9 @@ class Inst:
|
||||
if name.startswith('_'): raise AttributeError(name)
|
||||
return unwrap(self._values.get(name, 0))
|
||||
|
||||
def lit(self, v: int) -> str: return f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
|
||||
def lit(self, v: int, neg: bool = False) -> str:
|
||||
s = f"0x{self._literal:x}" if v == 255 and self._literal else decode_src(v)
|
||||
return f"-{s}" if neg else s
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Inst): return NotImplemented
|
||||
@@ -372,5 +479,37 @@ class Inst:
|
||||
from extra.assembly.amd.asm import disasm
|
||||
return disasm(self)
|
||||
|
||||
_enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp,
|
||||
'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp,
|
||||
'SMEM': SMEMOp, 'DS': DSOp, 'FLAT': FLATOp, 'MUBUF': MUBUFOp, 'MTBUF': MTBUFOp, 'MIMG': MIMGOp,
|
||||
'VOPD': VOPDOp, 'VINTERP': VINTERPOp}
|
||||
_VOP3SD_OPS = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
"""Return the op as an enum (e.g., VOP1Op.V_MOV_B32). VOP3 returns VOPCOp/VOP3SDOp for those op ranges."""
|
||||
val = self._values.get('op')
|
||||
if val is None: return None
|
||||
if hasattr(val, 'name'): return val # already an enum
|
||||
cls_name = self.__class__.__name__
|
||||
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
|
||||
return self._enum_map[cls_name](val)
|
||||
|
||||
@property
|
||||
def op_name(self) -> str:
|
||||
op = self.op
|
||||
return op.name if hasattr(op, 'name') else ''
|
||||
|
||||
def dst_regs(self) -> int: return spec_regs(self.op_name)[0]
|
||||
def src_regs(self, n: int) -> int: return spec_regs(self.op_name)[n + 1]
|
||||
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)
|
||||
def dst_dtype(self) -> str | None: return spec_dtype(self.op_name)[0]
|
||||
def src_dtype(self, n: int) -> str | None: return spec_dtype(self.op_name)[n + 1]
|
||||
def is_src_16(self, n: int) -> bool: return self.src_regs(n) == 1 and is_dtype_16(self.src_dtype(n))
|
||||
def is_src_64(self, n: int) -> bool: return self.src_regs(n) == 2
|
||||
def is_16bit(self) -> bool: return spec_is_16bit(self.op_name)
|
||||
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
|
||||
def is_dst_16(self) -> bool: return self.dst_regs() == 1 and is_dtype_16(self.dst_dtype())
|
||||
|
||||
class Inst32(Inst): pass
|
||||
class Inst64(Inst): pass
|
||||
|
||||
@@ -12,34 +12,6 @@ Program = dict[int, Inst]
|
||||
WAVE_SIZE, SGPR_COUNT, VGPR_COUNT = 32, 128, 256
|
||||
VCC_LO, VCC_HI, NULL, EXEC_LO, EXEC_HI, SCC = SrcEnum.VCC_LO, SrcEnum.VCC_HI, SrcEnum.NULL, SrcEnum.EXEC_LO, SrcEnum.EXEC_HI, SrcEnum.SCC
|
||||
|
||||
# Op classification helpers - build sets from op name patterns
|
||||
def _ops_matching(enum, *patterns, exclude=()): return {op for op in enum if any(p in op.name for p in patterns) and not any(e in op.name for e in exclude)}
|
||||
def _ops_ending(enum, *suffixes): return {op for op in enum if op.name.endswith(suffixes)}
|
||||
|
||||
# 64-bit ops (for literal handling)
|
||||
_VOP3_64BIT_OPS = {op.value for op in _ops_ending(VOP3Op, '_F64', '_B64', '_I64', '_U64')}
|
||||
_VOPC_64BIT_OPS = {op.value for op in _ops_ending(VOPCOp, '_F64', '_B64', '_I64', '_U64')}
|
||||
_VOP3_64BIT_OPS_32BIT_SRC1 = {VOP3Op.V_LDEXP_F64.value} # src1 is 32-bit exponent
|
||||
|
||||
# 16-bit ops (SAD/MSAD excluded - they use 32-bit packed sources)
|
||||
_VOP3_16BIT_OPS = _ops_matching(VOP3Op, '_F16', '_B16', '_I16', '_U16', exclude=('SAD',))
|
||||
_VOP1_16BIT_OPS = _ops_matching(VOP1Op, '_F16', '_B16', '_I16', '_U16')
|
||||
_VOP2_16BIT_OPS = _ops_matching(VOP2Op, '_F16', '_B16', '_I16', '_U16')
|
||||
_VOPC_16BIT_OPS = _ops_matching(VOPCOp, '_F16', '_B16', '_I16', '_U16')
|
||||
|
||||
# CVT ops with 32/64-bit source (despite 16-bit in name) - must end with the type suffix
|
||||
_CVT_32_64_SRC_OPS = {op for op in _ops_ending(VOP3Op, '_F32', '_I32', '_U32', '_F64', '_I64', '_U64') if op.name.startswith('V_CVT_')} | \
|
||||
{op for op in _ops_ending(VOP1Op, '_F32', '_I32', '_U32', '_F64', '_I64', '_U64') if op.name.startswith('V_CVT_')}
|
||||
# CVT ops with 32-bit destination (FROM 16-bit TO 32-bit) - match patterns like F32_F16 in name
|
||||
_CVT_32_DST_OPS = _ops_matching(VOP3Op, 'F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16') | \
|
||||
_ops_matching(VOP1Op, 'F32_F16', 'I32_I16', 'U32_U16', 'I32_F16', 'U32_F16')
|
||||
|
||||
# 16-bit dst ops (PACK has 32-bit dst, CVT to 32-bit has 32-bit dst)
|
||||
_VOP3_16BIT_DST_OPS = {op for op in _VOP3_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
|
||||
_VOP1_16BIT_DST_OPS = {op for op in _VOP1_16BIT_OPS if 'PACK' not in op.name} - _CVT_32_DST_OPS
|
||||
_VOP1_16BIT_SRC_OPS = _VOP1_16BIT_OPS - _CVT_32_64_SRC_OPS
|
||||
|
||||
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
|
||||
# Inline constants for src operands 128-254. Build tables for f32, f16, and f64 formats.
|
||||
_FLOAT_CONSTS = {v: k for k, v in FLOAT_ENC.items()} | {248: 0.15915494309189535} # INV_2PI
|
||||
def _build_inline_consts(mask, to_bits):
|
||||
@@ -134,7 +106,7 @@ class WaveState:
|
||||
def rsrc_f16(self, v: int, lane: int) -> int: return self._rsrc_base(v, lane, _INLINE_CONSTS_F16)
|
||||
def rsrc64(self, v: int, lane: int) -> int:
|
||||
if 128 <= v < 255: return _INLINE_CONSTS_F64[v - 128]
|
||||
if v == 255: return self.literal
|
||||
if v == 255: return self.literal # literal is already shifted in from_bytes for 64-bit ops
|
||||
return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
|
||||
|
||||
def pend_sgpr_lane(self, reg: int, lane: int, val: int):
|
||||
@@ -155,20 +127,8 @@ def decode_program(data: bytes) -> Program:
|
||||
base_size = inst_class._size()
|
||||
# Pass enough data for potential 64-bit literal (base + 8 bytes max)
|
||||
inst = inst_class.from_bytes(data[i:i+base_size+8])
|
||||
for name, val in inst._values.items(): setattr(inst, name, unwrap(val))
|
||||
# from_bytes already handles literal reading - only need fallback for cases it doesn't handle
|
||||
if inst._literal is None:
|
||||
has_literal = any(getattr(inst, fld, None) == 255 for fld in ('src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'srcx0', 'srcy0')) or \
|
||||
(inst_class == VOP2 and inst.op in (44, 45, 55, 56)) or \
|
||||
(inst_class == VOPD and (inst.opx in (1, 2) or inst.opy in (1, 2))) or \
|
||||
(inst_class == SOP2 and inst.op in (69, 70))
|
||||
if has_literal:
|
||||
# For 64-bit ops, the 32-bit literal is placed in HIGH 32 bits (low 32 bits = 0)
|
||||
op_val = getattr(inst._values.get('op'), 'value', inst._values.get('op'))
|
||||
is_64bit = ((inst_class is VOP3 and op_val in _VOP3_64BIT_OPS) or (inst_class is VOPC and op_val in _VOPC_64BIT_OPS)) and \
|
||||
not (op_val in _VOP3_64BIT_OPS_32BIT_SRC1 and getattr(inst, 'src1', None) == 255)
|
||||
lit32 = int.from_bytes(data[i+base_size:i+base_size+4], 'little')
|
||||
inst._literal = (lit32 << 32) if is_64bit else lit32
|
||||
for name, val in inst._values.items():
|
||||
if name != 'op': setattr(inst, name, unwrap(val)) # skip op to preserve property access
|
||||
inst._words = inst.size() // 4
|
||||
result[i // 4] = inst
|
||||
i += inst._words * 4
|
||||
@@ -181,16 +141,14 @@ def decode_program(data: bytes) -> Program:
|
||||
def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
"""Execute scalar instruction. Returns PC delta or negative for special cases."""
|
||||
compiled = _get_compiled()
|
||||
inst_type = type(inst)
|
||||
|
||||
# SOPP: special cases for control flow that has no pseudocode
|
||||
if inst_type is SOPP:
|
||||
op = inst.op
|
||||
if op == SOPPOp.S_ENDPGM: return -1
|
||||
if op == SOPPOp.S_BARRIER: return -2
|
||||
if isinstance(inst, SOPP):
|
||||
if inst.op == SOPPOp.S_ENDPGM: return -1
|
||||
if inst.op == SOPPOp.S_BARRIER: return -2
|
||||
|
||||
# SMEM: memory loads (not ALU)
|
||||
if inst_type is SMEM:
|
||||
if isinstance(inst, SMEM):
|
||||
addr = st.rsgpr64(inst.sbase * 2) + _sext(inst.offset, 21)
|
||||
if inst.soffset not in (NULL, 0x7f): addr += st.rsrc(inst.soffset, 0)
|
||||
if (cnt := SMEM_LOAD.get(inst.op)) is None: raise NotImplementedError(f"SMEM op {inst.op}")
|
||||
@@ -198,34 +156,30 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
return 0
|
||||
|
||||
# Get op enum and lookup compiled function
|
||||
if inst_type is SOP1: op_cls, ssrc0, sdst = SOP1Op, inst.ssrc0, inst.sdst
|
||||
elif inst_type is SOP2: op_cls, ssrc0, sdst = SOP2Op, inst.ssrc0, inst.sdst
|
||||
elif inst_type is SOPC: op_cls, ssrc0, sdst = SOPCOp, inst.ssrc0, None
|
||||
elif inst_type is SOPK: op_cls, ssrc0, sdst = SOPKOp, inst.sdst, inst.sdst # sdst is both src and dst
|
||||
elif inst_type is SOPP: op_cls, ssrc0, sdst = SOPPOp, None, None
|
||||
else: raise NotImplementedError(f"Unknown scalar type {inst_type}")
|
||||
if isinstance(inst, SOP1): ssrc0, sdst = inst.ssrc0, inst.sdst
|
||||
elif isinstance(inst, SOP2): ssrc0, sdst = inst.ssrc0, inst.sdst
|
||||
elif isinstance(inst, SOPC): ssrc0, sdst = inst.ssrc0, None
|
||||
elif isinstance(inst, SOPK): ssrc0, sdst = inst.sdst, inst.sdst # sdst is both src and dst
|
||||
elif isinstance(inst, SOPP): ssrc0, sdst = None, None
|
||||
else: raise NotImplementedError(f"Unknown scalar type {type(inst)}")
|
||||
|
||||
# SOPP has gaps in the opcode enum - treat unknown opcodes as no-ops
|
||||
try: op = op_cls(inst.op)
|
||||
try: op = inst.op
|
||||
except ValueError:
|
||||
if inst_type is SOPP: return 0
|
||||
if isinstance(inst, SOPP): return 0
|
||||
raise
|
||||
fn = compiled.get(op_cls, {}).get(op)
|
||||
fn = compiled.get(type(op), {}).get(op)
|
||||
if fn is None:
|
||||
# SOPP instructions without pseudocode (waits, hints, nops) are no-ops
|
||||
if inst_type is SOPP: return 0
|
||||
if isinstance(inst, SOPP): return 0
|
||||
raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
|
||||
# Build context - handle 64-bit ops that need 64-bit source reads
|
||||
# 64-bit source ops: name ends with _B64, _I64, _U64 or contains _U64, _I64 before last underscore
|
||||
is_64bit_s0 = op.name.endswith(('_B64', '_I64', '_U64')) or '_U64_' in op.name or '_I64_' in op.name
|
||||
is_64bit_s0s1 = op_cls is SOPCOp and op in (SOPCOp.S_CMP_EQ_U64, SOPCOp.S_CMP_LG_U64)
|
||||
s0 = st.rsrc64(ssrc0, 0) if is_64bit_s0 or is_64bit_s0s1 else (st.rsrc(ssrc0, 0) if inst_type not in (SOPK, SOPP) else (st.rsgpr(inst.sdst) if inst_type is SOPK else 0))
|
||||
is_64bit_sop2 = is_64bit_s0 and inst_type is SOP2
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if (is_64bit_sop2 or is_64bit_s0s1) else (st.rsrc(inst.ssrc1, 0) if inst_type in (SOP2, SOPC) else inst.simm16 if inst_type is SOPK else 0)
|
||||
d0 = st.rsgpr64(sdst) if (is_64bit_s0 or is_64bit_s0s1) and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
# Build context - use inst methods to determine operand sizes
|
||||
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
|
||||
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
exec_mask = st.exec_mask
|
||||
literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal
|
||||
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
|
||||
|
||||
# Execute compiled function - pass PC in bytes for instructions that need it
|
||||
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
|
||||
@@ -248,10 +202,10 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = None) -> None:
|
||||
"""Execute vector instruction for one lane."""
|
||||
compiled = _get_compiled()
|
||||
inst_type, V = type(inst), st.vgpr[lane]
|
||||
V = st.vgpr[lane]
|
||||
|
||||
# Memory ops (not ALU pseudocode)
|
||||
if inst_type is FLAT:
|
||||
if isinstance(inst, FLAT):
|
||||
op, addr_reg, data_reg, vdst, offset, saddr = inst.op, inst.addr, inst.data, inst.vdst, _sext(inst.offset, 13), inst.saddr
|
||||
addr = V[addr_reg] | (V[addr_reg+1] << 32)
|
||||
addr = (st.rsgpr64(saddr) + V[addr_reg] + offset) & MASK64 if saddr not in (NULL, 0x7f) else (addr + offset) & MASK64
|
||||
@@ -272,7 +226,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
else: raise NotImplementedError(f"FLAT op {op}")
|
||||
return
|
||||
|
||||
if inst_type is DS:
|
||||
if isinstance(inst, DS):
|
||||
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
|
||||
if op in DS_LOAD:
|
||||
cnt, sz, sign = DS_LOAD[op]
|
||||
@@ -302,7 +256,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
return
|
||||
|
||||
# VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)
|
||||
if inst_type is VOPD:
|
||||
if isinstance(inst, VOPD):
|
||||
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
||||
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
|
||||
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
|
||||
@@ -312,18 +266,15 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
return
|
||||
|
||||
# VOP3SD: has extra scalar dest for carry output
|
||||
if inst_type is VOP3SD:
|
||||
op = VOP3SDOp(inst.op)
|
||||
fn = compiled.get(VOP3SDOp, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
# Source sizes vary: DIV_SCALE=all 64-bit, MAD64=32/32/64, others=32-bit
|
||||
r64 = op == VOP3SDOp.V_DIV_SCALE_F64
|
||||
s0, s1 = (st.rsrc64 if r64 else st.rsrc)(inst.src0, lane), (st.rsrc64 if r64 else st.rsrc)(inst.src1, lane)
|
||||
mad64 = op in (VOP3SDOp.V_MAD_U64_U32, VOP3SDOp.V_MAD_I64_I32)
|
||||
s2 = st.rsrc64(inst.src2, lane) if r64 else ((V[inst.src2-256]|(V[inst.src2-255]<<32)) if inst.src2>=256 else st.rsgpr64(inst.src2)) if mad64 else st.rsrc(inst.src2, lane)
|
||||
if isinstance(inst, VOP3SD):
|
||||
fn = compiled.get(VOP3SDOp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
# Read sources based on register counts from inst properties
|
||||
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
|
||||
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
|
||||
# Carry-in ops use src2 as carry bitmask instead of VCC
|
||||
carry_ops = (VOP3SDOp.V_ADD_CO_CI_U32, VOP3SDOp.V_SUB_CO_CI_U32, VOP3SDOp.V_SUBREV_CO_CI_U32)
|
||||
result = fn(s0, s1, s2, V[inst.vdst], st.scc, st.rsgpr64(inst.src2) if op in carry_ops else st.vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
|
||||
result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
V[inst.vdst] = result['d0'] & MASK32
|
||||
if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32
|
||||
if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
|
||||
@@ -332,35 +283,31 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
# Get op enum and sources (None means "no source" for that operand)
|
||||
# dst_hi: for VOP1/VOP2 16-bit dst ops, bit 7 of vdst indicates .h (high 16-bit) destination
|
||||
dst_hi = False
|
||||
if inst_type is VOP1:
|
||||
if isinstance(inst, VOP1):
|
||||
if inst.op == VOP1Op.V_NOP: return
|
||||
op_cls, op, src0, src1, src2 = VOP1Op, VOP1Op(inst.op), inst.src0, None, None
|
||||
dst_hi, vdst = (inst.vdst & 0x80) != 0 and op in _VOP1_16BIT_DST_OPS, inst.vdst & 0x7f if op in _VOP1_16BIT_DST_OPS else inst.vdst
|
||||
elif inst_type is VOP2:
|
||||
op_cls, op, src0, src1, src2 = VOP2Op, VOP2Op(inst.op), inst.src0, inst.vsrc1 + 256, None
|
||||
dst_hi, vdst = (inst.vdst & 0x80) != 0 and op in _VOP2_16BIT_OPS, inst.vdst & 0x7f if op in _VOP2_16BIT_OPS else inst.vdst
|
||||
elif inst_type is VOP3:
|
||||
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 (use VOPCOp pseudocode)
|
||||
if inst.op < 256:
|
||||
op_cls, op, src0, src1, src2, vdst = VOPCOp, VOPCOp(inst.op), inst.src0, inst.src1, None, inst.vdst
|
||||
else:
|
||||
op_cls, op, src0, src1, src2, vdst = VOP3Op, VOP3Op(inst.op), inst.src0, inst.src1, inst.src2, inst.vdst
|
||||
elif inst_type is VOPC:
|
||||
op = VOPCOp(inst.op)
|
||||
src0, src1, src2 = inst.src0, None, None
|
||||
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16()
|
||||
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
|
||||
elif isinstance(inst, VOP2):
|
||||
src0, src1, src2 = inst.src0, inst.vsrc1 + 256, None
|
||||
dst_hi = (inst.vdst & 0x80) != 0 and inst.is_dst_16()
|
||||
vdst = inst.vdst & 0x7f if inst.is_dst_16() else inst.vdst
|
||||
elif isinstance(inst, VOP3):
|
||||
# VOP3 ops 0-255 are VOPC comparisons encoded as VOP3 - inst.op returns VOPCOp for these
|
||||
src0, src1, src2, vdst = inst.src0, inst.src1, (None if inst.op.value < 256 else inst.src2), inst.vdst
|
||||
elif isinstance(inst, VOPC):
|
||||
# For 16-bit VOPC, vsrc1 uses same encoding as VOP2 16-bit: bit 7 selects hi(1) or lo(0) half
|
||||
# vsrc1 field is 8 bits: [6:0] = VGPR index, [7] = hi flag
|
||||
src1 = inst.vsrc1 + 256 # convert to standard VGPR encoding (256 + vgpr_idx)
|
||||
op_cls, src0, src2, vdst = VOPCOp, inst.src0, None, VCC_LO
|
||||
elif inst_type is VOP3P:
|
||||
src0, src1, src2, vdst = inst.src0, inst.vsrc1 + 256, None, VCC_LO
|
||||
elif isinstance(inst, VOP3P):
|
||||
# VOP3P: Packed 16-bit operations using compiled functions
|
||||
op = VOP3POp(inst.op)
|
||||
# WMMA: wave-level matrix multiply-accumulate (special handling - needs cross-lane access)
|
||||
if op in (VOP3POp.V_WMMA_F32_16X16X16_F16, VOP3POp.V_WMMA_F32_16X16X16_BF16, VOP3POp.V_WMMA_F16_16X16X16_F16):
|
||||
if 'WMMA' in inst.op_name:
|
||||
if lane == 0: # Only execute once per wave, write results for all lanes
|
||||
exec_wmma(st, inst, op)
|
||||
exec_wmma(st, inst, inst.op)
|
||||
return
|
||||
# V_FMA_MIX: Mixed precision FMA - opsel_hi controls f32(0) vs f16(1), opsel selects which f16 half
|
||||
if op in (VOP3POp.V_FMA_MIX_F32, VOP3POp.V_FMA_MIXLO_F16, VOP3POp.V_FMA_MIXHI_F16):
|
||||
if 'FMA_MIX' in inst.op_name:
|
||||
opsel, opsel_hi, opsel_hi2 = getattr(inst, 'opsel', 0), getattr(inst, 'opsel_hi', 0), getattr(inst, 'opsel_hi2', 0)
|
||||
neg, abs_ = getattr(inst, 'neg', 0), getattr(inst, 'neg_hi', 0) # neg_hi reused as abs
|
||||
raws = [st.rsrc(inst.src0, lane), st.rsrc(inst.src1, lane), st.rsrc(inst.src2, lane) if inst.src2 is not None else 0]
|
||||
@@ -371,7 +318,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
if neg & (1<<i): srcs[i] = -srcs[i]
|
||||
result = srcs[0] * srcs[1] + srcs[2]
|
||||
V = st.vgpr[lane]
|
||||
V[inst.vdst] = _i32(result) if op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
V[inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
return
|
||||
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
|
||||
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
|
||||
@@ -380,80 +327,52 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
|
||||
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
|
||||
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
|
||||
fn = compiled.get(VOP3POp, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
fn = compiled.get(VOP3POp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
st.vgpr[lane][inst.vdst] = fn(srcs[0], srcs[1], srcs[2], 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'] & MASK32
|
||||
return
|
||||
else: raise NotImplementedError(f"Unknown vector type {inst_type}")
|
||||
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
|
||||
|
||||
fn = compiled.get(op_cls, {}).get(op)
|
||||
if fn is None: raise NotImplementedError(f"{op.name} not in pseudocode")
|
||||
op_cls = type(inst.op)
|
||||
fn = compiled.get(op_cls, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
|
||||
# Read sources (with VOP3 modifiers if applicable)
|
||||
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if inst_type is VOP3 else (0, 0)
|
||||
opsel = getattr(inst, 'opsel', 0) if inst_type is VOP3 else 0
|
||||
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
|
||||
opsel = getattr(inst, 'opsel', 0) if isinstance(inst, VOP3) else 0
|
||||
def mod_src(val: int, idx: int, is64=False) -> int:
|
||||
to_f, to_i = (_f64, _i64) if is64 else (_f32, _i32)
|
||||
if (abs_ >> idx) & 1: val = to_i(abs(to_f(val)))
|
||||
if (neg >> idx) & 1: val = to_i(-to_f(val))
|
||||
return val
|
||||
|
||||
# Determine if sources are 64-bit based on instruction type
|
||||
# For 64-bit shift ops: src0 is 32-bit (shift amount), src1 is 64-bit (value to shift)
|
||||
# For most other _B64/_I64/_U64/_F64 ops: all sources are 64-bit
|
||||
is_64bit_op = op.name.endswith(('_B64', '_I64', '_U64', '_F64'))
|
||||
# V_LDEXP_F64, V_TRIG_PREOP_F64, V_CMP_CLASS_F64, V_CMPX_CLASS_F64: src0 is 64-bit, src1 is 32-bit
|
||||
is_ldexp_64 = op in (VOP3Op.V_LDEXP_F64, VOP3Op.V_TRIG_PREOP_F64, VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64,
|
||||
VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
|
||||
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
|
||||
# 16-bit source ops: use precomputed sets instead of string checks
|
||||
# Note: must check op_cls to avoid cross-enum value collisions
|
||||
# VOP3-encoded VOPC 16-bit ops also use opsel (not VGPR bit 7 like non-VOP3 VOPC)
|
||||
is_16bit_src = (op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS) or \
|
||||
(inst_type is VOP3 and op_cls is VOPCOp and op in _VOPC_16BIT_OPS)
|
||||
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
|
||||
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS
|
||||
# Use inst methods to determine operand sizes (inst.is_src_16, inst.is_src_64, etc.)
|
||||
is_vop2_16bit = isinstance(inst, VOP2) and inst.is_16bit()
|
||||
|
||||
if is_shift_64:
|
||||
s0, s1 = mod_src(st.rsrc(src0, lane), 0), st.rsrc64(src1, lane) if src1 else 0
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_ldexp_64:
|
||||
s0 = mod_src(st.rsrc64(src0, lane), 0, is64=True)
|
||||
s1_raw = st.rsrc(src1, lane) if src1 is not None else 0
|
||||
is_class_op = op in (VOP3Op.V_CMP_CLASS_F64, VOP3Op.V_CMPX_CLASS_F64, VOPCOp.V_CMP_CLASS_F64, VOPCOp.V_CMPX_CLASS_F64)
|
||||
s1, s2 = mod_src((s1_raw >> 32) if src1 == 255 and is_class_op else s1_raw, 1), mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
elif is_64bit_op:
|
||||
s0, s1 = mod_src(st.rsrc64(src0, lane), 0, is64=True), mod_src(st.rsrc64(src1, lane), 1, is64=True) if src1 is not None else 0
|
||||
s2 = mod_src(st.rsrc64(src2, lane), 2, is64=True) if src2 is not None else 0
|
||||
elif is_16bit_src:
|
||||
# VOP3 16-bit ops: opsel bits select which half, abs/neg as f16 bit ops
|
||||
def rsrc_16bit(src, idx):
|
||||
if src is None: return 0
|
||||
# Read sources based on register counts and dtypes from inst properties
|
||||
def read_src(src, idx, regs, is_src_16):
|
||||
if src is None: return 0
|
||||
if regs == 2: return mod_src(st.rsrc64(src, lane), idx, is64=True)
|
||||
if is_src_16 and isinstance(inst, VOP3):
|
||||
raw = st.rsrc_f16(src, lane) if 128 <= src < 255 else st.rsrc(src, lane)
|
||||
val = _src16(raw, bool(opsel & (1 << idx)))
|
||||
if abs_ & (1 << idx): val &= 0x7fff
|
||||
if neg & (1 << idx): val ^= 0x8000
|
||||
return val
|
||||
s0, s1, s2 = rsrc_16bit(src0, 0), rsrc_16bit(src1, 1), rsrc_16bit(src2, 2)
|
||||
elif is_vop2_16bit or (op_cls is VOP1Op and op in _VOP1_16BIT_SRC_OPS) or (op_cls is VOPCOp and op in _VOPC_16BIT_OPS):
|
||||
# VOP1/VOP2/VOPC 16-bit ops: VGPRs use bit 7 for hi/lo, non-VGPRs use f16 inline consts
|
||||
# Special case: VOPC V_CMP_CLASS uses full 32-bit mask for src1 when non-VGPR
|
||||
def rsrc16_vgpr(src, idx, full32=False):
|
||||
if src is None: return 0
|
||||
if is_src_16 and isinstance(inst, (VOP1, VOP2, VOPC)):
|
||||
if src >= 256: return _src16(mod_src(st.rsrc(_vgpr_masked(src), lane), idx), _vgpr_hi(src))
|
||||
return mod_src(st.rsrc(src, lane) if full32 else st.rsrc_f16(src, lane), idx) & (0xffffffff if full32 else 0xffff)
|
||||
s0, s1 = rsrc16_vgpr(src0, 0), rsrc16_vgpr(src1, 1, full32=op_cls is VOPCOp)
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
else:
|
||||
s0 = mod_src(st.rsrc(src0, lane), 0)
|
||||
s1 = mod_src(st.rsrc(src1, lane), 1) if src1 is not None else 0
|
||||
s2 = mod_src(st.rsrc(src2, lane), 2) if src2 is not None else 0
|
||||
return mod_src(st.rsrc_f16(src, lane), idx) & 0xffff
|
||||
return mod_src(st.rsrc(src, lane), idx)
|
||||
|
||||
s0 = read_src(src0, 0, inst.src_regs(0), inst.is_src_16(0))
|
||||
s1 = read_src(src1, 1, inst.src_regs(1), inst.is_src_16(1)) if src1 is not None else 0
|
||||
s2 = read_src(src2, 2, inst.src_regs(2), inst.is_src_16(2)) if src2 is not None else 0
|
||||
# Read destination (accumulator for VOP2 f16, 64-bit for 64-bit ops)
|
||||
d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if is_64bit_op else V[vdst]
|
||||
d0 = _src16(V[vdst], dst_hi) if is_vop2_16bit else (V[vdst] | (V[vdst + 1] << 32)) if inst.dst_regs() == 2 else V[vdst]
|
||||
|
||||
# V_CNDMASK_B32/B16: VOP3 encoding uses src2 as mask (not VCC); VOP2 uses VCC implicitly
|
||||
# Pass the correct mask as vcc to the function so pseudocode VCC.u64[laneId] works correctly
|
||||
vcc_for_fn = st.rsgpr64(src2) if op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and inst_type is VOP3 and src2 is not None and src2 < 256 else st.vcc
|
||||
vcc_for_fn = st.rsgpr64(src2) if inst.op in (VOP3Op.V_CNDMASK_B32, VOP3Op.V_CNDMASK_B16) and isinstance(inst, VOP3) and src2 is not None and src2 < 256 else st.vcc
|
||||
|
||||
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
|
||||
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
|
||||
@@ -467,17 +386,16 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
st.vgpr[wr_lane][wr_idx] = wr_val
|
||||
if 'vcc_lane' in result:
|
||||
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
|
||||
st.pend_sgpr_lane(VCC_LO if op_cls is VOP2Op and 'CO_CI' in op.name else vdst, lane, result['vcc_lane'])
|
||||
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane'])
|
||||
if 'exec_lane' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
|
||||
if 'd0' in result and op_cls not in (VOPCOp,) and 'vgpr_write' not in result:
|
||||
writes_to_sgpr = op in (VOP1Op.V_READFIRSTLANE_B32,) or (op_cls is VOP3Op and op in (VOP3Op.V_READFIRSTLANE_B32, VOP3Op.V_READLANE_B32))
|
||||
is_16bit_dst = (op_cls is VOP3Op and op in _VOP3_16BIT_DST_OPS) or (op_cls is VOP1Op and op in _VOP1_16BIT_DST_OPS) or is_vop2_16bit
|
||||
if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result:
|
||||
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
|
||||
d0_val = result['d0']
|
||||
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
|
||||
elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif is_16bit_dst: V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if inst_type is VOP3 else dst_hi)
|
||||
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
|
||||
else: V[vdst] = d0_val & MASK32
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
@@ -506,22 +424,19 @@ def exec_wmma(st: WaveState, inst, op: VOP3POp) -> None:
|
||||
# MAIN EXECUTION LOOP
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
SCALAR_TYPES = {SOP1, SOP2, SOPC, SOPK, SOPP, SMEM}
|
||||
VECTOR_TYPES = {VOP1, VOP2, VOP3, VOP3SD, VOPC, FLAT, DS, VOPD, VOP3P}
|
||||
|
||||
def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) -> int:
|
||||
inst = program.get(st.pc)
|
||||
if inst is None: return 1
|
||||
inst_words, st.literal, inst_type = inst._words, getattr(inst, '_literal', None) or 0, type(inst)
|
||||
inst_words, st.literal = inst._words, getattr(inst, '_literal', None) or 0
|
||||
|
||||
if inst_type in SCALAR_TYPES:
|
||||
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)):
|
||||
delta = exec_scalar(st, inst)
|
||||
if delta == -1: return -1 # endpgm
|
||||
if delta == -2: st.pc += inst_words; return -2 # barrier
|
||||
st.pc += inst_words + delta
|
||||
else:
|
||||
# V_READFIRSTLANE/V_READLANE write to SGPR, execute once; others execute per-lane with exec_mask
|
||||
is_readlane = inst_type in (VOP1, VOP3) and hasattr(inst.op, 'name') and 'READLANE' in inst.op.name
|
||||
is_readlane = isinstance(inst, (VOP1, VOP3)) and ('READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name)
|
||||
exec_mask = 1 if is_readlane else st.exec_mask
|
||||
for lane in range(1 if is_readlane else n_lanes):
|
||||
if exec_mask & (1 << lane): exec_vector(st, inst, lane, lds)
|
||||
|
||||
@@ -3982,6 +3982,179 @@ class TestVOP3VOPC16Bit(unittest.TestCase):
|
||||
self.assertEqual(st.sgpr[0] & 1, 1, "hi>hi should be true: 0x9999>0x1234")
|
||||
|
||||
|
||||
class Test64BitLiteralSources(unittest.TestCase):
|
||||
"""Regression tests for 64-bit instruction literal source handling.
|
||||
|
||||
For f64 operations, a 32-bit literal in the instruction stream represents the
|
||||
HIGH 32 bits of the 64-bit value (low 32 bits are implicitly 0).
|
||||
|
||||
Bug: rsrc64() was returning the 32-bit literal as-is instead of shifting it
|
||||
left by 32 bits. This caused V_FMA_F64 and V_LDEXP_F64 to use wrong values
|
||||
when their source is a literal, breaking the f64->i64 conversion sequence.
|
||||
|
||||
The f64->i64 conversion sequence is:
|
||||
v_trunc_f64 -> v_ldexp_f64 (by -32) -> v_floor_f64 -> v_fma_f64 (by -2^32)
|
||||
-> v_cvt_u32_f64 (low bits) -> v_cvt_i32_f64 (high bits)
|
||||
|
||||
The V_FMA_F64 uses literal 0xC1F00000 which is the high 32 bits of f64 -2^32.
|
||||
"""
|
||||
|
||||
def test_v_fma_f64_literal_neg_2pow32(self):
|
||||
"""V_FMA_F64 with literal encoding of -2^32.
|
||||
|
||||
The f64 value -2^32 (-4294967296.0) has bits 0xC1F0000000000000.
|
||||
The compiler encodes only the high 32 bits (0xC1F00000) as a literal.
|
||||
The emulator must interpret this as 0xC1F00000_00000000.
|
||||
"""
|
||||
# v[0:1] = -41.0 (trunc), v[2:3] = -1.0 (floor of -41/2^32)
|
||||
# FMA: result = (-2^32) * (-1.0) + (-41.0) = 4294967296 - 41 = 4294967255.0
|
||||
val_41 = f2i64(-41.0)
|
||||
val_m1 = f2i64(-1.0)
|
||||
# Literal 0xC1F00000 is high 32 bits of f64 -2^32
|
||||
lit = 0xC1F00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], val_41 & 0xffffffff),
|
||||
s_mov_b32(s[1], (val_41 >> 32) & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
s_mov_b32(s[2], val_m1 & 0xffffffff),
|
||||
s_mov_b32(s[3], (val_m1 >> 32) & 0xffffffff),
|
||||
v_mov_b32_e32(v[2], s[2]),
|
||||
v_mov_b32_e32(v[3], s[3]),
|
||||
# V_FMA_F64 v[4:5], literal, v[2:3], v[0:1]
|
||||
# = (-2^32) * (-1.0) + (-41.0) = 4294967255.0
|
||||
VOP3(VOP3Op.V_FMA_F64, vdst=v[4], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i642f(st.vgpr[0][4] | (st.vgpr[0][5] << 32))
|
||||
expected = 4294967255.0 # 2^32 - 41
|
||||
self.assertAlmostEqual(result, expected, places=0, msg=f"Expected {expected}, got {result}")
|
||||
|
||||
def test_v_ldexp_f64_literal_neg32(self):
|
||||
"""V_LDEXP_F64 with literal -32 for exponent.
|
||||
|
||||
V_LDEXP_F64 computes src0 * 2^src1 where src1 is an integer exponent.
|
||||
The literal 0xFFFFFFE0 represents -32 as a 32-bit signed integer.
|
||||
For V_LDEXP_F64, src1 is 32-bit (not 64-bit), so this is correct as-is.
|
||||
"""
|
||||
val = f2i64(-41.0)
|
||||
expected = -41.0 * (2.0 ** -32) # -9.5367431640625e-09
|
||||
instructions = [
|
||||
s_mov_b32(s[0], val & 0xffffffff),
|
||||
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
# V_LDEXP_F64 v[2:3], v[0:1], -32
|
||||
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
result = i642f(st.vgpr[0][2] | (st.vgpr[0][3] << 32))
|
||||
self.assertAlmostEqual(result, expected, places=15, msg=f"Expected {expected}, got {result}")
|
||||
|
||||
def test_f64_to_i64_full_sequence(self):
|
||||
"""Full f64->i64 conversion sequence with negative value.
|
||||
|
||||
This is the exact sequence generated by the compiler for (long)(-41.0):
|
||||
v_trunc_f64 v[0:1], v[0:1]
|
||||
v_ldexp_f64 v[2:3], v[0:1], -32
|
||||
v_floor_f64 v[2:3], v[2:3]
|
||||
v_fma_f64 v[0:1], 0xc1f00000, v[2:3], v[0:1] # -2^32
|
||||
v_cvt_u32_f64 v0, v[0:1]
|
||||
v_cvt_i32_f64 v1, v[2:3]
|
||||
|
||||
Result: v1:v0 = 0xFFFFFFFF:0xFFFFFFD7 = -41 as i64
|
||||
"""
|
||||
val = f2i64(-41.0)
|
||||
lit = 0xC1F00000 # high 32 bits of f64 -2^32
|
||||
instructions = [
|
||||
s_mov_b32(s[0], val & 0xffffffff),
|
||||
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_trunc_f64_e32(v[0:2], v[0:2]),
|
||||
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0), # -32
|
||||
v_floor_f64_e32(v[2:4], v[2:4]),
|
||||
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
|
||||
v_cvt_u32_f64_e32(v[4], v[0:2]),
|
||||
v_cvt_i32_f64_e32(v[5], v[2:4]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = st.vgpr[0][4]
|
||||
hi = st.vgpr[0][5]
|
||||
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
|
||||
self.assertEqual(result, -41, f"Expected -41, got {result} (lo=0x{lo:08x}, hi=0x{hi:08x})")
|
||||
|
||||
def test_f64_to_i64_large_negative(self):
|
||||
"""f64->i64 conversion with larger negative value (-1000000).
|
||||
|
||||
Tests that the conversion sequence works for values that span both
|
||||
high and low 32-bit parts of the result.
|
||||
"""
|
||||
val = f2i64(-1000000.0)
|
||||
lit = 0xC1F00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], val & 0xffffffff),
|
||||
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_trunc_f64_e32(v[0:2], v[0:2]),
|
||||
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
|
||||
v_floor_f64_e32(v[2:4], v[2:4]),
|
||||
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
|
||||
v_cvt_u32_f64_e32(v[4], v[0:2]),
|
||||
v_cvt_i32_f64_e32(v[5], v[2:4]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = st.vgpr[0][4]
|
||||
hi = st.vgpr[0][5]
|
||||
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
|
||||
self.assertEqual(result, -1000000, f"Expected -1000000, got {result}")
|
||||
|
||||
def test_f64_to_i64_positive(self):
|
||||
"""f64->i64 conversion with positive value (1000000)."""
|
||||
val = f2i64(1000000.0)
|
||||
lit = 0xC1F00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], val & 0xffffffff),
|
||||
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_trunc_f64_e32(v[0:2], v[0:2]),
|
||||
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
|
||||
v_floor_f64_e32(v[2:4], v[2:4]),
|
||||
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
|
||||
v_cvt_u32_f64_e32(v[4], v[0:2]),
|
||||
v_cvt_i32_f64_e32(v[5], v[2:4]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = st.vgpr[0][4]
|
||||
hi = st.vgpr[0][5]
|
||||
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
|
||||
self.assertEqual(result, 1000000, f"Expected 1000000, got {result}")
|
||||
|
||||
def test_f64_to_i64_large_positive(self):
|
||||
"""f64->i64 conversion with value > 2^32 (requires 64-bit result)."""
|
||||
val = f2i64(5000000000.0) # 5 billion, > 2^32
|
||||
lit = 0xC1F00000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], val & 0xffffffff),
|
||||
s_mov_b32(s[1], (val >> 32) & 0xffffffff),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
v_mov_b32_e32(v[1], s[1]),
|
||||
v_trunc_f64_e32(v[0:2], v[0:2]),
|
||||
v_ldexp_f64(v[2:4], v[0:2], 0xFFFFFFE0),
|
||||
v_floor_f64_e32(v[2:4], v[2:4]),
|
||||
VOP3(VOP3Op.V_FMA_F64, vdst=v[0], src0=RawImm(255), src1=v[2], src2=v[0], literal=lit),
|
||||
v_cvt_u32_f64_e32(v[4], v[0:2]),
|
||||
v_cvt_i32_f64_e32(v[5], v[2:4]),
|
||||
]
|
||||
st = run_program(instructions, n_lanes=1)
|
||||
lo = st.vgpr[0][4]
|
||||
hi = st.vgpr[0][5]
|
||||
result = struct.unpack('<q', struct.pack('<II', lo, hi))[0]
|
||||
self.assertEqual(result, 5000000000, f"Expected 5000000000, got {result}")
|
||||
|
||||
|
||||
class TestDS2Addr(unittest.TestCase):
|
||||
"""Regression tests for DS_LOAD_2ADDR and DS_STORE_2ADDR instructions.
|
||||
These ops use offset scaling: offset * sizeof(data) for address calculation.
|
||||
|
||||
Reference in New Issue
Block a user