simplify rdna3 asm (#13835)

* simplify rdna3 asm

* cleanups

* fix names

* fix tests

* fixes

* more test fixes

* type fixes

* tests pass + mypy passes

* 3.11 syntax
This commit is contained in:
George Hotz
2025-12-26 11:21:03 -05:00
committed by GitHub
parent c44b4f9ae0
commit e9f2aaba2a
11 changed files with 921 additions and 1090 deletions

View File

@@ -1,6 +1,7 @@
# Pure combinational ALU functions for RDNA3 emulation
from __future__ import annotations
import struct, math
from typing import Callable
from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, VOP1Op, VOP2Op, VOP3Op
# Format base offsets for unified opcode space
@@ -8,26 +9,25 @@ SOP2_BASE, SOP1_BASE, SOPC_BASE, SOPK_BASE = 0x000, 0x100, 0x200, 0x300
VOP2_BASE, VOP1_BASE = 0x100, 0x180
# Float conversion helpers
FLOAT_BITS = {240: 0x3f000000, 241: 0xbf000000, 242: 0x3f800000, 243: 0xbf800000, 244: 0x40000000, 245: 0xc0000000, 246: 0x40800000, 247: 0xc0800000, 248: 0x3e22f983}
_struct_I, _struct_f = struct.Struct('<I'), struct.Struct('<f')
def f32(i: int) -> float: return _struct_f.unpack(_struct_I.pack(i & 0xffffffff))[0]
_I, _f, _H, _e = struct.Struct('<I'), struct.Struct('<f'), struct.Struct('<H'), struct.Struct('<e')
def f32(i: int) -> float: return _f.unpack(_I.pack(i & 0xffffffff))[0]
def i32(f: float) -> int:
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
try: return _struct_I.unpack(_struct_f.pack(f))[0]
try: return _I.unpack(_f.pack(f))[0]
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
def f16(i: int) -> float: return struct.unpack('<e', struct.pack('<H', i & 0xffff))[0]
def f16(i: int) -> float: return _e.unpack(_H.pack(i & 0xffff))[0]
def i16(f: float) -> int:
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00
try: return struct.unpack('<H', struct.pack('<e', f))[0]
try: return _H.unpack(_e.pack(f))[0]
except (OverflowError, struct.error): return 0x7c00 if f > 0 else 0xfc00
def sext(v: int, b: int) -> int: return v - (1 << b) if v & (1 << (b-1)) else v
def clz(x: int) -> int: return 32 - x.bit_length() if x else 32
def cls(x: int) -> int: x &= 0xffffffff; return 31 if x in (0, 0xffffffff) else clz(~x & 0xffffffff if x >> 31 else x) - 1
def _cvt_i32_f32(v): return (0x7fffffff if v > 0 else (-0x80000000 & 0xffffffff)) if math.isinf(v) else (0 if math.isnan(v) else max(-0x80000000, min(0x7fffffff, int(v))) & 0xffffffff)
def _cvt_u32_f32(v): return 0xffffffff if math.isinf(v) and v > 0 else (0 if math.isinf(v) or math.isnan(v) or v < 0 else min(0xffffffff, int(v)))
def _cvt_i32_f32(v): return (0x7fffffff if v > 0 else 0x80000000) if math.isinf(v) else (0 if math.isnan(v) else max(-0x80000000, min(0x7fffffff, int(v))) & 0xffffffff)
def _cvt_u32_f32(v): return (0xffffffff if v > 0 else 0) if math.isinf(v) else (0 if math.isnan(v) or v < 0 else min(0xffffffff, int(v)))
# SALU: op -> fn(s0, s1, scc_in) -> (result, scc_out)
SALU: dict[int, callable] = {
SALU: dict[int, Callable] = {
# SOP2
SOP2_BASE + SOP2Op.S_ADD_U32: lambda a, b, scc: ((a + b) & 0xffffffff, int((a + b) >= 0x100000000)),
SOP2_BASE + SOP2Op.S_SUB_U32: lambda a, b, scc: ((a - b) & 0xffffffff, int(b > a)),
@@ -114,7 +114,7 @@ SALU: dict[int, callable] = {
}
# VALU: op -> fn(s0, s1, s2) -> result
VALU: dict[int, callable] = {
VALU: dict[int, Callable] = {
# VOP2
VOP2_BASE + VOP2Op.V_ADD_F32: lambda a, b, c: i32(f32(a) + f32(b)),
VOP2_BASE + VOP2Op.V_SUB_F32: lambda a, b, c: i32(f32(a) - f32(b)),
@@ -228,28 +228,22 @@ VALU: dict[int, callable] = {
VOP3Op.V_MED3_I16: lambda a, b, c: sorted([sext(a & 0xffff, 16), sext(b & 0xffff, 16), sext(c & 0xffff, 16)])[1] & 0xffff,
}
def _cmp8(a, b): return [False, a < b, a == b, a <= b, a > b, a != b, a >= b, True]
def _cmp6(a, b): return [a < b, a == b, a <= b, a > b, a != b, a >= b]
def vopc(op: int, s0: int, s1: int, s0_hi: int = 0, s1_hi: int = 0) -> int:
base = op & 0x7f
if 16 <= base <= 31: # F32
f0, f1, cmp, nan = f32(s0), f32(s1), base - 16, math.isnan(f32(s0)) or math.isnan(f32(s1))
return int([False, f0<f1, f0==f1, f0<=f1, f0>f1, f0!=f1, f0>=f1, not nan, nan, f0<f1 or nan, f0==f1 or nan, f0<=f1 or nan, f0>f1 or nan, f0!=f1 or nan, f0>=f1 or nan, True][cmp])
if 49 <= base <= 54: # I16
cmp, s0s, s1s = base - 49, sext(s0 & 0xffff, 16), sext(s1 & 0xffff, 16)
return int([s0s<s1s, s0s==s1s, s0s<=s1s, s0s>s1s, s0s!=s1s, s0s>=s1s][cmp])
if 57 <= base <= 62: # U16
cmp, s0u, s1u = base - 57, s0 & 0xffff, s1 & 0xffff
return int([s0u<s1u, s0u==s1u, s0u<=s1u, s0u>s1u, s0u!=s1u, s0u>=s1u][cmp])
if 49 <= base <= 54: return int(_cmp6(sext(s0 & 0xffff, 16), sext(s1 & 0xffff, 16))[base - 49]) # I16
if 57 <= base <= 62: return int(_cmp6(s0 & 0xffff, s1 & 0xffff)[base - 57]) # U16
if 64 <= base <= 79: # I32/U32
cmp, s0s, s1s = (base - 64) % 8, sext(s0, 32), sext(s1, 32)
return int([False, s0s<s1s, s0s==s1s, s0s<=s1s, s0s>s1s, s0s!=s1s, s0s>=s1s, True][cmp]) if base < 72 else int([False, s0<s1, s0==s1, s0<=s1, s0>s1, s0!=s1, s0>=s1, True][cmp])
cmp = (base - 64) % 8
return int(_cmp8(sext(s0, 32), sext(s1, 32))[cmp] if base < 72 else _cmp8(s0, s1)[cmp])
if 80 <= base <= 95: # I64/U64
cmp = (base - 80) % 8
s0_64, s1_64 = s0 | (s0_hi << 32), s1 | (s1_hi << 32)
if base < 88: # I64
s0s, s1s = sext(s0_64, 64), sext(s1_64, 64)
return int([False, s0s<s1s, s0s==s1s, s0s<=s1s, s0s>s1s, s0s!=s1s, s0s>=s1s, True][cmp])
else: # U64
return int([False, s0_64<s1_64, s0_64==s1_64, s0_64<=s1_64, s0_64>s1_64, s0_64!=s1_64, s0_64>=s1_64, True][cmp])
return int(_cmp8(sext(s0_64, 64), sext(s1_64, 64))[(base - 80) % 8] if base < 88 else _cmp8(s0_64, s1_64)[(base - 80) % 8])
if base == 126: # CLASS_F32
f, mask = f32(s0), s1
if math.isnan(f): return int(bool(mask & 0x3))

View File

@@ -1,11 +1,22 @@
# RDNA3 assembler and disassembler
from __future__ import annotations
import re
from extra.assembly.rdna3.lib import Inst, RawImm, Reg, SGPR, VGPR, TTMP, FLOAT_ENC, SRC_FIELDS, unwrap
from extra.assembly.rdna3.lib import Inst, RawImm, Reg, SGPR, VGPR, TTMP, s, v, ttmp, _RegFactory, FLOAT_ENC, SRC_FIELDS, unwrap
# Decoding helpers
SPECIAL_GPRS = {106: "vcc_lo", 107: "vcc_hi", 124: "null", 125: "m0", 126: "exec_lo", 127: "exec_hi", 253: "scc"}
SPECIAL_DEC = {**SPECIAL_GPRS, **{v: str(k) for k, v in FLOAT_ENC.items()}}
SPECIAL_PAIRS = {106: "vcc", 126: "exec"} # Special register pairs (for 64-bit ops)
# GFX11 hwreg names (IDs 16-17 are TBA - not supported, IDs 18-19 are PERF_SNAPSHOT)
HWREG_NAMES = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID', 5: 'HW_REG_GPR_ALLOC',
6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS', 15: 'HW_REG_SH_MEM_BASES', 18: 'HW_REG_PERF_SNAPSHOT_PC_LO',
19: 'HW_REG_PERF_SNAPSHOT_PC_HI', 20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI',
22: 'HW_REG_XNACK_MASK', 23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER', 28: 'HW_REG_IB_STS2'}
HWREG_IDS = {v.lower(): k for k, v in HWREG_NAMES.items()} # Reverse map for assembler
MSG_NAMES = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA',
131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
_16BIT_TYPES = ('f16', 'i16', 'u16', 'b16')
def _is_16bit(s: str) -> bool: return any(s.endswith(x) for x in _16BIT_TYPES)
def decode_src(val: int) -> str:
if val <= 105: return f"s{val}"
@@ -16,28 +27,39 @@ def decode_src(val: int) -> str:
if 256 <= val <= 511: return f"v{val - 256}"
return "lit" if val == 255 else f"?{val}"
def _sreg(base: int, cnt: int = 1) -> str: return f"s{base}" if cnt == 1 else f"s[{base}:{base+cnt-1}]"
def _vreg(base: int, cnt: int = 1) -> str: return f"v{base}" if cnt == 1 else f"v[{base}:{base+cnt-1}]"
def _reg(prefix: str, base: int, cnt: int = 1) -> str: return f"{prefix}{base}" if cnt == 1 else f"{prefix}[{base}:{base+cnt-1}]"
def _sreg(base: int, cnt: int = 1) -> str: return _reg("s", base, cnt)
def _vreg(base: int, cnt: int = 1) -> str: return _reg("v", base, cnt)
def _fmt_sdst(v: int, cnt: int = 1) -> str:
"""Format SGPR destination with special register names."""
if v == 124: return "null"
if 108 <= v <= 123: return f"ttmp[{v-108}:{v-108+cnt-1}]" if cnt > 1 else f"ttmp{v-108}"
if cnt > 1:
if v == 126 and cnt == 2: return "exec"
if v == 106 and cnt == 2: return "vcc"
return _sreg(v, cnt)
if 108 <= v <= 123: return _reg("ttmp", v - 108, cnt)
if cnt > 1 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
if cnt > 1: return _sreg(v, cnt)
return {126: "exec_lo", 127: "exec_hi", 106: "vcc_lo", 107: "vcc_hi", 125: "m0"}.get(v, f"s{v}")
def _fmt_ssrc(v: int, cnt: int = 1) -> str:
"""Format SGPR source with special register names and pairs."""
if cnt == 2:
if v == 126: return "exec"
if v == 106: return "vcc"
if v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
if v <= 105: return _sreg(v, 2)
if 108 <= v <= 123: return f"ttmp[{v-108}:{v-108+1}]"
if 108 <= v <= 123: return _reg("ttmp", v - 108, 2)
return decode_src(v)
def _fmt_src_n(v: int, cnt: int) -> str:
"""Format source with given register count (1, 2, or 4)."""
if cnt == 1: return decode_src(v)
if v >= 256: return _vreg(v - 256, cnt)
if v <= 105: return _sreg(v, cnt)
if cnt == 2 and v in SPECIAL_PAIRS: return SPECIAL_PAIRS[v]
if 108 <= v <= 123: return _reg("ttmp", v - 108, cnt)
return decode_src(v)
def _fmt_src64(v: int) -> str:
"""Format 64-bit source (VGPR pair, SGPR pair, or special pair)."""
return _fmt_src_n(v, 2)
def _parse_sop_sizes(op_name: str) -> tuple[int, ...]:
"""Parse dst and src sizes from SOP instruction name. Returns (dst_cnt, src0_cnt) or (dst_cnt, src0_cnt, src1_cnt)."""
if op_name in ('s_bitset0_b64', 's_bitset1_b64'): return (2, 1)
@@ -83,19 +105,15 @@ def disasm(inst: Inst) -> str:
if op_name == 'v_nop': return 'v_nop'
if op_name == 'v_pipeflush': return 'v_pipeflush'
parts = op_name.split('_')
is_16bit_dst = any(p in ('f16', 'i16', 'u16', 'b16') for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in ('f16', 'i16', 'u16', 'b16') and 'cvt' not in op_name)
is_16bit_src = parts[-1] in ('f16', 'i16', 'u16', 'b16') and 'sat_pk' not in op_name
is_f64_dst = op_name in ('v_ceil_f64', 'v_floor_f64', 'v_fract_f64', 'v_frexp_mant_f64', 'v_rcp_f64', 'v_rndne_f64', 'v_rsq_f64', 'v_sqrt_f64', 'v_trunc_f64', 'v_cvt_f64_f32', 'v_cvt_f64_i32', 'v_cvt_f64_u32')
is_f64_src = op_name in ('v_ceil_f64', 'v_floor_f64', 'v_fract_f64', 'v_frexp_mant_f64', 'v_rcp_f64', 'v_rndne_f64', 'v_rsq_f64', 'v_sqrt_f64', 'v_trunc_f64', 'v_cvt_f32_f64', 'v_cvt_i32_f64', 'v_cvt_u32_f64', 'v_frexp_exp_i32_f64')
is_16bit_dst = any(p in _16BIT_TYPES for p in parts[-2:-1]) or (len(parts) >= 2 and parts[-1] in _16BIT_TYPES and 'cvt' not in op_name)
is_16bit_src = parts[-1] in _16BIT_TYPES and 'sat_pk' not in op_name
_F64_OPS = ('v_ceil_f64', 'v_floor_f64', 'v_fract_f64', 'v_frexp_mant_f64', 'v_rcp_f64', 'v_rndne_f64', 'v_rsq_f64', 'v_sqrt_f64', 'v_trunc_f64')
is_f64_dst = op_name in _F64_OPS or op_name in ('v_cvt_f64_f32', 'v_cvt_f64_i32', 'v_cvt_f64_u32')
is_f64_src = op_name in _F64_OPS or op_name in ('v_cvt_f32_f64', 'v_cvt_i32_f64', 'v_cvt_u32_f64', 'v_frexp_exp_i32_f64')
if op_name == 'v_readfirstlane_b32':
return f"v_readfirstlane_b32 {decode_src(vdst)}, v{src0 - 256 if src0 >= 256 else src0}"
dst_str = _vreg(vdst, 2) if is_f64_dst else f"v{vdst & 0x7f}.{'h' if vdst >= 128 else 'l'}" if is_16bit_dst else f"v{vdst}"
if is_f64_src:
src_str = _vreg(src0 - 256, 2) if src0 >= 256 else _sreg(src0, 2) if src0 <= 105 else "vcc" if src0 == 106 else "exec" if src0 == 126 else f"ttmp[{src0-108}:{src0-108+1}]" if 108 <= src0 <= 123 else fmt_src(src0)
elif is_16bit_src and src0 >= 256:
src_str = f"v{(src0 - 256) & 0x7f}.{'h' if src0 >= 384 else 'l'}"
else:
src_str = fmt_src(src0)
src_str = _fmt_src64(src0) if is_f64_src else f"v{(src0 - 256) & 0x7f}.{'h' if src0 >= 384 else 'l'}" if is_16bit_src and src0 >= 256 else fmt_src(src0)
return f"{op_name}_e32 {dst_str}, {src_str}"
# VOP2
@@ -118,16 +136,9 @@ def disasm(inst: Inst) -> str:
is_64bit_vsrc1 = is_64bit and 'class' not in op_name
is_16bit = any(x in op_name for x in ('_f16', '_i16', '_u16')) and 'f32' not in op_name
is_cmpx = op_name.startswith('v_cmpx') # VOPCX writes to exec, no vcc destination
if is_64bit:
src0_str = _vreg(src0 - 256, 2) if src0 >= 256 else _sreg(src0, 2) if src0 <= 105 else "vcc" if src0 == 106 else "exec" if src0 == 126 else f"ttmp[{src0-108}:{src0-108+1}]" if 108 <= src0 <= 123 else fmt_src(src0)
elif is_16bit and src0 >= 256:
src0_str = f"v{(src0 - 256) & 0x7f}.{'h' if src0 >= 384 else 'l'}"
else:
src0_str = fmt_src(src0)
src0_str = _fmt_src64(src0) if is_64bit else f"v{(src0 - 256) & 0x7f}.{'h' if src0 >= 384 else 'l'}" if is_16bit and src0 >= 256 else fmt_src(src0)
vsrc1_str = _vreg(vsrc1, 2) if is_64bit_vsrc1 else f"v{vsrc1 & 0x7f}.{'h' if vsrc1 >= 128 else 'l'}" if is_16bit else f"v{vsrc1}"
if is_cmpx:
return f"{op_name}_e32 {src0_str}, {vsrc1_str}"
return f"{op_name}_e32 vcc_lo, {src0_str}, {vsrc1_str}"
return f"{op_name}_e32 {src0_str}, {vsrc1_str}" if is_cmpx else f"{op_name}_e32 vcc_lo, {src0_str}, {vsrc1_str}"
# SOPP
if cls_name == 'SOPP':
@@ -159,64 +170,28 @@ def disasm(inst: Inst) -> str:
# SMEM
if cls_name == 'SMEM':
# No-operand instructions
if op_name in ('s_gl1_inv', 's_dcache_inv'): return op_name
sdata, sbase, soffset, offset = unwrap(inst._values['sdata']), unwrap(inst._values['sbase']), unwrap(inst._values['soffset']), unwrap(inst._values.get('offset', 0))
glc, dlc = unwrap(inst._values.get('glc', 0)), unwrap(inst._values.get('dlc', 0))
# s_atc_probe/s_atc_probe_buffer: sdata is the probe mode (0-7), not a register
if op_name in ('s_atc_probe', 's_atc_probe_buffer'):
sbase_idx = sbase * 2
sbase_cnt = 4 if op_name == 's_atc_probe_buffer' else 2
sbase_str = _sreg(sbase_idx, sbase_cnt)
if offset and soffset != 124:
off_str = f"{decode_src(soffset)} offset:0x{offset:x}"
elif offset:
off_str = f"0x{offset:x}"
else:
off_str = decode_src(soffset)
return f"{op_name} {sdata}, {sbase_str}, {off_str}"
# Format offset: "soffset offset:X" if both, "0x{offset:x}" if only imm, or decode_src(soffset)
off_str = f"{decode_src(soffset)} offset:0x{offset:x}" if offset and soffset != 124 else f"0x{offset:x}" if offset else decode_src(soffset)
sbase_idx, sbase_cnt = sbase * 2, 4 if (8 <= op_val <= 12 or op_name == 's_atc_probe_buffer') else 2
sbase_str = _fmt_ssrc(sbase_idx, sbase_cnt) if sbase_cnt == 2 else _sreg(sbase_idx, sbase_cnt) if sbase_idx <= 105 else _reg("ttmp", sbase_idx - 108, sbase_cnt)
if op_name in ('s_atc_probe', 's_atc_probe_buffer'): return f"{op_name} {sdata}, {sbase_str}, {off_str}"
width = {0:1, 1:2, 2:4, 3:8, 4:16, 8:1, 9:2, 10:4, 11:8, 12:16}.get(op_val, 1)
# Offset handling: if offset is set, we need "soffset offset:X" format, otherwise just soffset or imm
if offset and soffset != 124: # both soffset register and offset immediate
off_str = f"{decode_src(soffset)} offset:0x{offset:x}"
elif offset: # only offset immediate (soffset=null)
off_str = f"0x{offset:x}"
elif soffset == 124: # null
off_str = "null"
else: # only soffset register
off_str = decode_src(soffset)
# sbase is stored as register pair index, multiply by 2 for actual register number
# s_buffer_load_* (op 8-12) use 4-reg sbase (buffer descriptor), s_load_* (op 0-4) use 2-reg sbase
sbase_idx = sbase * 2
sbase_cnt = 4 if 8 <= op_val <= 12 else 2
# Format sbase with special register names
if sbase_idx == 106 and sbase_cnt == 2: sbase_str = "vcc"
elif sbase_idx == 126 and sbase_cnt == 2: sbase_str = "exec"
elif 108 <= sbase_idx <= 123: sbase_str = f"ttmp[{sbase_idx-108}:{sbase_idx-108+sbase_cnt-1}]"
else: sbase_str = _sreg(sbase_idx, sbase_cnt)
# Build modifiers
mods = []
if glc: mods.append("glc")
if dlc: mods.append("dlc")
mod_str = " " + " ".join(mods) if mods else ""
return f"{op_name} {_fmt_sdst(sdata, width)}, {sbase_str}, {off_str}{mod_str}"
mods = [m for m in ["glc" if glc else "", "dlc" if dlc else ""] if m]
return f"{op_name} {_fmt_sdst(sdata, width)}, {sbase_str}, {off_str}" + (" " + " ".join(mods) if mods else "")
# FLAT
if cls_name == 'FLAT':
vdst, addr, data, saddr, offset, seg = [unwrap(inst._values.get(f, 0)) for f in ['vdst', 'addr', 'data', 'saddr', 'offset', 'seg']]
prefix = {0: 'flat', 1: 'scratch', 2: 'global'}.get(seg, 'flat')
op_suffix = op_name.split('_', 1)[1] if '_' in op_name else op_name
instr = f"{prefix}_{op_suffix}"
is_store = 'store' in op_name
instr = f"{['flat', 'scratch', 'global'][seg] if seg < 3 else 'flat'}_{op_name.split('_', 1)[1] if '_' in op_name else op_name}"
width = {'b32':1, 'b64':2, 'b96':3, 'b128':4, 'u8':1, 'i8':1, 'u16':1, 'i16':1}.get(op_name.split('_')[-1], 1)
if saddr == 0x7F:
addr_str, saddr_str = _vreg(addr, 2), ""
else:
addr_str = _vreg(addr)
saddr_str = f", {_sreg(saddr, 2)}" if saddr < 106 else f", off" if saddr == 124 else f", {decode_src(saddr)}"
addr_str = _vreg(addr, 2) if saddr == 0x7F else _vreg(addr)
saddr_str = "" if saddr == 0x7F else f", {_sreg(saddr, 2)}" if saddr < 106 else ", off" if saddr == 124 else f", {decode_src(saddr)}"
off_str = f" offset:{offset}" if offset else ""
if is_store: return f"{instr} {addr_str}, {_vreg(data, width)}{saddr_str}{off_str}"
return f"{instr} {_vreg(vdst, width)}, {addr_str}{saddr_str}{off_str}"
vdata_str = _vreg(data if 'store' in op_name else vdst, width)
return f"{instr} {addr_str}, {vdata_str}{saddr_str}{off_str}" if 'store' in op_name else f"{instr} {vdata_str}, {addr_str}{saddr_str}{off_str}"
# VOP3: vector ops with modifiers (can be 1, 2, or 3 sources depending on opcode range)
if cls_name == 'VOP3':
@@ -237,18 +212,10 @@ def disasm(inst: Inst) -> str:
# v_mad_i64_i32/v_mad_u64_u32: 64-bit dst and src2, 32-bit src0/src1
is_mad64 = 'mad_i64_i32' in op_name or 'mad_u64_u32' in op_name
def fmt_sd_src(v, neg_bit, is_64bit=False):
s = fmt_src(v)
if is_64bit or is_f64:
if v >= 256: s = _vreg(v - 256, 2)
elif v <= 105: s = _sreg(v, 2)
elif v == 106: s = "vcc"
elif v == 126: s = "exec"
elif 108 <= v <= 123: s = f"ttmp[{v-108}:{v-108+1}]"
if neg_bit: s = f"-{s}"
return s
src0_str = fmt_sd_src(src0, neg & 1, False) # 32-bit for mad64
src1_str = fmt_sd_src(src1, neg & 2, False) # 32-bit for mad64
src2_str = fmt_sd_src(src2, neg & 4, is_mad64) # 64-bit for mad64
s = _fmt_src64(v) if (is_64bit or is_f64) else fmt_src(v)
return f"-{s}" if neg_bit else s
src0_str, src1_str = fmt_sd_src(src0, neg & 1), fmt_sd_src(src1, neg & 2)
src2_str = fmt_sd_src(src2, neg & 4, is_mad64)
dst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}"
sdst_str = _fmt_sdst(sdst, 1)
# v_add_co_u32, v_sub_co_u32, v_subrev_co_u32, v_add_co_ci_u32, etc. only use 2 sources
@@ -278,56 +245,24 @@ def disasm(inst: Inst) -> str:
# v_mqsad_u32_u8: 128-bit (4 reg) dst/src2, 64-bit src0, 32-bit src1
is_sad64 = any(x in op_name for x in ('qsad_pk', 'mqsad_pk'))
is_mqsad_u32 = 'mqsad_u32' in op_name
# Detect conversion ops: v_cvt_{dst_type}_{src_type} - each side may have different size
# Also handle v_cvt_pk_* which packs two values into one
# Detect 16-bit and 64-bit operand sizes for various instruction patterns
if 'cvt_pk' in op_name:
# Pack ops: dst is packed 16-bit, src is determined by last type in name
# e.g., v_cvt_pk_i16_f32, v_cvt_pk_norm_i16_f32
is_f16_dst = is_f16_src = is_f16_src2 = False # dst is 32-bit, srcs depend on op
is_f16_src = op_name.endswith('16') # only if final type is 16-bit
elif m := re.match(r'v_cvt_([a-z0-9_]+)_([a-z0-9]+)', op_name):
is_f16_dst, is_f16_src, is_f16_src2 = False, op_name.endswith('16'), False
elif m := re.match(r'v_(?:cvt|frexp_exp)_([a-z0-9_]+)_([a-z0-9]+)', op_name):
dst_type, src_type = m.group(1), m.group(2)
# Check if dst/src ends with a 16-bit type suffix
is_f16_dst = any(dst_type.endswith(x) for x in ('f16', 'i16', 'u16', 'b16'))
is_f16_src = is_f16_src2 = any(src_type.endswith(x) for x in ('f16', 'i16', 'u16', 'b16'))
# Override is_f64 for conversion ops - check if dst or src is 64-bit
is_f64_dst = '64' in dst_type
is_f64_src = '64' in src_type
is_f64 = False # Don't use default is_f64 detection for cvt ops
elif m := re.match(r'v_frexp_exp_([a-z0-9]+)_([a-z0-9]+)', op_name):
# v_frexp_exp_i32_f64: 32-bit dst (exponent), 64-bit src
# v_frexp_exp_i16_f16: 16-bit dst, 16-bit src
dst_type, src_type = m.group(1), m.group(2)
is_f16_dst = any(dst_type.endswith(x) for x in ('f16', 'i16', 'u16', 'b16'))
is_f16_src = is_f16_src2 = any(src_type.endswith(x) for x in ('f16', 'i16', 'u16', 'b16'))
is_f64_dst = '64' in dst_type
is_f64_src = '64' in src_type
is_f64 = False
elif m := re.match(r'v_mad_([iu])32_([iu])16', op_name):
# v_mad_i32_i16, v_mad_u32_u16: 32-bit dst, 16-bit src0/src1, 32-bit src2
is_f16_dst = False
is_f16_src = True # src0 and src1 are 16-bit
is_f16_src2 = False # src2 is 32-bit
is_f16_dst, is_f16_src, is_f16_src2 = _is_16bit(dst_type), _is_16bit(src_type), _is_16bit(src_type)
is_f64_dst, is_f64_src, is_f64 = '64' in dst_type, '64' in src_type, False
elif re.match(r'v_mad_[iu]32_[iu]16', op_name):
is_f16_dst, is_f16_src, is_f16_src2 = False, True, False # 32-bit dst, 16-bit src0/src1, 32-bit src2
elif 'pack_b32' in op_name:
# v_pack_b32_f16: 32-bit dst, 16-bit sources
is_f16_dst = False
is_f16_src = is_f16_src2 = True
is_f16_dst, is_f16_src, is_f16_src2 = False, True, True # 32-bit dst, 16-bit sources
else:
# 16-bit ops need .h/.l suffix, but packed ops (dot2, pk_, sad, msad, qsad, mqsad) don't
is_16bit_op = ('f16' in op_name or 'i16' in op_name or 'u16' in op_name or 'b16' in op_name) and not any(x in op_name for x in ('dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad'))
is_16bit_op = any(x in op_name for x in _16BIT_TYPES) and not any(x in op_name for x in ('dot2', 'pk_', 'sad', 'msad', 'qsad', 'mqsad'))
is_f16_dst = is_f16_src = is_f16_src2 = is_16bit_op
def fmt_vop3_src(v, neg_bit, abs_bit, hi_bit=False, reg_cnt=1, is_16=False):
s = fmt_src(v)
# Add register pair/quad for 64/128-bit, or .h suffix for f16 VGPRs with opsel
if reg_cnt > 1 and v >= 256: s = _vreg(v - 256, reg_cnt)
elif reg_cnt > 1 and v <= 105: s = _sreg(v, reg_cnt)
elif reg_cnt == 2 and v == 106: s = "vcc"
elif reg_cnt == 2 and v == 126: s = "exec"
elif reg_cnt > 1 and 108 <= v <= 123: s = f"ttmp[{v-108}:{v-108+reg_cnt-1}]"
elif is_16 and v >= 256: s = f"v{v - 256}.h" if hi_bit else f"v{v - 256}.l"
s = _fmt_src_n(v, reg_cnt) if reg_cnt > 1 else f"v{v - 256}.h" if is_16 and v >= 256 and hi_bit else f"v{v - 256}.l" if is_16 and v >= 256 else fmt_src(v)
if abs_bit: s = f"|{s}|"
if neg_bit: s = f"-{s}"
return s
return f"-{s}" if neg_bit else s
# Determine register count for each source (check for cvt-specific 64-bit flags first)
is_src0_64 = locals().get('is_f64_src', is_f64 and not is_shift64) or is_sad64 or is_mqsad_u32
is_src1_64 = is_f64 and not is_class and not is_ldexp64 and not is_trig_preop
@@ -398,228 +333,94 @@ def disasm(inst: Inst) -> str:
if cls_name == 'VOP3SD':
vdst, sdst = unwrap(inst._values.get('vdst', 0)), unwrap(inst._values.get('sdst', 0))
src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')]
neg = unwrap(inst._values.get('neg', 0))
omod = unwrap(inst._values.get('omod', 0))
clmp = unwrap(inst._values.get('clmp', 0))
is_f64 = 'f64' in op_name
is_mad64 = 'mad_i64_i32' in op_name or 'mad_u64_u32' in op_name
def fmt_sd_src(v, neg_bit, is_64bit=False):
s = fmt_src(v)
if is_64bit or is_f64:
if v >= 256: s = _vreg(v - 256, 2)
elif v <= 105: s = _sreg(v, 2)
elif v == 106: s = "vcc"
elif v == 126: s = "exec"
elif 108 <= v <= 123: s = f"ttmp[{v-108}:{v-108+1}]"
if neg_bit: s = f"-{s}"
return s
src0_str = fmt_sd_src(src0, neg & 1, False)
src1_str = fmt_sd_src(src1, neg & 2, False)
src2_str = fmt_sd_src(src2, neg & 4, is_mad64)
dst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}"
sdst_str = _fmt_sdst(sdst, 1)
clamp_str = " clamp" if clmp else ""
omod_str = {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(omod, "")
# v_add_co_u32, v_sub_co_u32, v_subrev_co_u32 only use 2 sources
if op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32'):
return f"{op_name}_e64 {dst_str}, {sdst_str}, {src0_str}, {src1_str}" + clamp_str
# v_add_co_ci_u32, v_sub_co_ci_u32, v_subrev_co_ci_u32 use 3 sources (src2 is carry-in)
if op_name in ('v_add_co_ci_u32', 'v_sub_co_ci_u32', 'v_subrev_co_ci_u32'):
return f"{op_name}_e64 {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" + clamp_str
# v_div_scale, v_mad_*64_*32 use 3 sources
return f"{op_name} {dst_str}, {sdst_str}, {src0_str}, {src1_str}, {src2_str}" + clamp_str + omod_str
neg, omod, clmp = unwrap(inst._values.get('neg', 0)), unwrap(inst._values.get('omod', 0)), unwrap(inst._values.get('clmp', 0))
is_f64, is_mad64 = 'f64' in op_name, 'mad_i64_i32' in op_name or 'mad_u64_u32' in op_name
def fmt_neg(v, neg_bit, is_64=False): return f"-{_fmt_src64(v) if (is_64 or is_f64) else fmt_src(v)}" if neg_bit else _fmt_src64(v) if (is_64 or is_f64) else fmt_src(v)
srcs = [fmt_neg(src0, neg & 1), fmt_neg(src1, neg & 2), fmt_neg(src2, neg & 4, is_mad64)]
dst_str, sdst_str = _vreg(vdst, 2) if (is_f64 or is_mad64) else f"v{vdst}", _fmt_sdst(sdst, 1)
clamp_str, omod_str = " clamp" if clmp else "", {1: " mul:2", 2: " mul:4", 3: " div:2"}.get(omod, "")
is_2src = op_name in ('v_add_co_u32', 'v_sub_co_u32', 'v_subrev_co_u32')
suffix = "_e64" if op_name.startswith('v_') and 'co_' in op_name else ""
return f"{op_name}{suffix} {dst_str}, {sdst_str}, {', '.join(srcs[:2] if is_2src else srcs)}" + clamp_str + omod_str
# VOPD: dual-issue instructions
if cls_name == 'VOPD':
from extra.assembly.rdna3 import autogen
opx, opy = unwrap(inst._values.get('opx', 0)), unwrap(inst._values.get('opy', 0))
vdstx, vdsty_enc = unwrap(inst._values.get('vdstx', 0)), unwrap(inst._values.get('vdsty', 0))
srcx0, vsrcx1 = unwrap(inst._values.get('srcx0', 0)), unwrap(inst._values.get('vsrcx1', 0))
srcy0, vsrcy1 = unwrap(inst._values.get('srcy0', 0)), unwrap(inst._values.get('vsrcy1', 0))
# Decode vdsty: actual = (encoded << 1) | ((vdstx & 1) ^ 1)
vdsty = (vdsty_enc << 1) | ((vdstx & 1) ^ 1)
try:
opx_name = autogen.VOPDOp(opx).name.lower()
opy_name = autogen.VOPDOp(opy).name.lower()
except (ValueError, KeyError):
opx_name, opy_name = f"opx_{opx}", f"opy_{opy}"
# v_dual_mov_b32 only has 1 source
opx_str = f"{opx_name} v{vdstx}, {fmt_src(srcx0)}" if 'mov' in opx_name else f"{opx_name} v{vdstx}, {fmt_src(srcx0)}, v{vsrcx1}"
opy_str = f"{opy_name} v{vdsty}, {fmt_src(srcy0)}" if 'mov' in opy_name else f"{opy_name} v{vdsty}, {fmt_src(srcy0)}, v{vsrcy1}"
return f"{opx_str} :: {opy_str}"
opx, opy, vdstx, vdsty_enc = [unwrap(inst._values.get(f, 0)) for f in ('opx', 'opy', 'vdstx', 'vdsty')]
srcx0, vsrcx1, srcy0, vsrcy1 = [unwrap(inst._values.get(f, 0)) for f in ('srcx0', 'vsrcx1', 'srcy0', 'vsrcy1')]
vdsty = (vdsty_enc << 1) | ((vdstx & 1) ^ 1) # Decode vdsty
def fmt_vopd(op, vdst, src0, vsrc1):
try: name = autogen.VOPDOp(op).name.lower()
except (ValueError, KeyError): name = f"op_{op}"
return f"{name} v{vdst}, {fmt_src(src0)}" if 'mov' in name else f"{name} v{vdst}, {fmt_src(src0)}, v{vsrc1}"
return f"{fmt_vopd(opx, vdstx, srcx0, vsrcx1)} :: {fmt_vopd(opy, vdsty, srcy0, vsrcy1)}"
# VOP3P: packed vector ops
if cls_name == 'VOP3P':
vdst = unwrap(inst._values.get('vdst', 0))
vdst, clmp = unwrap(inst._values.get('vdst', 0)), unwrap(inst._values.get('clmp', 0))
src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')]
neg = unwrap(inst._values.get('neg', 0)) # neg_lo
neg_hi = unwrap(inst._values.get('neg_hi', 0))
opsel = unwrap(inst._values.get('opsel', 0))
opsel_hi = unwrap(inst._values.get('opsel_hi', 0))
opsel_hi2 = unwrap(inst._values.get('opsel_hi2', 0))
clmp = unwrap(inst._values.get('clmp', 0))
# WMMA ops have special register widths
is_wmma = 'wmma' in op_name
# Determine number of sources (dot ops are 3-src, most are 2-src)
is_3src = any(x in op_name for x in ('fma', 'mad', 'dot', 'wmma'))
# Format source operands
def fmt_vop3p_src(v, reg_cnt=1):
if v >= 256: return _vreg(v - 256, reg_cnt)
if v <= 105: return _sreg(v, reg_cnt) if reg_cnt > 1 else f"s{v}"
if v == 106 and reg_cnt == 2: return "vcc"
if v == 126 and reg_cnt == 2: return "exec"
return fmt_src(v)
neg, neg_hi = unwrap(inst._values.get('neg', 0)), unwrap(inst._values.get('neg_hi', 0))
opsel, opsel_hi, opsel_hi2 = unwrap(inst._values.get('opsel', 0)), unwrap(inst._values.get('opsel_hi', 0)), unwrap(inst._values.get('opsel_hi2', 0))
is_wmma, is_3src = 'wmma' in op_name, any(x in op_name for x in ('fma', 'mad', 'dot', 'wmma'))
def fmt_bits(name, val, n): return f"{name}:[{','.join(str((val >> i) & 1) for i in range(n))}]"
# WMMA: f16/bf16 use 8-reg sources, iu8 uses 4-reg, iu4 uses 2-reg; all have 8-reg dst
if is_wmma:
src_cnt = 2 if 'iu4' in op_name else 4 if 'iu8' in op_name else 8
src0_str = _vreg(src0 - 256, src_cnt) if src0 >= 256 else fmt_vop3p_src(src0, src_cnt)
src1_str = _vreg(src1 - 256, src_cnt) if src1 >= 256 else fmt_vop3p_src(src1, src_cnt)
src2_str = _vreg(src2 - 256, 8) if src2 >= 256 else fmt_vop3p_src(src2, 8)
src0_str, src1_str, src2_str = _fmt_src_n(src0, src_cnt), _fmt_src_n(src1, src_cnt), _fmt_src_n(src2, 8)
dst_str = _vreg(vdst, 8)
else:
src0_str = fmt_vop3p_src(src0)
src1_str = fmt_vop3p_src(src1)
src2_str = fmt_vop3p_src(src2)
src0_str, src1_str, src2_str = _fmt_src_n(src0, 1), _fmt_src_n(src1, 1), _fmt_src_n(src2, 1)
dst_str = f"v{vdst}"
# Build modifiers - VOP3P uses op_sel, op_sel_hi, neg_lo, neg_hi
mods = []
# op_sel: selects high/low half of each source
if opsel:
if is_3src:
mods.append(f"op_sel:[{opsel & 1},{(opsel >> 1) & 1},{(opsel >> 2) & 1}]")
else:
mods.append(f"op_sel:[{opsel & 1},{(opsel >> 1) & 1}]")
# op_sel_hi: selects high half for upper result lane (default [1,1] or [1,1,1])
# opsel_hi is bits 0-1, opsel_hi2 is bit 2 (for src2)
n = 3 if is_3src else 2
full_opsel_hi = opsel_hi | (opsel_hi2 << 2)
default_opsel_hi = 0b111 if is_3src else 0b11
if full_opsel_hi != default_opsel_hi:
if is_3src:
mods.append(f"op_sel_hi:[{full_opsel_hi & 1},{(full_opsel_hi >> 1) & 1},{(full_opsel_hi >> 2) & 1}]")
else:
mods.append(f"op_sel_hi:[{full_opsel_hi & 1},{(full_opsel_hi >> 1) & 1}]")
# neg_lo: negate lower half of source
if neg:
if is_3src:
mods.append(f"neg_lo:[{neg & 1},{(neg >> 1) & 1},{(neg >> 2) & 1}]")
else:
mods.append(f"neg_lo:[{neg & 1},{(neg >> 1) & 1}]")
# neg_hi: negate upper half of source
if neg_hi:
if is_3src:
mods.append(f"neg_hi:[{neg_hi & 1},{(neg_hi >> 1) & 1},{(neg_hi >> 2) & 1}]")
else:
mods.append(f"neg_hi:[{neg_hi & 1},{(neg_hi >> 1) & 1}]")
mods = [fmt_bits("op_sel", opsel, n)] if opsel else []
if full_opsel_hi != (0b111 if is_3src else 0b11): mods.append(fmt_bits("op_sel_hi", full_opsel_hi, n))
if neg: mods.append(fmt_bits("neg_lo", neg, n))
if neg_hi: mods.append(fmt_bits("neg_hi", neg_hi, n))
if clmp: mods.append("clamp")
mod_str = " " + " ".join(mods) if mods else ""
if is_3src:
return f"{op_name} {dst_str}, {src0_str}, {src1_str}, {src2_str}{mod_str}"
return f"{op_name} {dst_str}, {src0_str}, {src1_str}{mod_str}"
return f"{op_name} {dst_str}, {src0_str}, {src1_str}, {src2_str}{mod_str}" if is_3src else f"{op_name} {dst_str}, {src0_str}, {src1_str}{mod_str}"
# VINTERP: interpolation instructions
if cls_name == 'VINTERP':
vdst = unwrap(inst._values.get('vdst', 0))
src0, src1, src2 = [unwrap(inst._values.get(f, 0)) for f in ('src0', 'src1', 'src2')]
waitexp = unwrap(inst._values.get('waitexp', 0))
neg = unwrap(inst._values.get('neg', 0))
clmp = unwrap(inst._values.get('clmp', 0))
opsel = unwrap(inst._values.get('opsel', 0))
def fmt_vi_src(v, neg_bit):
s = f"v{v - 256}" if v >= 256 else fmt_src(v)
if neg_bit: s = f"-{s}"
return s
src0_str = fmt_vi_src(src0, neg & 1)
src1_str = fmt_vi_src(src1, neg & 2)
src2_str = fmt_vi_src(src2, neg & 4)
# LLVM doesn't use .l/.h suffix for vinterp dst
dst_str = f"v{vdst}"
mods = []
if waitexp: mods.append(f"wait_exp:{waitexp}")
if clmp: mods.append("clamp")
mod_str = " " + " ".join(mods) if mods else ""
return f"{op_name} {dst_str}, {src0_str}, {src1_str}, {src2_str}{mod_str}"
neg, waitexp, clmp = unwrap(inst._values.get('neg', 0)), unwrap(inst._values.get('waitexp', 0)), unwrap(inst._values.get('clmp', 0))
def fmt_neg_vi(v, neg_bit): return f"-{v}" if neg_bit else v
srcs = [fmt_neg_vi(f"v{s - 256}" if s >= 256 else fmt_src(s), neg & (1 << i)) for i, s in enumerate([src0, src1, src2])]
mods = [m for m in [f"wait_exp:{waitexp}" if waitexp else "", "clamp" if clmp else ""] if m]
return f"{op_name} v{vdst}, {', '.join(srcs)}" + (" " + " ".join(mods) if mods else "")
# MUBUF/MTBUF helpers
def _buf_vaddr(vaddr, offen, idxen): return _vreg(vaddr, 2) if offen and idxen else f"v{vaddr}" if offen or idxen else "off"
def _buf_srsrc(srsrc): srsrc_base = srsrc * 4; return _reg("ttmp", srsrc_base - 108, 4) if 108 <= srsrc_base <= 123 else _sreg(srsrc_base, 4)
# MUBUF: buffer load/store
if cls_name == 'MUBUF':
vdata, vaddr = unwrap(inst._values.get('vdata', 0)), unwrap(inst._values.get('vaddr', 0))
srsrc, soffset = unwrap(inst._values.get('srsrc', 0)), unwrap(inst._values.get('soffset', 0))
offset = unwrap(inst._values.get('offset', 0))
offen, idxen = unwrap(inst._values.get('offen', 0)), unwrap(inst._values.get('idxen', 0))
glc, dlc, slc = unwrap(inst._values.get('glc', 0)), unwrap(inst._values.get('dlc', 0)), unwrap(inst._values.get('slc', 0))
tfe = unwrap(inst._values.get('tfe', 0))
# Special ops with no operands
vdata, vaddr, srsrc, soffset = [unwrap(inst._values.get(f, 0)) for f in ('vdata', 'vaddr', 'srsrc', 'soffset')]
offset, offen, idxen = unwrap(inst._values.get('offset', 0)), unwrap(inst._values.get('offen', 0)), unwrap(inst._values.get('idxen', 0))
glc, dlc, slc, tfe = [unwrap(inst._values.get(f, 0)) for f in ('glc', 'dlc', 'slc', 'tfe')]
if op_name in ('buffer_gl0_inv', 'buffer_gl1_inv'): return op_name
# Determine data width from op name
# d16 formats: _x and _xy use 1 reg, _xyz and _xyzw use 2 regs
# regular formats: _x=1, _xy=2, _xyz=3, _xyzw=4
# atomic u64 uses 2 regs, cmpswap doubles width (compare + swap)
if 'd16' in op_name:
width = 2 if any(x in op_name for x in ('xyz', 'xyzw')) else 1
if 'd16' in op_name: width = 2 if any(x in op_name for x in ('xyz', 'xyzw')) else 1
elif 'atomic' in op_name:
# cmpswap uses 2 regs for b32, 4 for b64; other atomics use 1 for b32, 2 for b64/u64/i64
base_width = 2 if any(x in op_name for x in ('b64', 'u64', 'i64')) else 1
width = base_width * 2 if 'cmpswap' in op_name else base_width
else:
width = {'b32':1, 'b64':2, 'b96':3, 'b128':4, 'b16':1, 'x':1, 'xy':2, 'xyz':3, 'xyzw':4}.get(op_name.split('_')[-1], 1)
# tfe adds 1 extra VGPR for texture fault status
else: width = {'b32':1, 'b64':2, 'b96':3, 'b128':4, 'b16':1, 'x':1, 'xy':2, 'xyz':3, 'xyzw':4}.get(op_name.split('_')[-1], 1)
if tfe: width += 1
is_store = 'store' in op_name
# Format vaddr
if offen and idxen: vaddr_str = f"v[{vaddr}:{vaddr+1}]"
elif offen or idxen: vaddr_str = f"v{vaddr}"
else: vaddr_str = "off"
# Format srsrc (4-aligned SGPR quad)
srsrc_base = srsrc * 4
srsrc_str = f"s[{srsrc_base}:{srsrc_base+3}]"
# Format soffset - use decode_src for proper constant handling
soff_str = decode_src(soffset)
# Build modifiers
mods = []
if offen: mods.append("offen")
if idxen: mods.append("idxen")
if offset: mods.append(f"offset:{offset}")
if glc: mods.append("glc")
if dlc: mods.append("dlc")
if slc: mods.append("slc")
if tfe: mods.append("tfe")
mod_str = " " + " ".join(mods) if mods else ""
if is_store:
return f"{op_name} {_vreg(vdata, width)}, {vaddr_str}, {srsrc_str}, {soff_str}{mod_str}"
return f"{op_name} {_vreg(vdata, width)}, {vaddr_str}, {srsrc_str}, {soff_str}{mod_str}"
mods = [m for m in ["offen" if offen else "", "idxen" if idxen else "", f"offset:{offset}" if offset else "",
"glc" if glc else "", "dlc" if dlc else "", "slc" if slc else "", "tfe" if tfe else ""] if m]
return f"{op_name} {_vreg(vdata, width)}, {_buf_vaddr(vaddr, offen, idxen)}, {_buf_srsrc(srsrc)}, {decode_src(soffset)}" + (" " + " ".join(mods) if mods else "")
# MTBUF: typed buffer load/store
if cls_name == 'MTBUF':
vdata, vaddr = unwrap(inst._values.get('vdata', 0)), unwrap(inst._values.get('vaddr', 0))
srsrc, soffset = unwrap(inst._values.get('srsrc', 0)), unwrap(inst._values.get('soffset', 0))
offset, fmt = unwrap(inst._values.get('offset', 0)), unwrap(inst._values.get('format', 0))
offen, idxen = unwrap(inst._values.get('offen', 0)), unwrap(inst._values.get('idxen', 0))
glc, dlc, slc = unwrap(inst._values.get('glc', 0)), unwrap(inst._values.get('dlc', 0)), unwrap(inst._values.get('slc', 0))
# Format vaddr
if offen and idxen: vaddr_str = f"v[{vaddr}:{vaddr+1}]"
elif offen or idxen: vaddr_str = f"v{vaddr}"
else: vaddr_str = "off"
# Format srsrc (4-aligned SGPR quad, or ttmp)
srsrc_base = srsrc * 4
if 108 <= srsrc_base <= 123:
srsrc_str = f"ttmp[{srsrc_base-108}:{srsrc_base-108+3}]"
else:
srsrc_str = f"s[{srsrc_base}:{srsrc_base+3}]"
# Format soffset - use decode_src for proper special register handling
soff_str = decode_src(soffset)
# Build modifiers - idxen must come before offen for LLVM
mods = [f"format:{fmt}"]
if idxen: mods.append("idxen")
if offen: mods.append("offen")
if offset: mods.append(f"offset:{offset}")
if glc: mods.append("glc")
if dlc: mods.append("dlc")
if slc: mods.append("slc")
# Determine vdata width: d16 xyz/xyzw use 2 regs, d16 x/xy use 1 reg
if 'd16' in op_name:
width = 2 if any(x in op_name for x in ('xyz', 'xyzw')) else 1
else:
width = {'x':1, 'xy':2, 'xyz':3, 'xyzw':4}.get(op_name.split('_')[-1], 1)
return f"{op_name} {_vreg(vdata, width)}, {vaddr_str}, {srsrc_str}, {soff_str} {' '.join(mods)}"
vdata, vaddr, srsrc, soffset = [unwrap(inst._values.get(f, 0)) for f in ('vdata', 'vaddr', 'srsrc', 'soffset')]
offset, tbuf_fmt, offen, idxen = [unwrap(inst._values.get(f, 0)) for f in ('offset', 'format', 'offen', 'idxen')]
glc, dlc, slc = [unwrap(inst._values.get(f, 0)) for f in ('glc', 'dlc', 'slc')]
mods = [f"format:{tbuf_fmt}"] + [m for m in ["idxen" if idxen else "", "offen" if offen else "", f"offset:{offset}" if offset else "",
"glc" if glc else "", "dlc" if dlc else "", "slc" if slc else ""] if m]
width = 2 if 'd16' in op_name and any(x in op_name for x in ('xyz', 'xyzw')) else 1 if 'd16' in op_name else {'x':1, 'xy':2, 'xyz':3, 'xyzw':4}.get(op_name.split('_')[-1], 1)
return f"{op_name} {_vreg(vdata, width)}, {_buf_vaddr(vaddr, offen, idxen)}, {_buf_srsrc(srsrc)}, {decode_src(soffset)} {' '.join(mods)}"
# SOP1/SOP2/SOPC/SOPK
if cls_name in ('SOP1', 'SOP2', 'SOPC', 'SOPK'):
@@ -627,15 +428,13 @@ def disasm(inst: Inst) -> str:
dst_cnt, src0_cnt = sizes[0], sizes[1]
src1_cnt = sizes[2] if len(sizes) > 2 else src0_cnt
if cls_name == 'SOP1':
if op_name == 's_getpc_b64': return f"{op_name} {_fmt_sdst(unwrap(inst._values.get('sdst', 0)), 2)}"
if op_name in ('s_setpc_b64', 's_rfe_b64'): return f"{op_name} {_fmt_ssrc(unwrap(inst._values.get('ssrc0', 0)), 2)}"
if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(unwrap(inst._values.get('sdst', 0)), 2)}, {_fmt_ssrc(unwrap(inst._values.get('ssrc0', 0)), 2)}"
sdst, ssrc0 = unwrap(inst._values.get('sdst', 0)), unwrap(inst._values.get('ssrc0', 0))
if op_name == 's_getpc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}"
if op_name in ('s_setpc_b64', 's_rfe_b64'): return f"{op_name} {_fmt_ssrc(ssrc0, 2)}"
if op_name == 's_swappc_b64': return f"{op_name} {_fmt_sdst(sdst, 2)}, {_fmt_ssrc(ssrc0, 2)}"
if op_name in ('s_sendmsg_rtn_b32', 's_sendmsg_rtn_b64'):
msg_id = unwrap(inst._values.get('ssrc0', 0))
msg_names = {128: 'MSG_RTN_GET_DOORBELL', 129: 'MSG_RTN_GET_DDID', 130: 'MSG_RTN_GET_TMA', 131: 'MSG_RTN_GET_REALTIME', 132: 'MSG_RTN_SAVE_WAVE', 133: 'MSG_RTN_GET_TBA'}
msg = msg_names.get(msg_id, str(msg_id))
return f"{op_name} {_fmt_sdst(unwrap(inst._values.get('sdst', 0)), 2 if 'b64' in op_name else 1)}, sendmsg({msg})"
return f"{op_name} {_fmt_sdst(unwrap(inst._values.get('sdst', 0)), dst_cnt)}, {_fmt_ssrc(unwrap(inst._values.get('ssrc0', 0)), src0_cnt)}"
return f"{op_name} {_fmt_sdst(sdst, 2 if 'b64' in op_name else 1)}, sendmsg({MSG_NAMES.get(ssrc0, str(ssrc0))})"
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}"
if cls_name == 'SOP2':
sdst, ssrc0, ssrc1 = [unwrap(inst._values.get(f, 0)) for f in ('sdst', 'ssrc0', 'ssrc1')]
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, {_fmt_ssrc(ssrc0, src0_cnt)}, {_fmt_ssrc(ssrc1, src1_cnt)}"
@@ -645,41 +444,24 @@ def disasm(inst: Inst) -> str:
sdst, simm16 = unwrap(inst._values.get('sdst', 0)), unwrap(inst._values.get('simm16', 0))
if op_name == 's_version': return f"{op_name} 0x{simm16:x}"
if op_name in ('s_setreg_b32', 's_getreg_b32'):
# Decode hwreg: (size-1) << 11 | offset << 6 | id
hwreg_id, hwreg_offset, hwreg_size = simm16 & 0x3f, (simm16 >> 6) & 0x1f, ((simm16 >> 11) & 0x1f) + 1
# GFX11+ hwreg names (IDs 16-17 are TBA which are not supported on GFX11, IDs 18-19 are PERF_SNAPSHOT)
hwreg_names = {1: 'HW_REG_MODE', 2: 'HW_REG_STATUS', 3: 'HW_REG_TRAPSTS', 4: 'HW_REG_HW_ID',
5: 'HW_REG_GPR_ALLOC', 6: 'HW_REG_LDS_ALLOC', 7: 'HW_REG_IB_STS',
15: 'HW_REG_SH_MEM_BASES',
18: 'HW_REG_PERF_SNAPSHOT_PC_LO', 19: 'HW_REG_PERF_SNAPSHOT_PC_HI',
20: 'HW_REG_FLAT_SCR_LO', 21: 'HW_REG_FLAT_SCR_HI', 22: 'HW_REG_XNACK_MASK',
23: 'HW_REG_HW_ID1', 24: 'HW_REG_HW_ID2', 25: 'HW_REG_POPS_PACKER',
28: 'HW_REG_IB_STS2'}
# For unsupported registers (TBA_LO/HI, TMA_LO/HI on GFX11), output raw simm16 value
if hwreg_id in (16, 17, 18, 19) and hwreg_id not in hwreg_names:
# Unsupported on GFX11 - use raw encoding
hwreg_str = f"0x{simm16:x}"
else:
hwreg_name = hwreg_names.get(hwreg_id, str(hwreg_id))
hwreg_str = f"hwreg({hwreg_name}, {hwreg_offset}, {hwreg_size})"
if op_name == 's_setreg_b32':
return f"{op_name} {hwreg_str}, {_fmt_sdst(sdst, 1)}"
return f"{op_name} {_fmt_sdst(sdst, 1)}, {hwreg_str}"
hwreg_str = f"0x{simm16:x}" if hwreg_id in (16, 17) else f"hwreg({HWREG_NAMES.get(hwreg_id, str(hwreg_id))}, {hwreg_offset}, {hwreg_size})"
return f"{op_name} {hwreg_str}, {_fmt_sdst(sdst, 1)}" if op_name == 's_setreg_b32' else f"{op_name} {_fmt_sdst(sdst, 1)}, {hwreg_str}"
return f"{op_name} {_fmt_sdst(sdst, dst_cnt)}, 0x{simm16:x}"
# Generic fallback
def fmt(n, v):
def fmt_field(n, v):
v = unwrap(v)
if n in SRC_FIELDS: return fmt_src(v) if v != 255 else "0xff"
if n in ('sdst', 'vdst'): return f"{'s' if n == 'sdst' else 'v'}{v}"
return f"v{v}" if n == 'vsrc1' else f"0x{v:x}" if n == 'simm16' else str(v)
ops = [fmt(n, inst._values.get(n, 0)) for n in inst._fields if n not in ('encoding', 'op')]
ops = [fmt_field(n, inst._values.get(n, 0)) for n in inst._fields if n not in ('encoding', 'op')]
return f"{op_name} {', '.join(ops)}" if ops else op_name
# Assembler
SPECIAL_REGS = {'vcc_lo': RawImm(106), 'vcc_hi': RawImm(107), 'null': RawImm(124), 'off': RawImm(124), 'm0': RawImm(125), 'exec_lo': RawImm(126), 'exec_hi': RawImm(127), 'scc': RawImm(253)}
FLOAT_CONSTS = {'0.5': 0.5, '-0.5': -0.5, '1.0': 1.0, '-1.0': -1.0, '2.0': 2.0, '-2.0': -2.0, '4.0': 4.0, '-4.0': -4.0}
REG_MAP = {'s': SGPR, 'v': VGPR, 't': TTMP, 'ttmp': TTMP}
REG_MAP: dict[str, _RegFactory] = {'s': s, 'v': v, 't': ttmp, 'ttmp': ttmp}
def parse_operand(op: str) -> tuple:
op = op.strip().lower()
@@ -696,23 +478,16 @@ def parse_operand(op: str) -> tuple:
if op in SPECIAL_REGS: return (SPECIAL_REGS[op], neg, abs_, hi_half)
if m := re.match(r'^([svt](?:tmp)?)\[(\d+):(\d+)\]$', op): return (REG_MAP[m.group(1)][int(m.group(2)):int(m.group(3))+1], neg, abs_, hi_half)
if m := re.match(r'^([svt](?:tmp)?)(\d+)$', op):
return (REG_MAP[m.group(1)](int(m.group(2)), 1, hi_half), neg, abs_, hi_half)
reg = REG_MAP[m.group(1)][int(m.group(2))]
reg.hi = hi_half
return (reg, neg, abs_, hi_half)
# hwreg(name, offset, size) or hwreg(name) -> simm16 encoding
if m := re.match(r'^hwreg\((\w+)(?:,\s*(\d+),\s*(\d+))?\)$', op):
# GFX11 hwreg names - note IDs 18-19 are PERF_SNAPSHOT on GFX11, not TMA
hwreg_names = {'hw_reg_mode': 1, 'hw_reg_status': 2, 'hw_reg_trapsts': 3, 'hw_reg_hw_id': 4,
'hw_reg_gpr_alloc': 5, 'hw_reg_lds_alloc': 6, 'hw_reg_ib_sts': 7,
'hw_reg_sh_mem_bases': 15,
'hw_reg_perf_snapshot_pc_lo': 18, 'hw_reg_perf_snapshot_pc_hi': 19,
'hw_reg_flat_scr_lo': 20, 'hw_reg_flat_scr_hi': 21, 'hw_reg_xnack_mask': 22,
'hw_reg_hw_id1': 23, 'hw_reg_hw_id2': 24, 'hw_reg_pops_packer': 25, 'hw_reg_ib_sts2': 28}
name_str = m.group(1).lower()
hwreg_id = hwreg_names.get(name_str, int(name_str) if name_str.isdigit() else None)
hwreg_id = HWREG_IDS.get(name_str, int(name_str) if name_str.isdigit() else None)
if hwreg_id is None: raise ValueError(f"unknown hwreg name: {name_str}")
offset = int(m.group(2)) if m.group(2) else 0
size = int(m.group(3)) if m.group(3) else 32
simm16 = ((size - 1) << 11) | (offset << 6) | hwreg_id
return (simm16, neg, abs_, hi_half)
offset, size = int(m.group(2)) if m.group(2) else 0, int(m.group(3)) if m.group(3) else 32
return (((size - 1) << 11) | (offset << 6) | hwreg_id, neg, abs_, hi_half)
raise ValueError(f"cannot parse operand: {op}")
SMEM_OPS = {'s_load_b32', 's_load_b64', 's_load_b128', 's_load_b256', 's_load_b512',
@@ -780,6 +555,13 @@ def asm(text: str) -> Inst:
if mnemonic.replace('_e32', '') in vcc_ops and len(values) >= 5: values = [values[0], values[2], values[3]]
if mnemonic.startswith('v_cmp') and len(values) >= 3 and operands[0].strip().lower() in ('vcc_lo', 'vcc_hi', 'vcc'):
values = values[1:]
# CMPX instructions with _e64 suffix: prepend implicit EXEC_LO destination (vdst=126)
if 'cmpx' in mnemonic and mnemonic.endswith('_e64') and len(values) == 2:
values = [VGPR(126, 1)] + values
# Recalculate modifiers: parsed[0]=src0, parsed[1]=src1 (no vdst in user input)
neg_bits = sum((1 << i) for i, p in enumerate(parsed[:3]) if p[1])
abs_bits = sum((1 << i) for i, p in enumerate(parsed[:3]) if p[2])
opsel_bits = sum((1 << i) for i, p in enumerate(parsed[:2]) if p[3])
vop3sd_ops = {'v_div_scale_f32', 'v_div_scale_f64'}
if mnemonic in vop3sd_ops and len(parsed) >= 5:
neg_bits = sum((1 << i) for i, p in enumerate(parsed[2:5]) if p[1])

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,10 @@
# RDNA3 emulator - pure Python implementation for testing
from __future__ import annotations
import ctypes, struct, math
from extra.assembly.rdna3.lib import Inst32, Inst64, RawImm
from typing import Callable
from extra.assembly.rdna3.lib import Inst, Inst32, Inst64, RawImm
Program = dict[int, Inst] # pc (word offset) -> instruction
from extra.assembly.rdna3.autogen import (
SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD, SrcEnum,
SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, VOPDOp
@@ -32,23 +35,19 @@ def mem_read(addr: int, size: int) -> int: return _ctypes_at(addr, size).value i
def mem_write(addr: int, size: int, val: int) -> None:
if _mem_valid(addr, size): _ctypes_at(addr, size).value = val
# Memory op tables
FLAT_LOAD = {GLOBALOp.GLOBAL_LOAD_B32: (1,4,0), FLATOp.FLAT_LOAD_B32: (1,4,0), GLOBALOp.GLOBAL_LOAD_B64: (2,4,0), FLATOp.FLAT_LOAD_B64: (2,4,0),
GLOBALOp.GLOBAL_LOAD_B96: (3,4,0), FLATOp.FLAT_LOAD_B96: (3,4,0), GLOBALOp.GLOBAL_LOAD_B128: (4,4,0), FLATOp.FLAT_LOAD_B128: (4,4,0),
GLOBALOp.GLOBAL_LOAD_U8: (1,1,0), FLATOp.FLAT_LOAD_U8: (1,1,0), GLOBALOp.GLOBAL_LOAD_I8: (1,1,1), FLATOp.FLAT_LOAD_I8: (1,1,1),
GLOBALOp.GLOBAL_LOAD_U16: (1,2,0), FLATOp.FLAT_LOAD_U16: (1,2,0), GLOBALOp.GLOBAL_LOAD_I16: (1,2,1), FLATOp.FLAT_LOAD_I16: (1,2,1)}
FLAT_STORE = {GLOBALOp.GLOBAL_STORE_B32: (1,4), FLATOp.FLAT_STORE_B32: (1,4), GLOBALOp.GLOBAL_STORE_B64: (2,4), FLATOp.FLAT_STORE_B64: (2,4),
GLOBALOp.GLOBAL_STORE_B96: (3,4), FLATOp.FLAT_STORE_B96: (3,4), GLOBALOp.GLOBAL_STORE_B128: (4,4), FLATOp.FLAT_STORE_B128: (4,4),
GLOBALOp.GLOBAL_STORE_B8: (1,1), FLATOp.FLAT_STORE_B8: (1,1), GLOBALOp.GLOBAL_STORE_B16: (1,2), FLATOp.FLAT_STORE_B16: (1,2)}
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0),
DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
FLAT_D16_LO = {FLATOp.FLAT_LOAD_D16_U8: (1, 0), FLATOp.FLAT_LOAD_D16_I8: (1, 1), FLATOp.FLAT_LOAD_D16_B16: (2, 0),
GLOBALOp.GLOBAL_LOAD_D16_U8: (1, 0), GLOBALOp.GLOBAL_LOAD_D16_I8: (1, 1), GLOBALOp.GLOBAL_LOAD_D16_B16: (2, 0)}
FLAT_D16_HI = {FLATOp.FLAT_LOAD_D16_HI_U8: (1, 0), FLATOp.FLAT_LOAD_D16_HI_I8: (1, 1), FLATOp.FLAT_LOAD_D16_HI_B16: (2, 0),
GLOBALOp.GLOBAL_LOAD_D16_HI_U8: (1, 0), GLOBALOp.GLOBAL_LOAD_D16_HI_I8: (1, 1), GLOBALOp.GLOBAL_LOAD_D16_HI_B16: (2, 0)}
FLAT_D16_STORE = {FLATOp.FLAT_STORE_D16_HI_B8: 1, FLATOp.FLAT_STORE_D16_HI_B16: 2, GLOBALOp.GLOBAL_STORE_D16_HI_B8: 1, GLOBALOp.GLOBAL_STORE_D16_HI_B16: 2}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# Memory op tables - (cnt, sz, sign) for loads, (cnt, sz) for stores
def _mem_ops(ops, suffix_map):
return {getattr(e, f"{p}_{s}"): v for e in ops for s, v in suffix_map.items() for p in [e.__name__.replace("Op", "")]}
_LOAD_MAP = {'LOAD_B32': (1,4,0), 'LOAD_B64': (2,4,0), 'LOAD_B96': (3,4,0), 'LOAD_B128': (4,4,0), 'LOAD_U8': (1,1,0), 'LOAD_I8': (1,1,1), 'LOAD_U16': (1,2,0), 'LOAD_I16': (1,2,1)}
_STORE_MAP = {'STORE_B32': (1,4), 'STORE_B64': (2,4), 'STORE_B96': (3,4), 'STORE_B128': (4,4), 'STORE_B8': (1,1), 'STORE_B16': (1,2)}
FLAT_LOAD = _mem_ops([GLOBALOp, FLATOp], _LOAD_MAP)
FLAT_STORE = _mem_ops([GLOBALOp, FLATOp], _STORE_MAP)
DS_LOAD: dict[int, tuple[int,int,int]] = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE: dict[int, tuple[int,int]] = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
FLAT_D16_LO = {getattr(e, f"{e.__name__.replace('Op', '')}_{s}"): v for e in [FLATOp, GLOBALOp] for s, v in [('LOAD_D16_U8', (1, 0)), ('LOAD_D16_I8', (1, 1)), ('LOAD_D16_B16', (2, 0))]}
FLAT_D16_HI = {getattr(e, f"{e.__name__.replace('Op', '')}_{s}"): v for e in [FLATOp, GLOBALOp] for s, v in [('LOAD_D16_HI_U8', (1, 0)), ('LOAD_D16_HI_I8', (1, 1)), ('LOAD_D16_HI_B16', (2, 0))]}
FLAT_D16_STORE = {getattr(e, f"{e.__name__.replace('Op', '')}_{s}"): v for e in [FLATOp, GLOBALOp] for s, v in [('STORE_D16_HI_B8', 1), ('STORE_D16_HI_B16', 2)]}
SMEM_LOAD: dict[int, int] = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
SOPK_WAIT = {SOPKOp.S_WAITCNT_VSCNT, SOPKOp.S_WAITCNT_VMCNT, SOPKOp.S_WAITCNT_EXPCNT, SOPKOp.S_WAITCNT_LGKMCNT}
class WaveState:
@@ -78,7 +77,6 @@ class WaveState:
def wsgpr64(self, i: int, v: int) -> None: self.wsgpr(i, v & 0xffffffff); self.wsgpr(i+1, (v >> 32) & 0xffffffff)
def rsrc(self, v: int, lane: int) -> int:
if v < VCC_LO: return self.sgpr[v]
if v < SGPR_COUNT: return self.sgpr[v]
if v == SCC: return self.scc
if v < 255: return _INLINE_CONSTS[v - 128]
@@ -88,7 +86,7 @@ class WaveState:
def rsrc64(self, v: int, lane: int) -> int:
return self.rsrc(v, lane) | ((self.rsrc(v+1, lane) if v < VCC_LO or 256 <= v <= 511 else 0) << 32)
def pend_sgpr_lane(self, reg: int, lane: int, val: bool) -> None:
def pend_sgpr_lane(self, reg: int, lane: int, val: int) -> None:
if reg not in self._pend_sgpr: self._pend_sgpr[reg] = 0
if val: self._pend_sgpr[reg] |= (1 << lane)
@@ -152,15 +150,14 @@ def exec_sop1(st: WaveState, inst: SOP1) -> int:
if (fn := SALU.get(SOP1_BASE + op)) is None: raise NotImplementedError(f"SOP1 op {op}")
r, scc = fn(s0, 0, st.scc); st.wsgpr(inst.sdst, r); st.scc = scc; return 0
_SOP2_64: dict[int, Callable[[int, int], int]] = {SOP2Op.S_AND_B64: lambda a, b: a & b, SOP2Op.S_OR_B64: lambda a, b: a | b, SOP2Op.S_XOR_B64: lambda a, b: a ^ b}
def exec_sop2(st: WaveState, inst: SOP2) -> int:
s0, s1, op = st.rsrc(inst.ssrc0, 0), st.rsrc(inst.ssrc1, 0), inst.op
# 64-bit ops handled inline
if op == SOP2Op.S_LSHL_B64: r = (st.rsrc64(inst.ssrc0, 0) << (s1 & 0x3f)) & 0xffffffffffffffff; st.wsgpr64(inst.sdst, r); st.scc = int(r != 0); return 0
if op == SOP2Op.S_LSHR_B64: r = st.rsrc64(inst.ssrc0, 0) >> (s1 & 0x3f); st.wsgpr64(inst.sdst, r); st.scc = int(r != 0); return 0
if op == SOP2Op.S_ASHR_I64: r = sext(st.rsrc64(inst.ssrc0, 0), 64) >> (s1 & 0x3f); st.wsgpr64(inst.sdst, r & 0xffffffffffffffff); st.scc = int(r != 0); return 0
if op == SOP2Op.S_AND_B64: r = st.rsrc64(inst.ssrc0, 0) & st.rsrc64(inst.ssrc1, 0); st.wsgpr64(inst.sdst, r); st.scc = int(r != 0); return 0
if op == SOP2Op.S_OR_B64: r = st.rsrc64(inst.ssrc0, 0) | st.rsrc64(inst.ssrc1, 0); st.wsgpr64(inst.sdst, r); st.scc = int(r != 0); return 0
if op == SOP2Op.S_XOR_B64: r = st.rsrc64(inst.ssrc0, 0) ^ st.rsrc64(inst.ssrc1, 0); st.wsgpr64(inst.sdst, r); st.scc = int(r != 0); return 0
if (fn := _SOP2_64.get(op)): r = fn(st.rsrc64(inst.ssrc0, 0), st.rsrc64(inst.ssrc1, 0)); st.wsgpr64(inst.sdst, r); st.scc = int(r != 0); return 0
if op == SOP2Op.S_CSELECT_B64: st.wsgpr64(inst.sdst, st.rsrc64(inst.ssrc0, 0) if st.scc else st.rsrc64(inst.ssrc1, 0)); return 0
if op == SOP2Op.S_FMAC_F32: st.wsgpr(inst.sdst, i32(f32(st.rsgpr(inst.sdst)) + f32(s0) * f32(s1))); return 0
if op == SOP2Op.S_FMAAK_F32: st.wsgpr(inst.sdst, i32(f32(s0) * f32(s1) + f32(inst._literal or 0))); return 0
@@ -175,15 +172,15 @@ def exec_sopc(st: WaveState, inst: SOPC) -> int:
if (fn := SALU.get(SOPC_BASE + op)) is None: raise NotImplementedError(f"SOPC op {op}")
st.scc = fn(s0, s1, st.scc)[1]; return 0
_SOPK_CMP = frozenset((SOPKOp.S_CMPK_EQ_I32, SOPKOp.S_CMPK_LG_I32, SOPKOp.S_CMPK_GT_I32, SOPKOp.S_CMPK_GE_I32,
SOPKOp.S_CMPK_LT_I32, SOPKOp.S_CMPK_LE_I32, SOPKOp.S_CMPK_EQ_U32, SOPKOp.S_CMPK_LG_U32,
SOPKOp.S_CMPK_GT_U32, SOPKOp.S_CMPK_GE_U32, SOPKOp.S_CMPK_LT_U32, SOPKOp.S_CMPK_LE_U32))
def exec_sopk(st: WaveState, inst: SOPK) -> int:
simm, s0, op = inst.simm16, st.rsgpr(inst.sdst), inst.op
if op in SOPK_WAIT: return 0
if (fn := SALU.get(SOPK_BASE + op)) is None: raise NotImplementedError(f"SOPK op {op}")
r, scc = fn(s0, simm, st.scc)
if op not in (SOPKOp.S_CMPK_EQ_I32, SOPKOp.S_CMPK_LG_I32, SOPKOp.S_CMPK_GT_I32, SOPKOp.S_CMPK_GE_I32,
SOPKOp.S_CMPK_LT_I32, SOPKOp.S_CMPK_LE_I32, SOPKOp.S_CMPK_EQ_U32, SOPKOp.S_CMPK_LG_U32,
SOPKOp.S_CMPK_GT_U32, SOPKOp.S_CMPK_GE_U32, SOPKOp.S_CMPK_LT_U32, SOPKOp.S_CMPK_LE_U32):
st.wsgpr(inst.sdst, r)
if op not in _SOPK_CMP: st.wsgpr(inst.sdst, r)
st.scc = scc; return 0
def exec_sopp(st: WaveState, inst: SOPP) -> int:
@@ -297,12 +294,8 @@ def exec_vop3(st: WaveState, inst: VOP3, lane: int) -> None:
if op in (VOP3Op.V_ADD_F64, VOP3Op.V_MUL_F64, VOP3Op.V_FMA_F64, VOP3Op.V_MAX_F64, VOP3Op.V_MIN_F64):
a, b = f64(st.rsrc(src0+1, lane), s0), f64(st.rsrc(src1+1, lane), s1)
c = f64(st.rsrc(src2+1, lane), s2) if op == VOP3Op.V_FMA_F64 else 0.0
if op == VOP3Op.V_ADD_F64: r = a + b
elif op == VOP3Op.V_MUL_F64: r = a * b
elif op == VOP3Op.V_FMA_F64: r = a * b + c
elif op == VOP3Op.V_MAX_F64: r = max(a, b)
else: r = min(a, b)
V[vdst], V[vdst+1] = i64_parts(r); return
rf = a + b if op == VOP3Op.V_ADD_F64 else a * b if op == VOP3Op.V_MUL_F64 else a * b + c if op == VOP3Op.V_FMA_F64 else max(a, b) if op == VOP3Op.V_MAX_F64 else min(a, b)
V[vdst], V[vdst+1] = i64_parts(rf); return
if (fn := VALU.get(op)): V[vdst] = fn(s0, s1, s2); return
raise NotImplementedError(f"VOP3 op {op}")
@@ -362,30 +355,23 @@ def exec_ds(st: WaveState, inst: DS, lane: int, lds: bytearray) -> None:
for i in range(cnt): lds[addr+i*sz:addr+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
else: raise NotImplementedError(f"DS op {op}")
VOPD_OPS = {
VOPDOp.V_DUAL_MUL_F32: lambda a, b: i32(f32(a)*f32(b)), VOPDOp.V_DUAL_ADD_F32: lambda a, b: i32(f32(a)+f32(b)),
VOPDOp.V_DUAL_SUB_F32: lambda a, b: i32(f32(a)-f32(b)), VOPDOp.V_DUAL_SUBREV_F32: lambda a, b: i32(f32(b)-f32(a)),
VOPDOp.V_DUAL_MAX_F32: lambda a, b: i32(max(f32(a), f32(b))), VOPDOp.V_DUAL_MIN_F32: lambda a, b: i32(min(f32(a), f32(b))),
VOPDOp.V_DUAL_MUL_DX9_ZERO_F32: lambda a, b: i32(0.0 if f32(a) == 0.0 or f32(b) == 0.0 else f32(a)*f32(b)),
VOPDOp.V_DUAL_MOV_B32: lambda a, b: a, VOPDOp.V_DUAL_ADD_NC_U32: lambda a, b: (a + b) & 0xffffffff,
VOPDOp.V_DUAL_LSHLREV_B32: lambda a, b: (b << (a & 0x1f)) & 0xffffffff, VOPDOp.V_DUAL_AND_B32: lambda a, b: a & b,
VOPD_OPS: dict[int, Callable[[int, int, int, int, int], int]] = {
VOPDOp.V_DUAL_MUL_F32: lambda a, b, d, l, lit: i32(f32(a)*f32(b)), VOPDOp.V_DUAL_ADD_F32: lambda a, b, d, l, lit: i32(f32(a)+f32(b)),
VOPDOp.V_DUAL_SUB_F32: lambda a, b, d, l, lit: i32(f32(a)-f32(b)), VOPDOp.V_DUAL_SUBREV_F32: lambda a, b, d, l, lit: i32(f32(b)-f32(a)),
VOPDOp.V_DUAL_MAX_F32: lambda a, b, d, l, lit: i32(max(f32(a), f32(b))), VOPDOp.V_DUAL_MIN_F32: lambda a, b, d, l, lit: i32(min(f32(a), f32(b))),
VOPDOp.V_DUAL_MUL_DX9_ZERO_F32: lambda a, b, d, l, lit: i32(0.0 if f32(a) == 0.0 or f32(b) == 0.0 else f32(a)*f32(b)),
VOPDOp.V_DUAL_MOV_B32: lambda a, b, d, l, lit: a, VOPDOp.V_DUAL_ADD_NC_U32: lambda a, b, d, l, lit: (a + b) & 0xffffffff,
VOPDOp.V_DUAL_LSHLREV_B32: lambda a, b, d, l, lit: (b << (a & 0x1f)) & 0xffffffff, VOPDOp.V_DUAL_AND_B32: lambda a, b, d, l, lit: a & b,
VOPDOp.V_DUAL_FMAC_F32: lambda a, b, d, l, lit: i32(f32(a)*f32(b)+f32(d)), VOPDOp.V_DUAL_FMAAK_F32: lambda a, b, d, l, lit: i32(f32(a)*f32(b)+f32(lit)),
VOPDOp.V_DUAL_FMAMK_F32: lambda a, b, d, l, lit: i32(f32(a)*f32(lit)+f32(b)), VOPDOp.V_DUAL_CNDMASK_B32: lambda a, b, d, l, lit: b if l else a,
}
def exec_vopd(st: WaveState, inst: VOPD, lane: int) -> None:
V, vdsty = st.vgpr[lane], (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
sx0, sx1, sy0, sy1 = st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], st.rsrc(inst.srcy0, lane), V[inst.vsrcy1]
opx, opy, dstx = inst.opx, inst.opy, inst.vdstx
if (fn := VOPD_OPS.get(opx)): V[dstx] = fn(sx0, sx1)
elif opx == VOPDOp.V_DUAL_FMAC_F32: V[dstx] = i32(f32(sx0)*f32(sx1)+f32(V[dstx]))
elif opx == VOPDOp.V_DUAL_FMAAK_F32: V[dstx] = i32(f32(sx0)*f32(sx1)+f32(st.literal))
elif opx == VOPDOp.V_DUAL_FMAMK_F32: V[dstx] = i32(f32(sx0)*f32(st.literal)+f32(sx1))
elif opx == VOPDOp.V_DUAL_CNDMASK_B32: V[dstx] = sx1 if (st.vcc >> lane) & 1 else sx0
else: raise NotImplementedError(f"VOPD opx {opx}")
if (fn := VOPD_OPS.get(opy)): V[vdsty] = fn(sy0, sy1)
elif opy == VOPDOp.V_DUAL_FMAC_F32: V[vdsty] = i32(f32(sy0)*f32(sy1)+f32(V[vdsty]))
elif opy == VOPDOp.V_DUAL_FMAAK_F32: V[vdsty] = i32(f32(sy0)*f32(sy1)+f32(st.literal))
elif opy == VOPDOp.V_DUAL_FMAMK_F32: V[vdsty] = i32(f32(sy0)*f32(st.literal)+f32(sy1))
elif opy == VOPDOp.V_DUAL_CNDMASK_B32: V[vdsty] = sy1 if (st.vcc >> lane) & 1 else sy0
else: raise NotImplementedError(f"VOPD opy {opy}")
V, vdsty, vcc_lane = st.vgpr[lane], (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1), (st.vcc >> lane) & 1
sx0, sx1, sy0, sy1, dstx = st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], inst.vdstx
if (fn := VOPD_OPS.get(inst.opx)): V[dstx] = fn(sx0, sx1, V[dstx], vcc_lane, st.literal)
else: raise NotImplementedError(f"VOPD opx {inst.opx}")
if (fn := VOPD_OPS.get(inst.opy)): V[vdsty] = fn(sy0, sy1, V[vdsty], vcc_lane, st.literal)
else: raise NotImplementedError(f"VOPD opy {inst.opy}")
def exec_vop3p(st: WaveState, inst: VOP3P, lane: int) -> None:
op, vdst, V = inst.op, inst.vdst, st.vgpr[lane]
@@ -435,8 +421,8 @@ def exec_wmma_f32_16x16x16_f16(st: WaveState, inst: VOP3P, n_lanes: int) -> None
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION LOOP
# ═══════════════════════════════════════════════════════════════════════════════
SCALAR = {SOP1: exec_sop1, SOP2: exec_sop2, SOPC: exec_sopc, SOPK: exec_sopk, SOPP: exec_sopp, SMEM: exec_smem}
VECTOR = {VOP1: exec_vop1, VOP2: exec_vop2, VOP3: exec_vop3, VOP3SD: exec_vop3sd, VOPC: exec_vopc, FLAT: exec_flat, DS: exec_ds, VOPD: exec_vopd, VOP3P: exec_vop3p}
SCALAR: dict[type, Callable[..., int]] = {SOP1: exec_sop1, SOP2: exec_sop2, SOPC: exec_sopc, SOPK: exec_sopk, SOPP: exec_sopp, SMEM: exec_smem}
VECTOR: dict[type, Callable[..., None]] = {VOP1: exec_vop1, VOP2: exec_vop2, VOP3: exec_vop3, VOP3SD: exec_vop3sd, VOPC: exec_vopc, FLAT: exec_flat, DS: exec_ds, VOPD: exec_vopd, VOP3P: exec_vop3p}
_WMMA_OPS = frozenset((VOP3POp.V_WMMA_F32_16X16X16_F16, VOP3POp.V_WMMA_F32_16X16X16_BF16, VOP3POp.V_WMMA_F16_16X16X16_F16,
VOP3POp.V_WMMA_BF16_16X16X16_BF16, VOP3POp.V_WMMA_I32_16X16X16_IU8, VOP3POp.V_WMMA_I32_16X16X16_IU4))
@@ -449,20 +435,35 @@ def step_wave(program: Program, st: WaveState, lds: bytearray, n_lanes: int) ->
delta = handler(st, inst)
if delta == -1: return -1
if delta == -2: st.pc += inst_words; return -2
if delta == -3: next_pc = (st.pc + inst_words) * 4; st.wsgpr(inst.sdst, next_pc & 0xffffffff); st.wsgpr(inst.sdst + 1, (next_pc >> 32) & 0xffffffff); st.pc += inst_words; return 0
if delta == -4: st.pc = st.rsrc64(inst.ssrc0, 0) // 4; return 0
if delta == -5: next_pc = (st.pc + inst_words) * 4; st.wsgpr(inst.sdst, next_pc & 0xffffffff); st.wsgpr(inst.sdst + 1, (next_pc >> 32) & 0xffffffff); st.pc = st.rsrc64(inst.ssrc0, 0) // 4; return 0
if delta == -3: # S_GETPC_B64
sop1 = inst if isinstance(inst, SOP1) else None
assert sop1 is not None
next_pc = (st.pc + inst_words) * 4; st.wsgpr(sop1.sdst, next_pc & 0xffffffff); st.wsgpr(sop1.sdst + 1, (next_pc >> 32) & 0xffffffff); st.pc += inst_words; return 0
if delta == -4: # S_SETPC_B64
sop1 = inst if isinstance(inst, SOP1) else None
assert sop1 is not None
st.pc = st.rsrc64(sop1.ssrc0, 0) // 4; return 0
if delta == -5: # S_SWAPPC_B64
sop1 = inst if isinstance(inst, SOP1) else None
assert sop1 is not None
next_pc = (st.pc + inst_words) * 4; st.wsgpr(sop1.sdst, next_pc & 0xffffffff); st.wsgpr(sop1.sdst + 1, (next_pc >> 32) & 0xffffffff); st.pc = st.rsrc64(sop1.ssrc0, 0) // 4; return 0
st.pc += inst_words + delta
else:
handler, exec_mask = VECTOR[inst_type], st.exec_mask
vec_handler, exec_mask = VECTOR[inst_type], st.exec_mask
if inst_type is DS:
for lane in range(n_lanes):
if exec_mask & (1 << lane): handler(st, inst, lane, lds)
elif inst_type is VOP3P and inst.op in _WMMA_OPS:
exec_wmma_f32_16x16x16_f16(st, inst, n_lanes)
if exec_mask & (1 << lane): vec_handler(st, inst, lane, lds)
elif inst_type is VOP3P:
vop3p = inst if isinstance(inst, VOP3P) else None
assert vop3p is not None
if vop3p.op in _WMMA_OPS:
exec_wmma_f32_16x16x16_f16(st, vop3p, n_lanes)
else:
for lane in range(n_lanes):
if exec_mask & (1 << lane): vec_handler(st, vop3p, lane)
else:
for lane in range(n_lanes):
if exec_mask & (1 << lane): handler(st, inst, lane)
if exec_mask & (1 << lane): vec_handler(st, inst, lane)
st.commit_pends(); st.pc += inst_words
return 0

View File

@@ -5,9 +5,9 @@ from tinygrad.helpers import fetch
PDF_URL = "https://docs.amd.com/api/khub/documents/UVVZM22UN7tMUeiW_4ShTQ/content"
FIELD_TYPES = {'SSRC0': 'SSrc', 'SSRC1': 'SSrc', 'SOFFSET': 'SSrc', 'SADDR': 'SSrc', 'SRC0': 'Src', 'SRC1': 'Src', 'SRC2': 'Src',
'SDST': 'SGPR', 'SBASE': 'SGPR', 'SDATA': 'SGPR', 'SRSRC': 'SGPR', 'VDST': 'VGPR', 'VSRC1': 'VGPR', 'VDATA': 'VGPR',
'VADDR': 'VGPR', 'ADDR': 'VGPR', 'DATA': 'VGPR', 'DATA0': 'VGPR', 'DATA1': 'VGPR', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPR', 'VSRCY1': 'VGPR', 'VDSTX': 'VGPR', 'VDSTY': 'VDSTYEnc'}
'SDST': 'SGPRField', 'SBASE': 'SGPRField', 'SDATA': 'SGPRField', 'SRSRC': 'SGPRField', 'VDST': 'VGPRField', 'VSRC1': 'VGPRField', 'VDATA': 'VGPRField',
'VADDR': 'VGPRField', 'ADDR': 'VGPRField', 'DATA': 'VGPRField', 'DATA0': 'VGPRField', 'DATA1': 'VGPRField', 'SIMM16': 'SImm', 'OFFSET': 'Imm',
'OPX': 'VOPDOp', 'OPY': 'VOPDOp', 'SRCX0': 'Src', 'SRCY0': 'Src', 'VSRCX1': 'VGPRField', 'VSRCY1': 'VGPRField', 'VDSTX': 'VGPRField', 'VDSTY': 'VDSTYEnc'}
FIELD_ORDER = {
'SOP2': ['op', 'sdst', 'ssrc0', 'ssrc1'], 'SOP1': ['op', 'sdst', 'ssrc0'], 'SOPC': ['op', 'ssrc0', 'ssrc1'],
'SOPK': ['op', 'sdst', 'simm16'], 'SOPP': ['op', 'simm16'], 'VOP1': ['op', 'vdst', 'src0'], 'VOPC': ['op', 'src0', 'vsrc1'],
@@ -69,8 +69,8 @@ def generate(output_path: pathlib.Path|str|None = None) -> dict:
for m in re.finditer(r'Table \d+\. (\w+) Opcodes(.*?)(?=Table \d+\.|\n\d+\.\d+\.\d+\.\s+\w+\s*\nDescription|$)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+([A-Z][A-Z0-9_]+)', m.group(2))}:
enums[m.group(1) + "Op"] = ops
if m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', m.group(1))}:
if vopd_m := re.search(r'Table \d+\. VOPD Y-Opcodes\n(.*?)(?=Table \d+\.|15\.\d)', full_text, re.S):
if ops := {int(x.group(1)): x.group(2) for x in re.finditer(r'(\d+)\s+(V_DUAL_\w+)', vopd_m.group(1))}:
enums["VOPDOp"] = ops
enum_names = set(enums.keys())
@@ -118,14 +118,14 @@ def generate(output_path: pathlib.Path|str|None = None) -> dict:
field_names = {f[0] for f in fields}
# check next pages for continuation fields (tables without ENCODING)
for offset in range(1, 3):
if page_idx + offset >= len(pages) or has_header_before_fields(page_texts[page_idx + offset]): break
for t in page_tables[page_idx + offset]:
for pg_offset in range(1, 3):
if page_idx + pg_offset >= len(pages) or has_header_before_fields(page_texts[page_idx + pg_offset]): break
for t in page_tables[page_idx + pg_offset]:
if is_fields_table(t) and (extra := parse_fields_table(t, fmt_name, enum_names)) and not has_encoding(extra):
for f in extra:
if f[0] not in field_names:
fields.append(f)
field_names.add(f[0])
for ef in extra:
if ef[0] not in field_names:
fields.append(ef)
field_names.add(ef[0])
break
formats[fmt_name] = fields
@@ -140,7 +140,8 @@ def generate(output_path: pathlib.Path|str|None = None) -> dict:
return [f"class {name}(IntEnum):"] + [f" {n} = {v}" for v, n in sorted(items.items())] + [""]
def field_key(f): return order.index(f[0].lower()) if f[0].lower() in order else 1000
lines = ["# autogenerated from AMD RDNA3.5 ISA PDF by gen.py - do not edit", "from enum import IntEnum",
"from extra.assembly.rdna3.lib import bits, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, SSrc, Src, SImm, Imm, VDSTYEnc",
"from typing import Annotated",
"from extra.assembly.rdna3.lib import bits, BitField, Inst32, Inst64, SGPR, VGPR, TTMP as TTMP, s as s, v as v, ttmp as ttmp, SSrc, Src, SImm, Imm, VDSTYEnc, SGPRField, VGPRField",
"import functools", ""]
lines += enum_lines("SrcEnum", src_enum) + sum([enum_lines(n, ops) for n, ops in sorted(enums.items())], [])
# Format-specific field defaults (verified against LLVM test vectors)
@@ -156,17 +157,27 @@ def generate(output_path: pathlib.Path|str|None = None) -> dict:
if defaults := format_defaults.get(fmt_name):
lines.append(f" _defaults = {defaults}")
for name, hi, lo, _, ftype in sorted([f for f in fields if f[0] != 'ENCODING'], key=field_key):
typ = f":{ftype}" if ftype else ""
lines.append(f" {name.lower()}{typ} = bits[{hi}]" if hi == lo else f" {name.lower()}{typ} = bits[{hi}:{lo}]")
# Wrap IntEnum types (ending in Op) with Annotated[BitField, ...] for correct typing
if ftype and ftype.endswith('Op'):
ann = f":Annotated[BitField, {ftype}]"
else:
ann = f":{ftype}" if ftype else ""
lines.append(f" {name.lower()}{ann} = bits[{hi}]" if hi == lo else f" {name.lower()}{ann} = bits[{hi}:{lo}]")
lines.append("")
lines.append("# instruction helpers")
for cls_name, ops in sorted(enums.items()):
fmt = cls_name[:-2]
for _, name in sorted(ops.items()):
for op_val, name in sorted(ops.items()):
seg = {"GLOBAL": ", seg=2", "SCRATCH": ", seg=2"}.get(fmt, "")
tgt = {"GLOBAL": "FLAT, GLOBALOp", "SCRATCH": "FLAT, SCRATCHOp"}.get(fmt, f"{fmt}, {cls_name}")
if fmt in formats or fmt in ("GLOBAL", "SCRATCH"):
suffix = "_e32" if fmt in ("VOP1", "VOP2") else ""
# VOP1/VOP2/VOPC get _e32 suffix, VOP3 promoted ops (< 512) get _e64 suffix
if fmt in ("VOP1", "VOP2", "VOPC"):
suffix = "_e32"
elif fmt == "VOP3" and op_val < 512:
suffix = "_e64"
else:
suffix = ""
lines.append(f"{name.lower()}{suffix} = functools.partial({tgt}.{name}{seg})")
# export SrcEnum values, but skip DPP8/DPP16 which conflict with class names
skip_exports = {'DPP8', 'DPP16'}

View File

@@ -1,19 +1,36 @@
# library for RDNA3 assembly DSL
from __future__ import annotations
from enum import IntEnum
from typing import overload, Annotated, TypeVar, Generic
# Bit field DSL
class BitField:
def __init__(self, hi: int, lo: int, name: str | None = None): self.hi, self.lo, self.name = hi, lo, name
def __set_name__(self, owner, name): self.name = name
def __set_name__(self, owner, name): self.name, self._owner = name, owner
def __eq__(self, val: int) -> tuple[BitField, int]: return (self, val) # type: ignore
def mask(self) -> int: return (1 << (self.hi - self.lo + 1)) - 1
@property
def marker(self) -> type | None:
# Get marker from Annotated type hint if present
import typing
if hasattr(self, '_owner') and self.name:
hints = typing.get_type_hints(self._owner, include_extras=True)
if self.name in hints:
hint = hints[self.name]
if typing.get_origin(hint) is Annotated:
args = typing.get_args(hint)
return args[1] if len(args) > 1 else None
return None
@overload
def __get__(self, obj: None, objtype: type) -> BitField: ...
@overload
def __get__(self, obj: object, objtype: type | None = None) -> int: ...
def __get__(self, obj, objtype=None):
if obj is None: return self
val = unwrap(obj._values.get(self.name, 0))
ann = getattr(type(obj), '__annotations__', {}).get(self.name)
if ann and isinstance(ann, type) and issubclass(ann, IntEnum):
try: return ann(val)
# Convert to IntEnum if marker is an IntEnum subclass
if self.marker and isinstance(self.marker, type) and issubclass(self.marker, IntEnum):
try: return self.marker(val)
except ValueError: pass
return val
@@ -25,19 +42,42 @@ bits = _Bits()
class Reg:
def __init__(self, idx: int, count: int = 1, hi: bool = False): self.idx, self.count, self.hi = idx, count, hi
def __repr__(self): return f"{self.__class__.__name__.lower()[0]}[{self.idx}]" if self.count == 1 else f"{self.__class__.__name__.lower()[0]}[{self.idx}:{self.idx + self.count}]"
@classmethod
def __class_getitem__(cls, key): return cls(key.start, key.stop - key.start) if isinstance(key, slice) else cls(key)
T = TypeVar('T', bound=Reg)
class _RegFactory(Generic[T]):
def __init__(self, cls: type[T], name: str): self._cls, self._name = cls, name
@overload
def __getitem__(self, key: int) -> Reg: ...
@overload
def __getitem__(self, key: slice) -> Reg: ...
def __getitem__(self, key: int | slice) -> Reg:
return self._cls(key.start, key.stop - key.start) if isinstance(key, slice) else self._cls(key)
def __repr__(self): return f"<{self._name} factory>"
class SGPR(Reg): pass
class VGPR(Reg): pass
class TTMP(Reg): pass
s, v = SGPR, VGPR
s: _RegFactory[SGPR] = _RegFactory(SGPR, "SGPR")
v: _RegFactory[VGPR] = _RegFactory(VGPR, "VGPR")
ttmp: _RegFactory[TTMP] = _RegFactory(TTMP, "TTMP")
# Field type markers
class SSrc: pass
class Src: pass
class Imm: pass
class SImm: pass
class VDSTYEnc: pass # VOPD vdsty: encoded = actual >> 1, actual = (encoded << 1) | ((vdstx & 1) ^ 1)
# Field type markers (runtime classes for validation)
class _SSrc: pass
class _Src: pass
class _Imm: pass
class _SImm: pass
class _VDSTYEnc: pass # VOPD vdsty: encoded = actual >> 1, actual = (encoded << 1) | ((vdstx & 1) ^ 1)
class _SGPRField: pass
class _VGPRField: pass
# Type aliases for annotations - tells mypy it's a BitField while preserving marker info
SSrc = Annotated[BitField, _SSrc]
Src = Annotated[BitField, _Src]
Imm = Annotated[BitField, _Imm]
SImm = Annotated[BitField, _SImm]
VDSTYEnc = Annotated[BitField, _VDSTYEnc]
SGPRField = Annotated[BitField, _SGPRField]
VGPRField = Annotated[BitField, _VGPRField]
class RawImm:
def __init__(self, val: int): self.val = val
def __repr__(self): return f"RawImm({self.val})"
@@ -51,14 +91,15 @@ FLOAT_ENC = {0.5: 240, -0.5: 241, 1.0: 242, -1.0: 243, 2.0: 244, -2.0: 245, 4.0:
SRC_FIELDS = {'src0', 'src1', 'src2', 'ssrc0', 'ssrc1', 'soffset', 'srcx0', 'srcy0'}
RAW_FIELDS = {'vdata', 'vdst', 'vaddr', 'addr', 'data', 'data0', 'data1', 'sdst', 'sdata'}
def encode_src(val) -> int:
if isinstance(val, SGPR): return val.idx | (0x80 if val.hi else 0)
if isinstance(val, VGPR): return 256 + val.idx + (0x80 if val.hi else 0)
def _encode_reg(val) -> int:
if isinstance(val, TTMP): return 108 + val.idx
return val.idx | (0x80 if val.hi else 0)
def encode_src(val) -> int:
if isinstance(val, VGPR): return 256 + _encode_reg(val)
if isinstance(val, Reg): return _encode_reg(val)
if hasattr(val, 'value'): return val.value
if isinstance(val, float):
if val == 0.0: return 128 # 0.0 encodes as integer constant 0
return FLOAT_ENC.get(val, 255)
if isinstance(val, float): return 128 if val == 0.0 else FLOAT_ENC.get(val, 255)
return 128 + val if isinstance(val, int) and 0 <= val <= 64 else 192 + (-val) if isinstance(val, int) and -16 <= val <= -1 else 255
# Instruction base class
@@ -66,6 +107,9 @@ class Inst:
_fields: dict[str, BitField]
_encoding: tuple[BitField, int] | None = None
_defaults: dict[str, int] = {}
_values: dict[str, int | RawImm]
_words: int # size in 32-bit words, set by decode_program
_literal: int | None
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
@@ -76,10 +120,6 @@ class Inst:
self._values, self._literal = dict(self._defaults), literal
self._values.update(zip([n for n in self._fields if n != 'encoding'], args))
self._values.update(kwargs)
# Get annotations from class hierarchy
annotations = {}
for cls in type(self).__mro__:
annotations.update(getattr(cls, '__annotations__', {}))
# Type check and encode values
for name, val in list(self._values.items()):
if name == 'encoding': continue
@@ -87,14 +127,15 @@ class Inst:
if isinstance(val, RawImm):
if name in RAW_FIELDS: self._values[name] = val.val
continue
ann = annotations.get(name)
field = self._fields.get(name)
marker = field.marker if field else None
# Type validation
if ann is SGPR:
if marker is _SGPRField:
if isinstance(val, VGPR): raise TypeError(f"field '{name}' requires SGPR, got VGPR")
if not isinstance(val, (SGPR, TTMP, int, RawImm)): raise TypeError(f"field '{name}' requires SGPR, got {type(val).__name__}")
if ann is VGPR:
if marker is _VGPRField:
if not isinstance(val, VGPR): raise TypeError(f"field '{name}' requires VGPR, got {type(val).__name__}")
if ann is SSrc and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires scalar source, got VGPR")
if marker is _SSrc and isinstance(val, VGPR): raise TypeError(f"field '{name}' requires scalar source, got VGPR")
# Encode source fields as RawImm for consistent disassembly
if name in SRC_FIELDS:
encoded = encode_src(val)
@@ -104,27 +145,22 @@ class Inst:
self._literal = val
# Encode raw register fields for consistent repr
elif name in RAW_FIELDS:
if isinstance(val, Reg):
self._values[name] = (108 + val.idx) if isinstance(val, TTMP) else (val.idx | (0x80 if val.hi else 0))
elif hasattr(val, 'value'): # IntEnum like SrcEnum.NULL
self._values[name] = val.value
if isinstance(val, Reg): self._values[name] = _encode_reg(val)
elif hasattr(val, 'value'): self._values[name] = val.value # IntEnum like SrcEnum.NULL
# Encode sbase (divided by 2) and srsrc/ssamp (divided by 4)
elif name == 'sbase' and isinstance(val, Reg):
self._values[name] = val.idx // 2
elif name in {'srsrc', 'ssamp'} and isinstance(val, Reg):
self._values[name] = val.idx // 4
# VOPD vdsty: encode as actual >> 1 (constraint: vdsty parity must be opposite of vdstx)
elif ann is VDSTYEnc and isinstance(val, VGPR):
elif marker is _VDSTYEnc and isinstance(val, VGPR):
self._values[name] = val.idx >> 1
def _encode_field(self, name: str, val) -> int:
if isinstance(val, RawImm): return val.val
if name in {'srsrc', 'ssamp'}: return val.idx // 4 if isinstance(val, Reg) else val
if name == 'sbase': return val.idx // 2 if isinstance(val, Reg) else val
if name in RAW_FIELDS:
if isinstance(val, TTMP): return 108 + val.idx
if isinstance(val, Reg): return val.idx | (0x80 if val.hi else 0)
return val
if name in RAW_FIELDS: return _encode_reg(val) if isinstance(val, Reg) else val
if isinstance(val, Reg) or name in SRC_FIELDS: return encode_src(val)
return val.value if hasattr(val, 'value') else val
@@ -178,9 +214,4 @@ class Inst:
return disasm(self)
class Inst32(Inst): pass
class Inst64(Inst):
def to_bytes(self) -> bytes:
result = self.to_int().to_bytes(8, 'little')
return result + (lit & 0xffffffff).to_bytes(4, 'little') if (lit := self._get_literal() or getattr(self, '_literal', None)) else result
@classmethod
def from_bytes(cls, data: bytes): return cls.from_int(int.from_bytes(data[:8], 'little'))
class Inst64(Inst): pass

View File

@@ -26,7 +26,7 @@ def count_instructions(kernel: bytes) -> int:
"""Count instructions in a kernel."""
return len(decode_program(kernel))
def setup_buffers(buf_sizes: list[int], init_data: dict[int, bytes] = None):
def setup_buffers(buf_sizes: list[int], init_data: dict[int, bytes] | None = None):
"""Allocate buffers and return args pointer + valid ranges."""
if init_data is None: init_data = {}
buffers = []
@@ -46,7 +46,7 @@ def benchmark_emulator(name: str, run_fn, kernel: bytes, global_size, local_size
"""Benchmark an emulator and return average time."""
gx, gy, gz = global_size
lx, ly, lz = local_size
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
lib_ptr = ctypes.addressof(kernel_buf)
# Warmup
@@ -132,7 +132,7 @@ def profile_python_emu(kernel: bytes, global_size, local_size, args_ptr, n_runs:
"""Profile the Python emulator to find bottlenecks."""
gx, gy, gz = global_size
lx, ly, lz = local_size
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
lib_ptr = ctypes.addressof(kernel_buf)
pr = cProfile.Profile()

View File

@@ -68,7 +68,7 @@ class RustEmulator:
self.ctx = None
def create(self, kernel: bytes, n_lanes: int):
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
self.ctx = self.lib.wave_create(ctypes.addressof(kernel_buf), len(kernel), n_lanes)
self._kernel_buf = kernel_buf
@@ -87,9 +87,9 @@ class RustEmulator:
class PythonEmulator:
def __init__(self):
self.state: WaveState = None
self.program = None
self.lds = None
self.state: WaveState | None = None
self.program: dict | None = None
self.lds: bytearray | None = None
self.n_lanes = 0
def create(self, kernel: bytes, n_lanes: int):
@@ -99,11 +99,18 @@ class PythonEmulator:
self.lds = bytearray(65536)
self.n_lanes = n_lanes
def step(self) -> int: return step_wave(self.program, self.state, self.lds, self.n_lanes)
def set_sgpr(self, idx: int, val: int): self.state.sgpr[idx] = val & 0xffffffff
def set_vgpr(self, lane: int, idx: int, val: int): self.state.vgpr[lane][idx] = val & 0xffffffff
def step(self) -> int:
assert self.program is not None and self.state is not None and self.lds is not None
return step_wave(self.program, self.state, self.lds, self.n_lanes)
def set_sgpr(self, idx: int, val: int):
assert self.state is not None
self.state.sgpr[idx] = val & 0xffffffff
def set_vgpr(self, lane: int, idx: int, val: int):
assert self.state is not None
self.state.vgpr[lane][idx] = val & 0xffffffff
def get_snapshot(self) -> StateSnapshot:
assert self.state is not None
return StateSnapshot(pc=self.state.pc, scc=self.state.scc, vcc=self.state.vcc & 0xffffffff,
exec_mask=self.state.exec_mask & 0xffffffff, sgpr=list(self.state.sgpr),
vgpr=[list(self.state.vgpr[i]) for i in range(WAVE_SIZE)])
@@ -129,7 +136,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
# Initialize LDS (64KB, standard size for AMD GPUs)
rust.init_lds(65536)
for emu in [rust, python]:
for emu in (rust, python):
emu.set_sgpr(0, args_ptr & 0xffffffff)
emu.set_sgpr(1, (args_ptr >> 32) & 0xffffffff)
emu.set_sgpr(13, gidx)
@@ -187,7 +194,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
return True, f"Completed {gx*gy*gz} workgroups", total_steps
def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int, int], max_steps: int = 1000,
debug: bool = False, trace_len: int = 10, buf_data: dict[int, bytes] = None) -> tuple[bool, str]:
debug: bool = False, trace_len: int = 10, buf_data: dict[int, bytes] | None = None) -> tuple[bool, str]:
"""Run all kernels through both emulators with shared buffer pool."""
from extra.assembly.rdna3.emu import set_valid_mem_ranges, decode_program
if buf_data is None: buf_data = {}

View File

@@ -16,7 +16,7 @@ def run_kernel(kernel: bytes, n_threads: int = 1, n_outputs: int = 1) -> list[in
output_ptr = ctypes.addressof(output)
args = (ctypes.c_uint64 * 1)(output_ptr)
args_ptr = ctypes.addressof(args)
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
kernel_ptr = ctypes.addressof(kernel_buf)
# Register valid memory ranges for bounds checking
set_valid_mem_ranges({
@@ -344,7 +344,7 @@ class TestMemory(unittest.TestCase):
kernel += global_store_b32(addr=v[2], data=v[1], saddr=s[2]).to_bytes()
kernel += s_endpgm().to_bytes()
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
kernel_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({
(input_ptr, ctypes.sizeof(input_buf)),
@@ -434,7 +434,7 @@ class TestVOP3(unittest.TestCase):
v_mov_b32_e32(v[1], i32(5.0)),
v_mov_b32_e32(v[2], i32(3.0)),
# v_add_f32 with neg on src1: 5 + (-3) = 2
v_add_f32(v[1], v[1], v[2], neg=0b010),
v_add_f32_e64(v[1], v[1], v[2], neg=0b010),
])
out = run_kernel(kernel, n_threads=1)
self.assertEqual(f32(out[0]), 2.0)
@@ -475,7 +475,7 @@ class TestVOP3(unittest.TestCase):
"""Regression test: V_SQRT_F32 should return NaN for negative inputs, not 0."""
kernel = make_store_kernel([
v_mov_b32_e32(v[1], i32(-1.0)),
v_sqrt_f32(v[1], v[1]),
v_sqrt_f32_e32(v[1], v[1]),
])
out = run_kernel(kernel, n_threads=1)
self.assertTrue(math.isnan(f32(out[0])))
@@ -484,7 +484,7 @@ class TestVOP3(unittest.TestCase):
"""Regression test: V_RSQ_F32 should return NaN for negative inputs, not inf."""
kernel = make_store_kernel([
v_mov_b32_e32(v[1], i32(-1.0)),
v_rsq_f32(v[1], v[1]),
v_rsq_f32_e32(v[1], v[1]),
])
out = run_kernel(kernel, n_threads=1)
self.assertTrue(math.isnan(f32(out[0])))
@@ -655,7 +655,7 @@ class TestMultiWave(unittest.TestCase):
kernel += global_store_b32(addr=v[1], data=v[0], saddr=s[2]).to_bytes()
kernel += s_endpgm().to_bytes()
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
kernel_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({
(output_ptr, ctypes.sizeof(output)),
@@ -697,7 +697,7 @@ class TestRegressions(unittest.TestCase):
output_ptr = ctypes.addressof(output)
args = (ctypes.c_uint64 * 1)(output_ptr)
args_ptr = ctypes.addressof(args)
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
kernel_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({(output_ptr, 8), (args_ptr, 8), (kernel_ptr, len(kernel))})
run_asm(kernel_ptr, len(kernel), 1, 1, 1, 1, 1, 1, args_ptr)
@@ -726,7 +726,7 @@ class TestRegressions(unittest.TestCase):
output_ptr = ctypes.addressof(output)
args = (ctypes.c_uint64 * 1)(output_ptr)
args_ptr = ctypes.addressof(args)
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
kernel_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({(output_ptr, 8), (args_ptr, 8), (kernel_ptr, len(kernel))})
run_asm(kernel_ptr, len(kernel), 1, 1, 1, 1, 1, 1, args_ptr)
@@ -754,7 +754,7 @@ class TestRegressions(unittest.TestCase):
kernel += global_store_b32(addr=v[3], data=v[1], saddr=s[0]).to_bytes()
kernel += s_endpgm().to_bytes()
kernel_buf = (ctypes.c_char * len(kernel))(*kernel)
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
kernel_ptr = ctypes.addressof(kernel_buf)
set_valid_mem_ranges({(output_ptr, 4), (src_ptr, 2), (args_ptr, 16), (kernel_ptr, len(kernel))})
run_asm(kernel_ptr, len(kernel), 1, 1, 1, 1, 1, 1, args_ptr)

View File

@@ -3,16 +3,18 @@
import unittest
from extra.assembly.rdna3.autogen import *
from extra.assembly.rdna3.lib import Inst
from extra.assembly.rdna3.asm import asm
from extra.assembly.rdna3.test.test_roundtrip import compile_asm
class TestIntegration(unittest.TestCase):
inst: Inst
def tearDown(self):
if not hasattr(self, 'inst'): return
b = self.inst.to_bytes()
st = self.inst.disasm()
reasm = asm(st)
desc = f"{st:25s} {self.inst} {b} {reasm}"
desc = f"{st:25s} {self.inst} {b!r} {reasm}"
self.assertEqual(b, compile_asm(st), desc)
# TODO: this compare should work for valid things
#self.assertEqual(self.inst, reasm)

View File

@@ -35,7 +35,7 @@ EXPECTED_FORMATS = {
class TestPDFParserGenerate(unittest.TestCase):
"""Test the PDF parser by running generate() and checking results."""
result = None
result: dict
@classmethod
def setUpClass(cls):
@@ -103,6 +103,7 @@ class TestPDFParser(unittest.TestCase):
self.assertNotIn('simm16', SOP1._fields)
self.assertEqual(SOP1._fields['ssrc0'].hi, 7)
self.assertEqual(SOP1._fields['ssrc0'].lo, 0)
assert SOP1._encoding is not None
self.assertEqual(SOP1._encoding[0].hi, 31)
self.assertEqual(SOP1._encoding[1], 0b101111101)
@@ -132,6 +133,7 @@ class TestPDFParser(unittest.TestCase):
(FLAT, 31, 26, 0b110111),
]
for cls, hi, lo, val in tests:
assert cls._encoding is not None
self.assertEqual(cls._encoding[0].hi, hi, f"{cls.__name__} encoding hi")
self.assertEqual(cls._encoding[0].lo, lo, f"{cls.__name__} encoding lo")
self.assertEqual(cls._encoding[1], val, f"{cls.__name__} encoding val")