assembly/amd: clean up pcode, jit pcode instead of static (#14001)

* assembly/amd: clean up pcode

* regen

* lil

* jit the pcode

* sendmsg

* cleanups

* inst prefetch lol
This commit is contained in:
George Hotz
2026-01-04 02:06:15 -05:00
committed by GitHub
parent 280790e438
commit 34ea053b26
12 changed files with 4577 additions and 31749 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -5,7 +5,8 @@ import ctypes, functools
from tinygrad.runtime.autogen import hsa
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import COMPILED_FUNCTIONS
from extra.assembly.amd.pcode import compile_pseudocode
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
SrcEnum, SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, SMEMOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, DSOp, FLATOp, GLOBALOp, SCRATCHOp, VOPDOp)
@@ -236,9 +237,8 @@ def exec_vopd(st: WaveState, inst, V: list, lane: int) -> None:
"""VOPD: dual-issue, execute two ops simultaneously (read all inputs before writes)."""
literal, vdstx, vdsty = inst._literal, inst.vdstx, (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
sx0, sx1, dx, sy0, sy1, dy = st.rsrc(inst.srcx0, lane, literal), V[inst.vsrcx1], V[vdstx], st.rsrc(inst.srcy0, lane, literal), V[inst.vsrcy1], V[vdsty]
opx, opy = _VOPD_TO_VOP[inst.opx], _VOPD_TO_VOP[inst.opy]
V[vdstx] = COMPILED_FUNCTIONS[type(opx)][opx](sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
V[vdsty] = COMPILED_FUNCTIONS[type(opy)][opy](sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
V[vdstx] = inst._fnx(sx0, sx1, 0, dx, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
V[vdsty] = inst._fny(sy0, sy1, 0, dy, st.scc, st.vcc, lane, st.exec_mask, literal, None)['D0']
def exec_flat(st: WaveState, inst, V: list, lane: int) -> None:
"""FLAT/GLOBAL/SCRATCH memory ops."""
@@ -359,15 +359,14 @@ def decode_program(data: bytes) -> dict[int, Inst]:
result: dict[int, Inst] = {}
i = 0
while i < len(data):
try: inst_class = detect_format(data[i:])
except ValueError: break # stop at invalid instruction (padding/metadata after code)
inst = inst_class.from_bytes(data[i:i+inst_class._size()+8]) # +8 for potential 64-bit literal
inst = detect_format(data[i:]).from_bytes(data[i:])
inst._words = inst.size() // 4
# Determine dispatch function and pcode function
fn = COMPILED_FUNCTIONS.get(type(inst.op), {}).get(inst.op)
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_CODE_END: break
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: inst._dispatch = dispatch_endpgm
elif isinstance(inst, SOPP) and inst.op == SOPPOp.S_BARRIER: inst._dispatch = dispatch_barrier
elif isinstance(inst, SOPP) and inst.op in (SOPPOp.S_CLAUSE, SOPPOp.S_WAITCNT, SOPPOp.S_WAITCNT_DEPCTR, SOPPOp.S_SENDMSG, SOPPOp.S_SET_INST_PREFETCH_DISTANCE): inst._dispatch = dispatch_nop
elif isinstance(inst, (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM)): inst._dispatch = exec_scalar
elif isinstance(inst, VOP1) and inst.op == VOP1Op.V_NOP: inst._dispatch = dispatch_nop
elif isinstance(inst, VOP3P) and 'WMMA' in inst.op_name: inst._dispatch = dispatch_wmma
@@ -378,11 +377,14 @@ def decode_program(data: bytes) -> dict[int, Inst]:
elif isinstance(inst, DS): inst._dispatch = dispatch_lane(exec_ds)
else: inst._dispatch = dispatch_lane(exec_vop)
# Validate pcode exists for instructions that need it (scalar/wave-level ops and VOPD don't need pcode)
needs_pcode = inst._dispatch not in (dispatch_endpgm, dispatch_barrier, exec_scalar, dispatch_nop, dispatch_wmma,
dispatch_writelane, dispatch_readlane, dispatch_lane(exec_vopd))
if fn is None and inst.op_name and needs_pcode: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
inst._fn = fn if fn else lambda *args, **kwargs: {}
# Compile pcode for instructions that use it (not VOPD which has _fnx/_fny, not special dispatches)
# VOPD needs separate functions for X and Y ops
if isinstance(inst, VOPD):
def _compile_vopd_op(op): return compile_pseudocode(type(op).__name__, op.name, PSEUDOCODE_STRINGS[type(op)][op])
inst._fnx, inst._fny = _compile_vopd_op(_VOPD_TO_VOP[inst.opx]), _compile_vopd_op(_VOPD_TO_VOP[inst.opy])
elif inst._dispatch not in (dispatch_endpgm, dispatch_barrier, dispatch_nop, dispatch_wmma, dispatch_writelane):
assert type(inst.op) != int, f"inst op of {inst} is int"
inst._fn = compile_pseudocode(type(inst.op).__name__, inst.op.name, PSEUDOCODE_STRINGS[type(inst.op)][inst.op])
result[i // 4] = inst
i += inst._words * 4
return result

View File

@@ -1,9 +1,9 @@
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
import struct, math
import struct, math, re, functools
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS
# INTERNAL HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def _div(a, b):
@@ -11,143 +11,35 @@ def _div(a, b):
except ZeroDivisionError:
if a == 0.0 or math.isnan(a): return float("nan")
return math.copysign(float("inf"), a * b) if b == 0.0 else float("inf") if a > 0 else float("-inf")
def _to_f16_bits(v): return v if isinstance(v, int) else _i16(v)
def _isnan(x):
try: return math.isnan(float(x))
except (TypeError, ValueError): return False
def _check_nan_type(x, quiet_bit_expected, default):
"""Check NaN type by examining quiet bit. Returns default if can't determine."""
try:
if not math.isnan(float(x)): return False
if hasattr(x, '_reg') and hasattr(x, '_bits'):
bits = x._reg._val & ((1 << x._bits) - 1)
# NaN format: exponent all 1s, quiet bit, mantissa != 0
# f16: exp[14:10]=31, quiet=bit9, mant[8:0] | f32: exp[30:23]=255, quiet=bit22, mant[22:0] | f64: exp[62:52]=2047, quiet=bit51, mant[51:0]
exp_bits, quiet_pos, mant_mask = {16: (0x1f, 9, 0x3ff), 32: (0xff, 22, 0x7fffff), 64: (0x7ff, 51, 0xfffffffffffff)}.get(x._bits, (0,0,0))
exp_shift = {16: 10, 32: 23, 64: 52}.get(x._bits, 0)
if exp_bits and ((bits >> exp_shift) & exp_bits) == exp_bits and (bits & mant_mask) != 0:
return ((bits >> quiet_pos) & 1) == quiet_bit_expected
return default
except (TypeError, ValueError): return False
def _isquietnan(x): return _check_nan_type(x, 1, True) # quiet NaN has quiet bit = 1
def _issignalnan(x): return _check_nan_type(x, 0, False) # signaling NaN has quiet bit = 0
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
def _fma(a, b, c):
try: return math.fma(a, b, c)
except ValueError: return float('nan') # inf * 0 + c is NaN per IEEE 754
def _signext(v): return v
def _fpop(fn):
def wrapper(x):
x = float(x)
if math.isnan(x) or math.isinf(x): return x
result = float(fn(x))
# Preserve sign of zero (IEEE 754: ceil(-0.0) = -0.0, ceil(-0.1) = -0.0)
if result == 0.0: return math.copysign(0.0, x)
return result
return math.copysign(0.0, x) if result == 0.0 else result
return wrapper
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
class _SafeFloat(float):
"""Float subclass that uses _div for division to handle 0/inf correctly."""
def __truediv__(self, o): return _div(float(self), float(o))
def __rtruediv__(self, o): return _div(float(o), float(self))
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
def _f_to_int(f, lo, hi): f = float(f); return 0 if math.isnan(f) else (hi if f >= hi else lo if f <= lo else int(f))
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
def f32_to_f16(f):
f = float(f)
if math.isnan(f): return 0x7e00 # f16 NaN
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
try: return struct.unpack("<H", struct.pack("<e", f))[0]
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
def _f16_to_f32_bits(bits): return struct.unpack("<e", struct.pack("<H", int(bits) & 0xffff))[0]
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
def u8_to_u32(v): return int(v) & 0xff
def u4_to_u32(v): return int(v) & 0xf
def _sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
def _mantissa_f32(f): return struct.unpack("<I", struct.pack("<f", f))[0] & 0x7fffff if not (math.isinf(f) or math.isnan(f)) else 0
def _ldexp(m, e): return math.ldexp(m, e)
def isEven(x):
x = float(x)
if math.isinf(x) or math.isnan(x): return False
return int(x) % 2 == 0
def fract(x): return x - math.floor(x)
PI = math.pi
def _trig(fn, x):
# V_SIN/COS_F32: hardware does frac on input cycles before computing
if math.isinf(x) or math.isnan(x): return float("nan")
frac_cycles = fract(x / (2 * math.pi))
result = fn(frac_cycles * 2 * math.pi)
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
# Round very small results (below f32 precision) to exactly 0
if abs(result) < 1e-7: return 0.0
return result
def sin(x): return _trig(math.sin, x)
def cos(x): return _trig(math.cos, x)
def pow(a, b):
try: return a ** b
except OverflowError: return float("inf") if b > 0 else 0.0
def _brev(v, bits): return int(bin(v & ((1 << bits) - 1))[2:].zfill(bits)[::-1], 2)
def _brev32(v): return _brev(v, 32)
def _brev64(v): return _brev(v, 64)
def _ctz(v, bits):
v, n = int(v) & ((1 << bits) - 1), 0
if v == 0: return bits
while (v & 1) == 0: v >>= 1; n += 1
return n
def _ctz32(v): return _ctz(v, 32)
def _ctz64(v): return _ctz(v, 64)
def _exponent(f):
# Handle TypedView (f16/f32/f64) to get correct exponent for that type
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
raw = f._val
if f._bits == 16: return (raw >> 10) & 0x1f # f16: 5-bit exponent
if f._bits == 32: return (raw >> 23) & 0xff # f32: 8-bit exponent
if f._bits == 64: return (raw >> 52) & 0x7ff # f64: 11-bit exponent
# Fallback: convert to f32 and get exponent
f = float(f)
if math.isinf(f) or math.isnan(f): return 255
if f == 0.0: return 0
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
except: return 0
def _is_denorm_f32(f):
if not isinstance(f, float): f = _f32(int(f) & 0xffffffff)
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
bits = struct.unpack("<I", struct.pack("<f", float(f)))[0]
return (bits >> 23) & 0xff == 0
def _is_denorm_f64(f):
if not isinstance(f, float): f = _f64(int(f) & 0xffffffffffffffff)
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
bits = struct.unpack("<Q", struct.pack("<d", float(f)))[0]
return (bits >> 52) & 0x7ff == 0
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
v_min_i32, v_max_i32 = min, max
v_min_i16, v_max_i16 = min, max
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
def ABSDIFF(a, b): return abs(int(a) - int(b))
# BF16 (bfloat16) conversion functions
def _bf16(i):
"""Convert bf16 bits to float. BF16 is just the top 16 bits of f32."""
return struct.unpack("<f", struct.pack("<I", (i & 0xffff) << 16))[0]
@@ -157,95 +49,21 @@ def _ibf16(f):
if math.isinf(f): return 0x7f80 if f > 0 else 0xff80 # bf16 ±infinity
try: return (struct.unpack("<I", struct.pack("<f", float(f)))[0] >> 16) & 0xffff
except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
def f32_to_bf16(f): return _ibf16(f)
def _trig(fn, x):
# V_SIN/COS_F32: hardware does frac on input cycles before computing
if math.isinf(x) or math.isnan(x): return float("nan")
frac_cycles = fract(x / (2 * math.pi))
result = fn(frac_cycles * 2 * math.pi)
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
# Round very small results (below f32 precision) to exactly 0
if abs(result) < 1e-7: return 0.0
return result
# BYTE_PERMUTE for V_PERM_B32 - select bytes from 64-bit data based on selector
def BYTE_PERMUTE(data, sel):
"""Select a byte from 64-bit data based on selector value.
sel 0-7: select byte from data (S1 is bytes 0-3, S0 is bytes 4-7 in {S0,S1})
sel 8-11: sign-extend from specific bytes (8->byte1, 9->byte3, 10->byte5, 11->byte7)
sel 12: constant 0x00
sel >= 13: constant 0xFF"""
sel = int(sel) & 0xff
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00 # sign of byte 1
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00 # sign of byte 3
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00 # sign of byte 5
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00 # sign of byte 7
if sel == 12: return 0x00
return 0xff # sel >= 13
class _SafeFloat(float):
"""Float subclass that uses _div for division to handle 0/inf correctly."""
def __truediv__(self, o): return _div(float(self), float(o))
def __rtruediv__(self, o): return _div(float(o), float(self))
# v_sad_u8 helper for V_SAD instructions (sum of absolute differences of 4 bytes)
def v_sad_u8(s0, s1, s2):
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
result += abs(a - b)
return result & 0xffffffff
# v_msad_u8 helper (masked SAD - skip when reference byte is 0)
def v_msad_u8(s0, s1, s2):
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
if b != 0: # Only add diff if reference (s1) byte is non-zero
result += abs(a - b)
return result & 0xffffffff
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
def u32_to_u16(u): return int(u) & 0xffff
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
def SAT8(v): return max(0, min(255, int(v)))
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
def mantissa(f):
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
m, _ = math.frexp(f)
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
def signext_from_bit(val, bit):
bit = int(bit)
if bit == 0: return 0
mask = (1 << bit) - 1
val = int(val) & mask
if val & (1 << (bit - 1)): return val - (1 << bit)
return val
# Aliases used in pseudocode
s_ff1_i32_b32, s_ff1_i32_b64 = _ctz32, _ctz64
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
isNAN = _isnan
isQuietNAN = _isquietnan
isSignalNAN = _issignalnan
fma, ldexp, sign, exponent = _fma, _ldexp, _sign, _exponent
def F(x):
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
if isinstance(x, int): return _f32(x) # int -> interpret as f32 bits
if isinstance(x, TypedView): return x # preserve TypedView for bit-pattern checks
return float(x) # already a float or float-like
signext = lambda x: int(x) # sign-extend to full width - already handled by Python's arbitrary precision ints
pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
_pack, _pack32 = pack, pack32 # Aliases for internal use
WAVE32, WAVE64 = True, False
# Float overflow/underflow constants
OVERFLOW_F32 = float('inf')
UNDERFLOW_F32 = 0.0
OVERFLOW_F64 = float('inf')
UNDERFLOW_F64 = 0.0
MAX_FLOAT_F32 = 3.4028235e+38 # Largest finite float32
# INF object that supports .f16/.f32/.f64 access and comparison with floats
class _Inf:
f16 = f32 = f64 = float('inf')
def __neg__(self): return _NegInf()
@@ -260,26 +78,24 @@ class _NegInf:
def __float__(self): return float('-inf')
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
def __req__(self, other): return self.__eq__(other)
INF = _Inf()
# Rounding mode placeholder
class _RoundMode:
NEAREST_EVEN = 0
ROUND_MODE = _RoundMode()
# Helper functions for pseudocode
def cvtToQuietNAN(x): return float('nan')
DST = None # Placeholder, will be set in context
class _WaveMode:
IEEE = False
WAVE_MODE = _WaveMode()
class _DenormChecker:
"""Comparator for denormalized floats. x == DENORM.f32 checks if x is denormalized."""
def __init__(self, bits): self._bits = bits
def _check(self, other):
return _is_denorm_f64(float(other)) if self._bits == 64 else _is_denorm_f32(float(other))
f = float(other)
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
if self._bits == 64:
bits = struct.unpack("<Q", struct.pack("<d", f))[0]
return (bits >> 52) & 0x7ff == 0
bits = struct.unpack("<I", struct.pack("<f", f))[0]
return (bits >> 23) & 0xff == 0
def __eq__(self, other): return self._check(other)
def __req__(self, other): return self._check(other)
def __ne__(self, other): return not self._check(other)
@@ -287,7 +103,9 @@ class _DenormChecker:
class _Denorm:
f32 = _DenormChecker(32)
f64 = _DenormChecker(64)
DENORM = _Denorm()
_pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
_pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
class TypedView:
"""View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
@@ -396,8 +214,6 @@ class TypedView:
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
SliceProxy = TypedView # Alias for compatibility
class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
__slots__ = ('_val',)
@@ -466,5 +282,484 @@ class Reg:
def __eq__(s, o): return s._val == int(o)
def __ne__(s, o): return s._val != int(o)
# ═══════════════════════════════════════════════════════════════════════════════
# PSEUDOCODE API - Functions and constants from AMD ISA pseudocode
# ═══════════════════════════════════════════════════════════════════════════════
# Rounding and float operations
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
def fract(x): return x - math.floor(x)
def sin(x): return _trig(math.sin, x)
def cos(x): return _trig(math.cos, x)
def pow(a, b):
try: return a ** b
except OverflowError: return float("inf") if b > 0 else 0.0
def isEven(x):
x = float(x)
if math.isinf(x) or math.isnan(x): return False
return int(x) % 2 == 0
def mantissa(f):
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
m, _ = math.frexp(f)
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
def signext_from_bit(val, bit):
bit = int(bit)
if bit == 0: return 0
mask = (1 << bit) - 1
val = int(val) & mask
if val & (1 << (bit - 1)): return val - (1 << bit)
return val
# Type conversions
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
def f32_to_f16(f):
f = float(f)
if math.isnan(f): return 0x7e00 # f16 NaN
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
try: return struct.unpack("<H", struct.pack("<e", f))[0]
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
def f32_to_bf16(f): return _ibf16(f)
def u8_to_u32(v): return int(v) & 0xff
def u4_to_u32(v): return int(v) & 0xf
def u32_to_u16(u): return int(u) & 0xffff
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
def SAT8(v): return max(0, min(255, int(v)))
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
# Min/max operations
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
v_min_i32, v_max_i32 = min, max
v_min_i16, v_max_i16 = min, max
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
# SAD/MSAD operations
def ABSDIFF(a, b): return abs(int(a) - int(b))
def v_sad_u8(s0, s1, s2):
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
result += abs(a - b)
return result & 0xffffffff
def v_msad_u8(s0, s1, s2):
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
if b != 0: # Only add diff if reference (s1) byte is non-zero
result += abs(a - b)
return result & 0xffffffff
def BYTE_PERMUTE(data, sel):
"""Select a byte from 64-bit data based on selector value."""
sel = int(sel) & 0xff
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00
if sel == 12: return 0x00
return 0xff
# Pseudocode functions
def s_ff1_i32_b32(v): return _ctz(v, 32)
def s_ff1_i32_b64(v): return _ctz(v, 64)
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
def isNAN(x):
try: return math.isnan(float(x))
except (TypeError, ValueError): return False
def isQuietNAN(x): return _check_nan_type(x, 1, True)
def isSignalNAN(x): return _check_nan_type(x, 0, False)
def fma(a, b, c):
try: return math.fma(a, b, c)
except ValueError: return float('nan')
def ldexp(m, e): return math.ldexp(m, e)
def sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
def exponent(f):
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
raw = f._val
if f._bits == 16: return (raw >> 10) & 0x1f
if f._bits == 32: return (raw >> 23) & 0xff
if f._bits == 64: return (raw >> 52) & 0x7ff
f = float(f)
if math.isinf(f) or math.isnan(f): return 255
if f == 0.0: return 0
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
except: return 0
def signext(x): return int(x)
def cvtToQuietNAN(x): return float('nan')
def F(x):
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
if isinstance(x, int): return _f32(x)
if isinstance(x, TypedView): return x
return float(x)
# Constants
PI = math.pi
WAVE32, WAVE64 = True, False
OVERFLOW_F32, UNDERFLOW_F32 = float('inf'), 0.0
OVERFLOW_F64, UNDERFLOW_F64 = float('inf'), 0.0
MAX_FLOAT_F32 = 3.4028235e+38
INF = _Inf()
ROUND_MODE = _RoundMode()
WAVE_MODE = _WaveMode()
DENORM = _Denorm()
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
# ═══════════════════════════════════════════════════════════════════════════════
def _compile_pseudocode(pseudocode: str) -> str:
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
raw_lines = pseudocode.strip().split('\n')
joined_lines: list[str] = []
for line in raw_lines:
line = line.strip()
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
else:
joined_lines.append(line)
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.split('//')[0].strip() # Strip C-style comments
if not line: continue
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line.startswith('elsif '):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line == 'else':
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + "else:")
indent += 1
need_pass = True
elif line.startswith('endif'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass = False
elif line.startswith('endfor'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass, in_first_match_loop = False, False
elif line.startswith('declare '):
pass
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass, in_first_match_loop = True, True
elif '=' in line and not line.startswith('=='):
need_pass = False
line = line.rstrip(';')
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
rhs = _expr(m[1])
lines.append(' ' * indent + f"_full = {rhs}")
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
if op in line:
lhs, rhs = line.split(op, 1)
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
def _assign(lhs: str, rhs: str) -> str:
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}"
def _expr(e: str) -> str:
e = e.strip()
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
e = re.sub(r'!([^=])', r' not \1', e)
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
def pack(m):
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
e = re.sub(r"\d+'[FIBU]\(", "(", e)
e = re.sub(r'\bB\(', '(', e)
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
def convert_verilog_slice(m):
start, width = m.group(1).strip(), m.group(2).strip()
return f'[({start}) + ({width}) - 1 : ({start})]'
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
def process_brackets(s):
result, i = [], 0
while i < len(s):
if s[i] == '[':
depth, start = 1, i + 1
j = start
while j < len(s) and depth > 0:
if s[j] == '[': depth += 1
elif s[j] == ']': depth -= 1
j += 1
inner = _expr(s[start:j-1])
result.append('[' + inner + ']')
i = j
else:
result.append(s[i])
i += 1
return ''.join(result)
e = process_brackets(e)
while '?' in e:
depth, bracket, q = 0, 0, -1
for i, c in enumerate(e):
if c == '(': depth += 1
elif c == ')': depth -= 1
elif c == '[': bracket += 1
elif c == ']': bracket -= 1
elif c == '?' and depth == 0 and bracket == 0: q = i; break
if q < 0: break
depth, bracket, col = 0, 0, -1
for i in range(q + 1, len(e)):
if e[i] == '(': depth += 1
elif e[i] == ')': depth -= 1
elif e[i] == '[': bracket += 1
elif e[i] == ']': bracket -= 1
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
if col < 0: break
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
e = f'(({t}) if ({cond}) else ({f}))'
return e
def _apply_pseudocode_fixes(op_name: str, code: str) -> str:
"""Apply known fixes for PDF pseudocode bugs."""
if op_name == 'V_DIV_FMAS_F32':
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
if op_name == 'V_DIV_FMAS_F64':
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
if op_name == 'V_DIV_SCALE_F32':
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
if op_name == 'V_DIV_SCALE_F64':
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
if op_name == 'V_DIV_FIXUP_F32':
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
if op_name == 'V_DIV_FIXUP_F64':
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
if op_name == 'V_TRIG_PREOP_F64':
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
return code
def _generate_function(cls_name: str, op_name: str, pc: str, code: str) -> str:
"""Generate a single compiled pseudocode function.
Functions take int parameters and return dict of int values.
Reg wrapping happens inside the function, only for registers actually used."""
has_d1 = '{ D1' in pc
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op_name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
is_ds = cls_name == 'DSOp'
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
is_smem = cls_name == 'SMEMOp'
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
combined = code + pc
fn_name = f"_{cls_name}_{op_name}"
# Detect which registers are used/modified
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# Build function signature and Reg init lines
if is_smem:
lines = [f"def {fn_name}(MEM, addr):"]
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
special_regs = []
elif is_ds:
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat:
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'VDATA')]
elif has_s_array:
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
special_regs = []
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
if "in[" in combined:
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
code = code.replace("in[", "ins[")
else:
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
# Only create Regs for registers actually used in the pseudocode
reg_inits = []
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
# Build init code
init_parts = reg_inits.copy()
for name, init in special_regs:
if name in combined: init_parts.append(f"{name}={init}")
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=TypedView(EXEC, 31, 0)")
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=TypedView(EXEC, 63, 32)")
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
# Add init line and separator
if init_parts: lines.append(f" {'; '.join(init_parts)}")
# Add compiled pseudocode
for line in code.split('\n'):
if line.strip(): lines.append(f" {line}")
# Build result dict
result_items = []
if modifies_d0: result_items.append("'D0': D0._val")
if modifies_scc: result_items.append("'SCC': SCC._val")
if modifies_vcc: result_items.append("'VCC': VCC._val")
if modifies_exec: result_items.append("'EXEC': EXEC._val")
if has_d1: result_items.append("'D1': D1._val")
if modifies_pc: result_items.append("'PC': PC._val")
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'SDATA': SDATA._val")
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if is_flat:
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'VDATA': VDATA._val")
lines.append(f" return {{{', '.join(result_items)}}}")
return '\n'.join(lines)
# Build the globals dict for exec() - includes all pcode symbols
_PCODE_GLOBALS = {
'Reg': Reg, 'TypedView': TypedView, '_pack': _pack, '_pack32': _pack32,
'ABSDIFF': ABSDIFF, 'BYTE_PERMUTE': BYTE_PERMUTE, 'DENORM': DENORM, 'F': F,
'GT_NEG_ZERO': GT_NEG_ZERO, 'LT_NEG_ZERO': LT_NEG_ZERO, 'INF': INF,
'MAX_FLOAT_F32': MAX_FLOAT_F32, 'OVERFLOW_F32': OVERFLOW_F32, 'OVERFLOW_F64': OVERFLOW_F64,
'UNDERFLOW_F32': UNDERFLOW_F32, 'UNDERFLOW_F64': UNDERFLOW_F64,
'PI': PI, 'ROUND_MODE': ROUND_MODE, 'WAVE_MODE': WAVE_MODE,
'WAVE32': WAVE32, 'WAVE64': WAVE64, 'TWO_OVER_PI_1201': TWO_OVER_PI_1201,
'SAT8': SAT8, 'trunc': trunc, 'floor': floor, 'ceil': ceil, 'sqrt': sqrt,
'log2': log2, 'fract': fract, 'sin': sin, 'cos': cos, 'pow': pow,
'isEven': isEven, 'mantissa': mantissa, 'signext_from_bit': signext_from_bit,
'i32_to_f32': i32_to_f32, 'u32_to_f32': u32_to_f32, 'i32_to_f64': i32_to_f64,
'u32_to_f64': u32_to_f64, 'f32_to_f64': f32_to_f64, 'f64_to_f32': f64_to_f32,
'f32_to_i32': f32_to_i32, 'f32_to_u32': f32_to_u32, 'f64_to_i32': f64_to_i32,
'f64_to_u32': f64_to_u32, 'f32_to_f16': f32_to_f16, 'f16_to_f32': f16_to_f32,
'i16_to_f16': i16_to_f16, 'u16_to_f16': u16_to_f16, 'f16_to_i16': f16_to_i16,
'f16_to_u16': f16_to_u16, 'bf16_to_f32': bf16_to_f32, 'f32_to_bf16': f32_to_bf16,
'u8_to_u32': u8_to_u32, 'u4_to_u32': u4_to_u32, 'u32_to_u16': u32_to_u16,
'i32_to_i16': i32_to_i16, 'f16_to_snorm': f16_to_snorm, 'f16_to_unorm': f16_to_unorm,
'f32_to_snorm': f32_to_snorm, 'f32_to_unorm': f32_to_unorm,
'v_cvt_i16_f32': v_cvt_i16_f32, 'v_cvt_u16_f32': v_cvt_u16_f32, 'f32_to_u8': f32_to_u8,
'v_min_f32': v_min_f32, 'v_max_f32': v_max_f32, 'v_min_f16': v_min_f16, 'v_max_f16': v_max_f16,
'v_min_i32': v_min_i32, 'v_max_i32': v_max_i32, 'v_min_i16': v_min_i16, 'v_max_i16': v_max_i16,
'v_min_u32': v_min_u32, 'v_max_u32': v_max_u32, 'v_min_u16': v_min_u16, 'v_max_u16': v_max_u16,
'v_min3_f32': v_min3_f32, 'v_max3_f32': v_max3_f32, 'v_min3_f16': v_min3_f16, 'v_max3_f16': v_max3_f16,
'v_min3_i32': v_min3_i32, 'v_max3_i32': v_max3_i32, 'v_min3_i16': v_min3_i16, 'v_max3_i16': v_max3_i16,
'v_min3_u32': v_min3_u32, 'v_max3_u32': v_max3_u32, 'v_min3_u16': v_min3_u16, 'v_max3_u16': v_max3_u16,
'v_sad_u8': v_sad_u8, 'v_msad_u8': v_msad_u8,
's_ff1_i32_b32': s_ff1_i32_b32, 's_ff1_i32_b64': s_ff1_i32_b64,
'isNAN': isNAN, 'isQuietNAN': isQuietNAN, 'isSignalNAN': isSignalNAN,
'fma': fma, 'ldexp': ldexp, 'sign': sign, 'exponent': exponent,
'signext': signext, 'cvtToQuietNAN': cvtToQuietNAN,
}
@functools.cache
def compile_pseudocode(cls_name: str, op_name: str, pseudocode: str):
"""Compile pseudocode string to executable function. Cached for performance."""
code = _compile_pseudocode(pseudocode)
code = _apply_pseudocode_fixes(op_name, code)
fn_code = _generate_function(cls_name, op_name, pseudocode, code)
fn_name = f"_{cls_name}_{op_name}"
local_ns = {}
exec(fn_code, _PCODE_GLOBALS, local_ns)
return local_ns[fn_name]

View File

@@ -38,161 +38,7 @@ FLOAT_MAP = {'0.5': 'POS_HALF', '-0.5': 'NEG_HALF', '1.0': 'POS_ONE', '-1.0': 'N
'4.0': 'POS_FOUR', '-4.0': 'NEG_FOUR', '1/(2*PI)': 'INV_2PI', '0': 'ZERO'}
INST_PATTERN = re.compile(r'^([SVD]S?_[A-Z0-9_]+|(?:FLAT|GLOBAL|SCRATCH)_[A-Z0-9_]+)\s+(\d+)\s*$', re.M)
# Patterns that can't be handled by the DSL (require special handling in emu.py)
UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'thread_',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
'fp6', 'bf6', 'GS_REGS', 'M0.base', 'DS_DATA', '= 0..', 'sign(src', 'if no LDS', 'gds_base', 'vector mask',
'SGPR_ADDR', 'INST_OFFSET', 'laneID'] # FLAT ops with non-standard vars
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
# ═══════════════════════════════════════════════════════════════════════════════
def compile_pseudocode(pseudocode: str) -> str:
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
raw_lines = pseudocode.strip().split('\n')
joined_lines: list[str] = []
for line in raw_lines:
line = line.strip()
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
else:
joined_lines.append(line)
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.split('//')[0].strip() # Strip C-style comments
if not line: continue
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line.startswith('elsif '):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line == 'else':
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + "else:")
indent += 1
need_pass = True
elif line.startswith('endif'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass = False
elif line.startswith('endfor'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass, in_first_match_loop = False, False
elif line.startswith('declare '):
pass
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass, in_first_match_loop = True, True
elif '=' in line and not line.startswith('=='):
need_pass = False
line = line.rstrip(';')
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
rhs = _expr(m[1])
lines.append(' ' * indent + f"_full = {rhs}")
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
if op in line:
lhs, rhs = line.split(op, 1)
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
def _assign(lhs: str, rhs: str) -> str:
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}"
def _expr(e: str) -> str:
e = e.strip()
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
e = re.sub(r'!([^=])', r' not \1', e)
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
def pack(m):
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
e = re.sub(r"\d+'[FIBU]\(", "(", e)
e = re.sub(r'\bB\(', '(', e)
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
def convert_verilog_slice(m):
start, width = m.group(1).strip(), m.group(2).strip()
return f'[({start}) + ({width}) - 1 : ({start})]'
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
def process_brackets(s):
result, i = [], 0
while i < len(s):
if s[i] == '[':
depth, start = 1, i + 1
j = start
while j < len(s) and depth > 0:
if s[j] == '[': depth += 1
elif s[j] == ']': depth -= 1
j += 1
inner = _expr(s[start:j-1])
result.append('[' + inner + ']')
i = j
else:
result.append(s[i])
i += 1
return ''.join(result)
e = process_brackets(e)
while '?' in e:
depth, bracket, q = 0, 0, -1
for i, c in enumerate(e):
if c == '(': depth += 1
elif c == ')': depth -= 1
elif c == '[': bracket += 1
elif c == ']': bracket -= 1
elif c == '?' and depth == 0 and bracket == 0: q = i; break
if q < 0: break
depth, bracket, col = 0, 0, -1
for i in range(q + 1, len(e)):
if e[i] == '(': depth += 1
elif e[i] == ')': depth -= 1
elif e[i] == '[': bracket += 1
elif e[i] == ']': bracket -= 1
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
if col < 0: break
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
e = f'(({t}) if ({cond}) else ({f}))'
return e
# ═══════════════════════════════════════════════════════════════════════════════
# PDF PARSING WITH PAGE CACHING
@@ -472,8 +318,8 @@ def _generate_ins_py(formats, enums, src_enum, doc_name) -> str:
if "NULL" in src_names: lines.append("OFF = NULL\n")
return '\n'.join(lines)
def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
"""Generate gen_pcode.py content (compiled pseudocode functions)."""
def _generate_str_pcode_py(enums, pseudocode, arch) -> str:
"""Generate str_pcode.py content (raw pseudocode strings)."""
# Get op enums for this arch (import from .ins which re-exports from .enum)
import importlib
autogen = importlib.import_module(f"extra.assembly.amd.autogen.{arch}.ins")
@@ -491,186 +337,35 @@ def _generate_gen_pcode_py(enums, pseudocode, arch) -> str:
if key in defined_ops:
for enum_cls, enum_val in defined_ops[key]: instructions[enum_cls][enum_val] = pc
# First pass: generate all function code
fn_lines: list[str] = []
all_fn_entries: dict = {}
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
if not instructions.get(enum_cls): continue
fn_entries = []
for op, pc in instructions[enum_cls].items():
if any(p in pc for p in UNSUPPORTED): continue
try:
code = compile_pseudocode(pc)
code = _apply_pseudocode_fixes(op, code)
fn_name, fn_code = _generate_function(cls_name, op, pc, code)
fn_lines.append(fn_code)
fn_entries.append((op, fn_name))
except Exception as e: print(f" Warning: Failed to compile {op.name}: {e}")
if fn_entries:
all_fn_entries[enum_cls] = fn_entries
fn_lines.append(f'{cls_name}_FUNCTIONS = {{')
for op, fn_name in fn_entries: fn_lines.append(f" {cls_name}.{op.name}: {fn_name},")
fn_lines.append('}\n')
fn_lines.append('COMPILED_FUNCTIONS = {')
for enum_cls in OP_ENUMS:
if all_fn_entries.get(enum_cls): fn_lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_FUNCTIONS,')
fn_lines.append('}')
# Second pass: scan generated code for pcode imports
fn_code_str = '\n'.join(fn_lines)
import extra.assembly.amd.pcode as pcode_module
pcode_exports = [name for name in dir(pcode_module) if not name.startswith('_') or name.startswith('_') and not name.startswith('__')]
used_imports = sorted(name for name in pcode_exports if re.search(rf'\b{re.escape(name)}\b', fn_code_str))
# Build final output with explicit imports
# Build string dictionaries for each enum
lines = [f'''# autogenerated by pdf.py - do not edit
# to regenerate: python -m extra.assembly.amd.pdf --arch {arch}
# ruff: noqa: E501
# mypy: ignore-errors
from extra.assembly.amd.autogen.{arch}.enum import {", ".join(enum_names)}
from extra.assembly.amd.pcode import {", ".join(used_imports)}
'''] + fn_lines
''']
all_dict_entries: dict = {}
for enum_cls in OP_ENUMS:
cls_name = enum_cls.__name__
if not instructions.get(enum_cls): continue
dict_entries = [(op, repr(pc)) for op, pc in instructions[enum_cls].items()]
if dict_entries:
all_dict_entries[enum_cls] = dict_entries
lines.append(f'{cls_name}_PCODE = {{')
for op, escaped in dict_entries: lines.append(f" {cls_name}.{op.name}: {escaped},")
lines.append('}\n')
lines.append('PSEUDOCODE_STRINGS = {')
for enum_cls in OP_ENUMS:
if all_dict_entries.get(enum_cls): lines.append(f' {enum_cls.__name__}: {enum_cls.__name__}_PCODE,')
lines.append('}')
return '\n'.join(lines)
def _apply_pseudocode_fixes(op, code: str) -> str:
"""Apply known fixes for PDF pseudocode bugs."""
if op.name == 'V_DIV_FMAS_F32':
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
if op.name == 'V_DIV_FMAS_F64':
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
if op.name == 'V_DIV_SCALE_F32':
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
if op.name == 'V_DIV_SCALE_F64':
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
if op.name == 'V_DIV_FIXUP_F32':
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
if op.name == 'V_DIV_FIXUP_F64':
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
if op.name == 'V_TRIG_PREOP_F64':
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
return code
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
"""Generate a single compiled pseudocode function.
Functions take int parameters and return dict of int values.
Reg wrapping happens inside the function, only for registers actually used."""
has_d1 = '{ D1' in pc
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
is_ds = cls_name == 'DSOp'
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
is_smem = cls_name == 'SMEMOp'
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
combined = code + pc
fn_name = f"_{cls_name}_{op.name}"
# Detect which registers are used/modified
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# Build function signature and Reg init lines
if is_smem:
lines = [f"def {fn_name}(MEM, addr):"]
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
special_regs = []
elif is_ds:
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat:
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'VDATA')]
elif has_s_array:
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
special_regs = []
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
if "in[" in combined:
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
code = code.replace("in[", "ins[")
else:
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
# Only create Regs for registers actually used in the pseudocode
reg_inits = []
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
# Build init code
init_parts = reg_inits.copy()
for name, init in special_regs:
if name in combined: init_parts.append(f"{name}={init}")
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=SliceProxy(EXEC, 63, 32)")
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
# Add init line and separator
if init_parts: lines.append(f" {'; '.join(init_parts)}")
lines.append(" # --- compiled pseudocode ---")
# Add compiled pseudocode
for line in code.split('\n'):
if line.strip(): lines.append(f" {line}")
# Build result dict
result_items = []
if modifies_d0: result_items.append("'D0': D0._val")
if modifies_scc: result_items.append("'SCC': SCC._val")
if modifies_vcc: result_items.append("'VCC': VCC._val")
if modifies_exec: result_items.append("'EXEC': EXEC._val")
if has_d1: result_items.append("'D1': D1._val")
if modifies_pc: result_items.append("'PC': PC._val")
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'SDATA': SDATA._val")
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if is_flat:
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'VDATA': VDATA._val")
lines.append(f" return {{{', '.join(result_items)}}}\n")
return fn_name, '\n'.join(lines)
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN GENERATION
# ═══════════════════════════════════════════════════════════════════════════════
def generate_arch(arch: str) -> dict:
"""Generate enum.py, ins.py and gen_pcode.py for a single architecture."""
"""Generate enum.py, ins.py and str_pcode.py for a single architecture."""
urls = PDF_URLS[arch]
if isinstance(urls, str): urls = [urls]
@@ -696,9 +391,9 @@ def generate_arch(arch: str) -> dict:
ins_path.write_text(ins_content)
print(f"Generated {ins_path}: {len(merged['formats'])} formats")
# Write gen_pcode.py (needs enum.py to exist first for imports)
pcode_path = base_path / "gen_pcode.py"
pcode_content = _generate_gen_pcode_py(merged["enums"], merged["pseudocode"], arch)
# Write str_pcode.py (needs enum.py to exist first for imports)
pcode_path = base_path / "str_pcode.py"
pcode_content = _generate_str_pcode_py(merged["enums"], merged["pseudocode"], arch)
pcode_path.write_text(pcode_content)
print(f"Generated {pcode_path}: {len(merged['pseudocode'])} instructions")

View File

@@ -30,8 +30,8 @@ def get_llvm_objdump():
class ExecContext:
"""Context for running compiled pseudocode in tests."""
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, SliceProxy
self._Reg, self._MASK64, self._SliceProxy = Reg, MASK64, SliceProxy
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, TypedView
self._Reg, self._MASK64, self._TypedView = Reg, MASK64, TypedView
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
self.D0, self.D1 = Reg(d0), Reg(0)
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
@@ -51,7 +51,7 @@ class ExecContext:
ns.update({
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
'EXEC_LO': self._SliceProxy(self.EXEC, 31, 0), 'EXEC_HI': self._SliceProxy(self.EXEC, 63, 32),
'EXEC_LO': self._TypedView(self.EXEC, 31, 0), 'EXEC_HI': self._TypedView(self.EXEC, 63, 32),
'tmp': self.tmp, 'saveexec': self.saveexec,
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32, 'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,

View File

@@ -47,8 +47,7 @@ dev.synchronize()
elapsed = time.perf_counter() - st
self.assertNotEqual(result.returncode, 0, "should have raised")
self.assertTrue("NotImplementedError" in result.stderr or "ValueError" in result.stderr,
f"expected NotImplementedError or ValueError in stderr")
self.assertTrue("Error" in result.stderr, f"expected an error in stderr, got: {result.stderr[:500]}")
# Should exit immediately, not wait for the full timeout
self.assertLess(elapsed, 9.0, f"should exit immediately on emulator exception, took {elapsed:.1f}s")

View File

@@ -1,12 +1,16 @@
#!/usr/bin/env python3
"""Tests for the RDNA3 pseudocode DSL."""
import unittest
from extra.assembly.amd.pcode import (Reg, TypedView, SliceProxy, MASK32, MASK64,
_f32, _i32, _f16, _i16, f32_to_f16, _isnan, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
BYTE_PERMUTE, v_sad_u8, v_msad_u8)
from extra.assembly.amd.pdf import compile_pseudocode, _expr
from extra.assembly.amd.pcode import (Reg, TypedView, TypedView, MASK32, MASK64,
_f32, _i32, _f16, _i16, f32_to_f16, isNAN, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
BYTE_PERMUTE, v_sad_u8, v_msad_u8, _compile_pseudocode, _expr, compile_pseudocode)
from extra.assembly.amd.test.helpers import ExecContext
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP3SDOp_V_DIV_SCALE_F32, _VOPCOp_V_CMP_CLASS_F32
from extra.assembly.amd.autogen.rdna3.str_pcode import VOP3SDOp_PCODE, VOPCOp_PCODE
from extra.assembly.amd.autogen.rdna3.enum import VOP3SDOp, VOPCOp
# Compile pseudocode functions on demand for regression tests
_VOP3SDOp_V_DIV_SCALE_F32 = compile_pseudocode('VOP3SDOp', 'V_DIV_SCALE_F32', VOP3SDOp_PCODE[VOP3SDOp.V_DIV_SCALE_F32])
_VOPCOp_V_CMP_CLASS_F32 = compile_pseudocode('VOPCOp', 'V_CMP_CLASS_F32', VOPCOp_PCODE[VOPCOp.V_CMP_CLASS_F32])
class TestReg(unittest.TestCase):
def test_u32_read(self):
@@ -42,7 +46,7 @@ class TestReg(unittest.TestCase):
class TestTypedView(unittest.TestCase):
def test_bit_slice(self):
r = Reg(0xDEADBEEF)
# Slices return SliceProxy which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
# Slices return TypedView which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
self.assertEqual(r.u32[7:0].u32, 0xEF)
self.assertEqual(r.u32[15:8].u32, 0xBE)
self.assertEqual(r.u32[23:16].u32, 0xAD)
@@ -67,7 +71,7 @@ class TestTypedView(unittest.TestCase):
# S0.u32[S1.u32[4:0]] - access bit at position from another register
s0 = Reg(0b11010101)
s1 = Reg(3)
bit_pos = s1.u32[4:0] # SliceProxy, int value = 3
bit_pos = s1.u32[4:0] # TypedView, int value = 3
bit_val = s0.u32[int(bit_pos)] # bit 3 of s0 = 0
self.assertEqual(int(bit_pos), 3)
self.assertEqual(bit_val, 0)
@@ -85,7 +89,7 @@ class TestTypedView(unittest.TestCase):
self.assertFalse(r1.u32 < r2.u32)
self.assertTrue(r1.u32 != r2.u32)
class TestSliceProxy(unittest.TestCase):
class TestTypedView(unittest.TestCase):
def test_slice_read(self):
r = Reg(0x56781234)
self.assertEqual(r[15:0].u16, 0x1234)
@@ -154,19 +158,19 @@ class TestExecContext(unittest.TestCase):
self.assertEqual(ctx.SCC._val, 0)
def test_ternary(self):
code = compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
code = _compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
ctx = ExecContext(s0=5, s1=3)
ctx.run(code)
self.assertEqual(ctx.D0._val, 1)
def test_pack(self):
code = compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
code = _compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
ctx = ExecContext(s0=0x1234, s1=0x5678)
ctx.run(code)
self.assertEqual(ctx.D0._val, 0x56781234)
def test_tmp_with_typed_access(self):
code = compile_pseudocode("""tmp = S0.u32 + S1.u32
code = _compile_pseudocode("""tmp = S0.u32 + S1.u32
D0.u32 = tmp.u32""")
ctx = ExecContext(s0=100, s1=200)
ctx.run(code)
@@ -174,7 +178,7 @@ D0.u32 = tmp.u32""")
def test_s_add_u32_pattern(self):
# Real pseudocode pattern from S_ADD_U32
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
D0.u32 = tmp.u32""")
# Test overflow case
@@ -184,7 +188,7 @@ D0.u32 = tmp.u32""")
self.assertEqual(ctx.SCC._val, 1) # Carry set
def test_s_add_u32_no_overflow(self):
code = compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
D0.u32 = tmp.u32""")
ctx = ExecContext(s0=100, s1=200)
@@ -206,7 +210,7 @@ D0.u32 = tmp.u32""")
def test_for_loop(self):
# CTZ pattern - find first set bit
code = compile_pseudocode("""tmp = -1
code = _compile_pseudocode("""tmp = -1
for i in 0 : 31 do
if S0.u32[i] == 1 then
tmp = i
@@ -261,15 +265,15 @@ class TestPseudocodeRegressions(unittest.TestCase):
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
def test_isnan_with_typed_view(self):
"""_isnan must work with TypedView objects, not just Python floats.
Bug: _isnan checked isinstance(x, float) which returned False for TypedView."""
def testisNAN_with_typed_view(self):
"""isNAN must work with TypedView objects, not just Python floats.
Bug: isNAN checked isinstance(x, float) which returned False for TypedView."""
nan_reg = Reg(0x7fc00000) # quiet NaN
normal_reg = Reg(0x3f800000) # 1.0
inf_reg = Reg(0x7f800000) # +inf
self.assertTrue(_isnan(nan_reg.f32), "_isnan should return True for NaN TypedView")
self.assertFalse(_isnan(normal_reg.f32), "_isnan should return False for normal TypedView")
self.assertFalse(_isnan(inf_reg.f32), "_isnan should return False for inf TypedView")
self.assertTrue(isNAN(nan_reg.f32), "isNAN should return True for NaN TypedView")
self.assertFalse(isNAN(normal_reg.f32), "isNAN should return False for normal TypedView")
self.assertFalse(isNAN(inf_reg.f32), "isNAN should return False for inf TypedView")
class TestBF16(unittest.TestCase):
"""Tests for BF16 (bfloat16) support."""
@@ -308,7 +312,7 @@ class TestBF16(unittest.TestCase):
self.assertAlmostEqual(float(r.bf16), 3.0, places=1)
def test_bf16_slice_property(self):
"""Test SliceProxy.bf16 property."""
"""Test TypedView.bf16 property."""
r = Reg(0x40404040) # Two bf16 3.0 values
self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1)
self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1)