This commit is contained in:
George Hotz
2026-01-04 18:35:16 -08:00
parent 8147a78d24
commit 57684d2777
10 changed files with 198 additions and 1257 deletions

View File

@@ -43,6 +43,14 @@ def _i64(f):
try: return _struct_Q.unpack(_struct_d.pack(f))[0]
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
# Float conversion helpers (used by tests)
def f32_to_f16(f):
f = float(f)
if math.isnan(f): return 0x7e00 # f16 NaN
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
try: return _struct_H.unpack(_struct_e.pack(f))[0]
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
# Instruction spec - register counts and dtypes derived from instruction names
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,

View File

@@ -5,7 +5,6 @@ import ctypes, functools
from tinygrad.runtime.autogen import hsa
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.pcode import compile_pseudocode
from extra.assembly.amd.ucode import compile_uop
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
@@ -381,7 +380,7 @@ def decode_program(data: bytes) -> dict[int, Inst]:
# Compile pcode for instructions that use it (not VOPD which has _fnx/_fny, not special dispatches)
# Try ucode first (UOp-based), fall back to pcode (Python exec-based)
def _compile_op(cls_name, op_name, pcode):
return compile_uop(op_name, pcode) or compile_pseudocode(cls_name, op_name, pcode)
return compile_uop(op_name, pcode) #or compile_pseudocode(cls_name, op_name, pcode)
# VOPD needs separate functions for X and Y ops
if isinstance(inst, VOPD):
def _compile_vopd_op(op): return _compile_op(type(op).__name__, op.name, PSEUDOCODE_STRINGS[type(op)][op])

View File

@@ -1,765 +0,0 @@
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
import struct, math, re, functools
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
# ═══════════════════════════════════════════════════════════════════════════════
# INTERNAL HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def _div(a, b):
try: return a / b
except ZeroDivisionError:
if a == 0.0 or math.isnan(a): return float("nan")
return math.copysign(float("inf"), a * b) if b == 0.0 else float("inf") if a > 0 else float("-inf")
def _check_nan_type(x, quiet_bit_expected, default):
try:
if not math.isnan(float(x)): return False
if hasattr(x, '_reg') and hasattr(x, '_bits'):
bits = x._reg._val & ((1 << x._bits) - 1)
exp_bits, quiet_pos, mant_mask = {16: (0x1f, 9, 0x3ff), 32: (0xff, 22, 0x7fffff), 64: (0x7ff, 51, 0xfffffffffffff)}.get(x._bits, (0,0,0))
exp_shift = {16: 10, 32: 23, 64: 52}.get(x._bits, 0)
if exp_bits and ((bits >> exp_shift) & exp_bits) == exp_bits and (bits & mant_mask) != 0:
return ((bits >> quiet_pos) & 1) == quiet_bit_expected
return default
except (TypeError, ValueError): return False
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
def _fpop(fn):
def wrapper(x):
x = float(x)
if math.isnan(x) or math.isinf(x): return x
result = float(fn(x))
return math.copysign(0.0, x) if result == 0.0 else result
return wrapper
def _f_to_int(f, lo, hi): f = float(f); return 0 if math.isnan(f) else (hi if f >= hi else lo if f <= lo else int(f))
def _f16_to_f32_bits(bits): return struct.unpack("<e", struct.pack("<H", int(bits) & 0xffff))[0]
def _brev(v, bits): return int(bin(v & ((1 << bits) - 1))[2:].zfill(bits)[::-1], 2)
def _ctz(v, bits):
v, n = int(v) & ((1 << bits) - 1), 0
if v == 0: return bits
while (v & 1) == 0: v >>= 1; n += 1
return n
def _bf16(i):
"""Convert bf16 bits to float. BF16 is just the top 16 bits of f32."""
return struct.unpack("<f", struct.pack("<I", (i & 0xffff) << 16))[0]
def _ibf16(f):
"""Convert float to bf16 bits (truncate to top 16 bits of f32)."""
if math.isnan(f): return 0x7fc0 # bf16 quiet NaN
if math.isinf(f): return 0x7f80 if f > 0 else 0xff80 # bf16 ±infinity
try: return (struct.unpack("<I", struct.pack("<f", float(f)))[0] >> 16) & 0xffff
except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80
def _trig(fn, x):
# V_SIN/COS_F32: hardware does frac on input cycles before computing
if math.isinf(x) or math.isnan(x): return float("nan")
frac_cycles = fract(x / (2 * math.pi))
result = fn(frac_cycles * 2 * math.pi)
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
# Round very small results (below f32 precision) to exactly 0
if abs(result) < 1e-7: return 0.0
return result
class _SafeFloat(float):
"""Float subclass that uses _div for division to handle 0/inf correctly."""
def __truediv__(self, o): return _div(float(self), float(o))
def __rtruediv__(self, o): return _div(float(o), float(self))
class _Inf:
f16 = f32 = f64 = float('inf')
def __neg__(self): return _NegInf()
def __pos__(self): return self
def __float__(self): return float('inf')
def __eq__(self, other): return float(other) == float('inf') if not isinstance(other, _NegInf) else False
def __req__(self, other): return self.__eq__(other)
class _NegInf:
f16 = f32 = f64 = float('-inf')
def __neg__(self): return _Inf()
def __pos__(self): return self
def __float__(self): return float('-inf')
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
def __req__(self, other): return self.__eq__(other)
class _RoundMode:
NEAREST_EVEN = 0
class _WaveMode:
IEEE = False
class _DenormChecker:
"""Comparator for denormalized floats. x == DENORM.f32 checks if x is denormalized."""
def __init__(self, bits): self._bits = bits
def _check(self, other):
f = float(other)
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
if self._bits == 64:
bits = struct.unpack("<Q", struct.pack("<d", f))[0]
return (bits >> 52) & 0x7ff == 0
bits = struct.unpack("<I", struct.pack("<f", f))[0]
return (bits >> 23) & 0xff == 0
def __eq__(self, other): return self._check(other)
def __req__(self, other): return self._check(other)
def __ne__(self, other): return not self._check(other)
class _Denorm:
f32 = _DenormChecker(32)
f64 = _DenormChecker(64)
_pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
_pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
class TypedView:
"""View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
__slots__ = ('_reg', '_high', '_low', '_signed', '_float', '_bf16', '_reversed')
def __init__(self, reg, high, low=0, signed=False, is_float=False, is_bf16=False):
# Handle reversed slices like [0:31] which means bit-reverse
if high < low: high, low, reversed = low, high, True
else: reversed = False
self._reg, self._high, self._low, self._reversed = reg, high, low, reversed
self._signed, self._float, self._bf16 = signed, is_float, is_bf16
def _nbits(self): return self._high - self._low + 1
def _mask(self): return (1 << self._nbits()) - 1
def _get(self):
v = (self._reg._val >> self._low) & self._mask()
return _brev(v, self._nbits()) if self._reversed else v
def _set(self, v):
v = int(v)
if self._reversed: v = _brev(v, self._nbits())
self._reg._val = (self._reg._val & ~(self._mask() << self._low)) | ((v & self._mask()) << self._low)
@property
def _val(self): return self._get()
@property
def _bits(self): return self._nbits()
# Type accessors for slices (e.g., D0[31:16].f16)
u8 = property(lambda s: s._get() & 0xff)
u16 = property(lambda s: s._get() & 0xffff, lambda s, v: s._set(v))
u32 = property(lambda s: s._get() & MASK32, lambda s, v: s._set(v))
i16 = property(lambda s: _sext(s._get() & 0xffff, 16), lambda s, v: s._set(v))
i32 = property(lambda s: _sext(s._get() & MASK32, 32), lambda s, v: s._set(v))
f16 = property(lambda s: _f16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _i16(float(v))))
f32 = property(lambda s: _f32(s._get()), lambda s, v: s._set(_i32(float(v))))
bf16 = property(lambda s: _bf16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _ibf16(float(v))))
b16, b32 = u16, u32
# Chained type access (e.g., jump_addr.i64 when jump_addr is already TypedView)
@property
def i64(s): return s if s._nbits() == 64 and s._signed else int(s)
@property
def u64(s): return s if s._nbits() == 64 and not s._signed else int(s) & MASK64
def __getitem__(self, key):
if isinstance(key, slice):
high, low = int(key.start), int(key.stop)
return TypedView(self._reg, high, low)
return (self._get() >> int(key)) & 1
def __setitem__(self, key, value):
if isinstance(key, slice):
high, low = int(key.start), int(key.stop)
if high < low: high, low, value = low, high, _brev(int(value), low - high + 1)
mask = (1 << (high - low + 1)) - 1
self._reg._val = (self._reg._val & ~(mask << low)) | ((int(value) & mask) << low)
elif value: self._reg._val |= (1 << int(key))
else: self._reg._val &= ~(1 << int(key))
def __int__(self): return _sext(self._get(), self._nbits()) if self._signed else self._get()
def __index__(self): return int(self)
def __trunc__(self): return int(float(self)) if self._float else int(self)
def __float__(self):
if self._float:
if self._bf16: return _bf16(self._get())
bits = self._nbits()
return _f16(self._get()) if bits == 16 else _f32(self._get()) if bits == 32 else _f64(self._get())
return float(int(self))
def __bool__(s): return bool(int(s))
# Arithmetic - floats use float(), ints use int()
def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o)
def __radd__(s, o): return float(o) + float(s) if s._float else int(o) + int(s)
def __sub__(s, o): return float(s) - float(o) if s._float else int(s) - int(o)
def __rsub__(s, o): return float(o) - float(s) if s._float else int(o) - int(s)
def __mul__(s, o): return float(s) * float(o) if s._float else int(s) * int(o)
def __rmul__(s, o): return float(o) * float(s) if s._float else int(o) * int(s)
def __truediv__(s, o): return _div(float(s), float(o)) if s._float else _div(int(s), int(o))
def __rtruediv__(s, o): return _div(float(o), float(s)) if s._float else _div(int(o), int(s))
def __pow__(s, o): return float(s) ** float(o) if s._float else int(s) ** int(o)
def __rpow__(s, o): return float(o) ** float(s) if s._float else int(o) ** int(s)
def __neg__(s): return -float(s) if s._float else -int(s)
def __abs__(s): return abs(float(s)) if s._float else abs(int(s))
# Bitwise - GPU shifts mask the shift amount to valid range
def __and__(s, o): return int(s) & int(o)
def __or__(s, o): return int(s) | int(o)
def __xor__(s, o): return int(s) ^ int(o)
def __invert__(s): return ~int(s)
def __lshift__(s, o): n = int(o); return int(s) << n if 0 <= n < 64 or s._nbits() > 64 else 0
def __rshift__(s, o): n = int(o); return int(s) >> n if 0 <= n < 64 or s._nbits() > 64 else 0
def __rand__(s, o): return int(o) & int(s)
def __ror__(s, o): return int(o) | int(s)
def __rxor__(s, o): return int(o) ^ int(s)
def __rlshift__(s, o): n = int(s); return int(o) << n if 0 <= n < 64 else 0
def __rrshift__(s, o): n = int(s); return int(o) >> n if 0 <= n < 64 else 0
# Comparison - handle _DenormChecker specially
def __eq__(s, o):
if isinstance(o, _DenormChecker): return o._check(s)
return float(s) == float(o) if s._float else int(s) == int(o)
def __ne__(s, o):
if isinstance(o, _DenormChecker): return not o._check(s)
return float(s) != float(o) if s._float else int(s) != int(o)
def __lt__(s, o): return float(s) < float(o) if s._float else int(s) < int(o)
def __le__(s, o): return float(s) <= float(o) if s._float else int(s) <= int(o)
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
class Reg:
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
__slots__ = ('_val',)
def __init__(self, val=0): self._val = int(val)
# Typed views - TypedView(reg, high, signed, is_float, is_bf16)
u64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
i64 = property(lambda s: TypedView(s, 63, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64))
b64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
f64 = property(lambda s: TypedView(s, 63, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v))))
u32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
i32 = property(lambda s: TypedView(s, 31, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32))
b32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
f32 = property(lambda s: TypedView(s, 31, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v))))
u24 = property(lambda s: TypedView(s, 23))
i24 = property(lambda s: TypedView(s, 23, signed=True))
u16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
i16 = property(lambda s: TypedView(s, 15, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
b16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
f16 = property(lambda s: TypedView(s, 15, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff)))
bf16 = property(lambda s: TypedView(s, 15, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
u8 = property(lambda s: TypedView(s, 7))
i8 = property(lambda s: TypedView(s, 7, signed=True))
u3 = property(lambda s: TypedView(s, 2)) # 3-bit for opsel fields
u1 = property(lambda s: TypedView(s, 0)) # single bit
def __getitem__(s, key):
if isinstance(key, slice): return TypedView(s, int(key.start), int(key.stop))
return (s._val >> int(key)) & 1
def __setitem__(s, key, value):
if isinstance(key, slice):
high, low = int(key.start), int(key.stop)
if high < low: high, low = low, high
mask = (1 << (high - low + 1)) - 1
s._val = (s._val & ~(mask << low)) | ((int(value) & mask) << low)
elif value: s._val |= (1 << int(key))
else: s._val &= ~(1 << int(key))
def __int__(s): return s._val
def __index__(s): return s._val
def __bool__(s): return bool(s._val)
# Arithmetic (for tmp = tmp + 1 patterns). Float operands trigger f32 interpretation.
def __add__(s, o): return (_f32(s._val) + float(o)) if isinstance(o, float) else s._val + int(o)
def __radd__(s, o): return (float(o) + _f32(s._val)) if isinstance(o, float) else int(o) + s._val
def __sub__(s, o): return (_f32(s._val) - float(o)) if isinstance(o, float) else s._val - int(o)
def __rsub__(s, o): return (float(o) - _f32(s._val)) if isinstance(o, float) else int(o) - s._val
def __mul__(s, o): return (_f32(s._val) * float(o)) if isinstance(o, float) else s._val * int(o)
def __rmul__(s, o): return (float(o) * _f32(s._val)) if isinstance(o, float) else int(o) * s._val
def __and__(s, o): return s._val & int(o)
def __rand__(s, o): return int(o) & s._val
def __or__(s, o): return s._val | int(o)
def __ror__(s, o): return int(o) | s._val
def __xor__(s, o): return s._val ^ int(o)
def __rxor__(s, o): return int(o) ^ s._val
def __lshift__(s, o): n = int(o); return s._val << n if 0 <= n < 64 else 0
def __rshift__(s, o): n = int(o); return s._val >> n if 0 <= n < 64 else 0
def __invert__(s): return ~s._val
# Comparison (for tmp >= 0x100000000 patterns)
def __lt__(s, o): return s._val < int(o)
def __le__(s, o): return s._val <= int(o)
def __gt__(s, o): return s._val > int(o)
def __ge__(s, o): return s._val >= int(o)
def __eq__(s, o): return s._val == int(o)
def __ne__(s, o): return s._val != int(o)
# ═══════════════════════════════════════════════════════════════════════════════
# PSEUDOCODE API - Functions and constants from AMD ISA pseudocode
# ═══════════════════════════════════════════════════════════════════════════════
# Rounding and float operations
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
def fract(x): return x - math.floor(x)
def sin(x): return _trig(math.sin, x)
def cos(x): return _trig(math.cos, x)
def pow(a, b):
try: return a ** b
except OverflowError: return float("inf") if b > 0 else 0.0
def isEven(x):
x = float(x)
if math.isinf(x) or math.isnan(x): return False
return int(x) % 2 == 0
def mantissa(f):
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
m, _ = math.frexp(f)
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
def signext_from_bit(val, bit):
bit = int(bit)
if bit == 0: return 0
mask = (1 << bit) - 1
val = int(val) & mask
if val & (1 << (bit - 1)): return val - (1 << bit)
return val
# Type conversions
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
def f32_to_f16(f):
f = float(f)
if math.isnan(f): return 0x7e00 # f16 NaN
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
try: return struct.unpack("<H", struct.pack("<e", f))[0]
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
def f32_to_bf16(f): return _ibf16(f)
def u8_to_u32(v): return int(v) & 0xff
def u4_to_u32(v): return int(v) & 0xf
def u32_to_u16(u): return int(u) & 0xffff
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
def SAT8(v): return max(0, min(255, int(v)))
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
# Min/max operations
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
v_min_i32, v_max_i32 = min, max
v_min_i16, v_max_i16 = min, max
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
# SAD/MSAD operations
def ABSDIFF(a, b): return abs(int(a) - int(b))
def v_sad_u8(s0, s1, s2):
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
result += abs(a - b)
return result & 0xffffffff
def v_msad_u8(s0, s1, s2):
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
s0, s1, s2 = int(s0), int(s1), int(s2)
result = s2
for i in range(4):
a = (s0 >> (i * 8)) & 0xff
b = (s1 >> (i * 8)) & 0xff
if b != 0: # Only add diff if reference (s1) byte is non-zero
result += abs(a - b)
return result & 0xffffffff
def BYTE_PERMUTE(data, sel):
"""Select a byte from 64-bit data based on selector value."""
sel = int(sel) & 0xff
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00
if sel == 12: return 0x00
return 0xff
# Pseudocode functions
def s_ff1_i32_b32(v): return _ctz(v, 32)
def s_ff1_i32_b64(v): return _ctz(v, 64)
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
def isNAN(x):
try: return math.isnan(float(x))
except (TypeError, ValueError): return False
def isQuietNAN(x): return _check_nan_type(x, 1, True)
def isSignalNAN(x): return _check_nan_type(x, 0, False)
def fma(a, b, c):
try: return math.fma(a, b, c)
except ValueError: return float('nan')
def ldexp(m, e): return math.ldexp(m, e)
def sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
def exponent(f):
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
raw = f._val
if f._bits == 16: return (raw >> 10) & 0x1f
if f._bits == 32: return (raw >> 23) & 0xff
if f._bits == 64: return (raw >> 52) & 0x7ff
f = float(f)
if math.isinf(f) or math.isnan(f): return 255
if f == 0.0: return 0
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
except: return 0
def signext(x): return int(x)
def cvtToQuietNAN(x): return float('nan')
def F(x):
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
if isinstance(x, int): return _f32(x)
if isinstance(x, TypedView): return x
return float(x)
# Constants
PI = math.pi
WAVE32, WAVE64 = True, False
OVERFLOW_F32, UNDERFLOW_F32 = float('inf'), 0.0
OVERFLOW_F64, UNDERFLOW_F64 = float('inf'), 0.0
MAX_FLOAT_F32 = 3.4028235e+38
INF = _Inf()
ROUND_MODE = _RoundMode()
WAVE_MODE = _WaveMode()
DENORM = _Denorm()
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
# ═══════════════════════════════════════════════════════════════════════════════
def _compile_pseudocode(pseudocode: str) -> str:
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
raw_lines = pseudocode.strip().split('\n')
joined_lines: list[str] = []
for line in raw_lines:
line = line.strip()
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
else:
joined_lines.append(line)
lines = []
indent, need_pass, in_first_match_loop = 0, False, False
for line in joined_lines:
line = line.split('//')[0].strip() # Strip C-style comments
if not line: continue
if line.startswith('if '):
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line.startswith('elsif '):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
indent += 1
need_pass = True
elif line == 'else':
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
lines.append(' ' * indent + "else:")
indent += 1
need_pass = True
elif line.startswith('endif'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass = False
elif line.startswith('endfor'):
if need_pass: lines.append(' ' * indent + "pass")
indent -= 1
need_pass, in_first_match_loop = False, False
elif line.startswith('declare '):
pass
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
start, end = _expr(m[2].strip()), _expr(m[3].strip())
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
indent += 1
need_pass, in_first_match_loop = True, True
elif '=' in line and not line.startswith('=='):
need_pass = False
line = line.rstrip(';')
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
rhs = _expr(m[1])
lines.append(' ' * indent + f"_full = {rhs}")
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
if op in line:
lhs, rhs = line.split(op, 1)
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
lines.append(' ' * indent + stmt)
if need_pass: lines.append(' ' * indent + "pass")
return '\n'.join(lines)
def _assign(lhs: str, rhs: str) -> str:
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
return f"{lhs} = Reg({rhs})"
return f"{lhs} = {rhs}"
def _expr(e: str) -> str:
e = e.strip()
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
e = re.sub(r'!([^=])', r' not \1', e)
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
def pack(m):
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
return f'_pack({hi}, {lo})'
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
e = re.sub(r"\d+'[FIBU]\(", "(", e)
e = re.sub(r'\bB\(', '(', e)
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
def convert_verilog_slice(m):
start, width = m.group(1).strip(), m.group(2).strip()
return f'[({start}) + ({width}) - 1 : ({start})]'
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
def process_brackets(s):
result, i = [], 0
while i < len(s):
if s[i] == '[':
depth, start = 1, i + 1
j = start
while j < len(s) and depth > 0:
if s[j] == '[': depth += 1
elif s[j] == ']': depth -= 1
j += 1
inner = _expr(s[start:j-1])
result.append('[' + inner + ']')
i = j
else:
result.append(s[i])
i += 1
return ''.join(result)
e = process_brackets(e)
while '?' in e:
depth, bracket, q = 0, 0, -1
for i, c in enumerate(e):
if c == '(': depth += 1
elif c == ')': depth -= 1
elif c == '[': bracket += 1
elif c == ']': bracket -= 1
elif c == '?' and depth == 0 and bracket == 0: q = i; break
if q < 0: break
depth, bracket, col = 0, 0, -1
for i in range(q + 1, len(e)):
if e[i] == '(': depth += 1
elif e[i] == ')': depth -= 1
elif e[i] == '[': bracket += 1
elif e[i] == ']': bracket -= 1
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
if col < 0: break
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
e = f'(({t}) if ({cond}) else ({f}))'
return e
def _apply_pseudocode_fixes(op_name: str, code: str) -> str:
"""Apply known fixes for PDF pseudocode bugs."""
if op_name == 'V_DIV_FMAS_F32':
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
if op_name == 'V_DIV_FMAS_F64':
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
if op_name == 'V_DIV_SCALE_F32':
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
if op_name == 'V_DIV_SCALE_F64':
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
if op_name == 'V_DIV_FIXUP_F32':
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
if op_name == 'V_DIV_FIXUP_F64':
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
if op_name == 'V_TRIG_PREOP_F64':
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
return code
def _generate_function(cls_name: str, op_name: str, pc: str, code: str) -> str:
"""Generate a single compiled pseudocode function.
Functions take int parameters and return dict of int values.
Reg wrapping happens inside the function, only for registers actually used."""
has_d1 = '{ D1' in pc
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op_name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
is_ds = cls_name == 'DSOp'
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
is_smem = cls_name == 'SMEMOp'
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
combined = code + pc
fn_name = f"_{cls_name}_{op_name}"
# Detect which registers are used/modified
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# Build function signature and Reg init lines
if is_smem:
lines = [f"def {fn_name}(MEM, addr):"]
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
special_regs = []
elif is_ds:
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
elif is_flat:
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
special_regs = [('DATA', 'VDATA')]
elif has_s_array:
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
special_regs = []
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
if "in[" in combined:
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
code = code.replace("in[", "ins[")
else:
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
# Only create Regs for registers actually used in the pseudocode
reg_inits = []
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
# Build init code
init_parts = reg_inits.copy()
for name, init in special_regs:
if name in combined: init_parts.append(f"{name}={init}")
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=TypedView(EXEC, 31, 0)")
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=TypedView(EXEC, 63, 32)")
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
# Add init line and separator
if init_parts: lines.append(f" {'; '.join(init_parts)}")
# Add compiled pseudocode
for line in code.split('\n'):
if line.strip(): lines.append(f" {line}")
# Build result dict
result_items = []
if modifies_d0: result_items.append("'D0': D0._val")
if modifies_scc: result_items.append("'SCC': SCC._val")
if modifies_vcc: result_items.append("'VCC': VCC._val")
if modifies_exec: result_items.append("'EXEC': EXEC._val")
if has_d1: result_items.append("'D1': D1._val")
if modifies_pc: result_items.append("'PC': PC._val")
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'SDATA': SDATA._val")
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if is_flat:
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'RETURN_DATA': RETURN_DATA._val")
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
result_items.append("'VDATA': VDATA._val")
lines.append(f" return {{{', '.join(result_items)}}}")
return '\n'.join(lines)
# Build the globals dict for exec() - includes all pcode symbols
_PCODE_GLOBALS = {
'Reg': Reg, 'TypedView': TypedView, '_pack': _pack, '_pack32': _pack32,
'ABSDIFF': ABSDIFF, 'BYTE_PERMUTE': BYTE_PERMUTE, 'DENORM': DENORM, 'F': F,
'GT_NEG_ZERO': GT_NEG_ZERO, 'LT_NEG_ZERO': LT_NEG_ZERO, 'INF': INF,
'MAX_FLOAT_F32': MAX_FLOAT_F32, 'OVERFLOW_F32': OVERFLOW_F32, 'OVERFLOW_F64': OVERFLOW_F64,
'UNDERFLOW_F32': UNDERFLOW_F32, 'UNDERFLOW_F64': UNDERFLOW_F64,
'PI': PI, 'ROUND_MODE': ROUND_MODE, 'WAVE_MODE': WAVE_MODE,
'WAVE32': WAVE32, 'WAVE64': WAVE64, 'TWO_OVER_PI_1201': TWO_OVER_PI_1201,
'SAT8': SAT8, 'trunc': trunc, 'floor': floor, 'ceil': ceil, 'sqrt': sqrt,
'log2': log2, 'fract': fract, 'sin': sin, 'cos': cos, 'pow': pow,
'isEven': isEven, 'mantissa': mantissa, 'signext_from_bit': signext_from_bit,
'i32_to_f32': i32_to_f32, 'u32_to_f32': u32_to_f32, 'i32_to_f64': i32_to_f64,
'u32_to_f64': u32_to_f64, 'f32_to_f64': f32_to_f64, 'f64_to_f32': f64_to_f32,
'f32_to_i32': f32_to_i32, 'f32_to_u32': f32_to_u32, 'f64_to_i32': f64_to_i32,
'f64_to_u32': f64_to_u32, 'f32_to_f16': f32_to_f16, 'f16_to_f32': f16_to_f32,
'i16_to_f16': i16_to_f16, 'u16_to_f16': u16_to_f16, 'f16_to_i16': f16_to_i16,
'f16_to_u16': f16_to_u16, 'bf16_to_f32': bf16_to_f32, 'f32_to_bf16': f32_to_bf16,
'u8_to_u32': u8_to_u32, 'u4_to_u32': u4_to_u32, 'u32_to_u16': u32_to_u16,
'i32_to_i16': i32_to_i16, 'f16_to_snorm': f16_to_snorm, 'f16_to_unorm': f16_to_unorm,
'f32_to_snorm': f32_to_snorm, 'f32_to_unorm': f32_to_unorm,
'v_cvt_i16_f32': v_cvt_i16_f32, 'v_cvt_u16_f32': v_cvt_u16_f32, 'f32_to_u8': f32_to_u8,
'v_min_f32': v_min_f32, 'v_max_f32': v_max_f32, 'v_min_f16': v_min_f16, 'v_max_f16': v_max_f16,
'v_min_i32': v_min_i32, 'v_max_i32': v_max_i32, 'v_min_i16': v_min_i16, 'v_max_i16': v_max_i16,
'v_min_u32': v_min_u32, 'v_max_u32': v_max_u32, 'v_min_u16': v_min_u16, 'v_max_u16': v_max_u16,
'v_min3_f32': v_min3_f32, 'v_max3_f32': v_max3_f32, 'v_min3_f16': v_min3_f16, 'v_max3_f16': v_max3_f16,
'v_min3_i32': v_min3_i32, 'v_max3_i32': v_max3_i32, 'v_min3_i16': v_min3_i16, 'v_max3_i16': v_max3_i16,
'v_min3_u32': v_min3_u32, 'v_max3_u32': v_max3_u32, 'v_min3_u16': v_min3_u16, 'v_max3_u16': v_max3_u16,
'v_sad_u8': v_sad_u8, 'v_msad_u8': v_msad_u8,
's_ff1_i32_b32': s_ff1_i32_b32, 's_ff1_i32_b64': s_ff1_i32_b64,
'isNAN': isNAN, 'isQuietNAN': isQuietNAN, 'isSignalNAN': isSignalNAN,
'fma': fma, 'ldexp': ldexp, 'sign': sign, 'exponent': exponent,
'signext': signext, 'cvtToQuietNAN': cvtToQuietNAN,
}
@functools.cache
def compile_pseudocode(cls_name: str, op_name: str, pseudocode: str):
"""Compile pseudocode string to executable function. Cached for performance."""
code = _compile_pseudocode(pseudocode)
code = _apply_pseudocode_fixes(op_name, code)
fn_code = _generate_function(cls_name, op_name, pseudocode, code)
fn_name = f"_{cls_name}_{op_name}"
local_ns = {}
exec(fn_code, _PCODE_GLOBALS, local_ns)
return local_ns[fn_name]

View File

@@ -22,45 +22,3 @@ def get_llvm_objdump():
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
if shutil.which(p): return p
raise FileNotFoundError("llvm-objdump not found")
# ═══════════════════════════════════════════════════════════════════════════════
# EXECUTION CONTEXT (for testing compiled pseudocode)
# ═══════════════════════════════════════════════════════════════════════════════
class ExecContext:
"""Context for running compiled pseudocode in tests."""
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, TypedView
self._Reg, self._MASK64, self._TypedView = Reg, MASK64, TypedView
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
self.D0, self.D1 = Reg(d0), Reg(0)
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
self.lane, self.laneId, self.literal = lane, lane, literal
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
self.VGPR = vgprs if vgprs is not None else {}
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
def run(self, code: str):
"""Execute compiled code."""
import extra.assembly.amd.pcode as pcode
ns = {k: getattr(pcode, k) for k in dir(pcode) if not k.startswith('_')}
# Also include underscore-prefixed helpers that compiled pseudocode uses
for k in ['_pack', '_pack32']:
if hasattr(pcode, k): ns[k] = getattr(pcode, k)
ns.update({
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
'EXEC_LO': self._TypedView(self.EXEC, 31, 0), 'EXEC_HI': self._TypedView(self.EXEC, 63, 32),
'tmp': self.tmp, 'saveexec': self.saveexec,
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32, 'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
})
exec(code, ns)
def _sync(ctx_reg, ns_val):
if isinstance(ns_val, self._Reg): ctx_reg._val = ns_val._val
else: ctx_reg._val = int(ns_val) & self._MASK64
for name in ('SCC', 'VCC', 'EXEC', 'D0', 'D1', 'tmp', 'saveexec'):
if ns.get(name) is not getattr(self, name): _sync(getattr(self, name), ns[name])
def result(self) -> dict: return {"d0": self.D0._val, "scc": self.SCC._val & 1}

View File

@@ -5,9 +5,8 @@ Set USE_HW=1 to run on both emulator and real hardware, comparing results.
"""
import ctypes, os, struct
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import RawImm
from extra.assembly.amd.dsl import RawImm, _i32, _f32
from extra.assembly.amd.emu import WaveState, run_asm, set_valid_mem_ranges
from extra.assembly.amd.pcode import _i32, _f32
VCC = SrcEnum.VCC_LO # For VOP3SD sdst field
USE_HW = os.environ.get("USE_HW", "0") == "1"

View File

@@ -255,7 +255,7 @@ class TestF16Conversions(unittest.TestCase):
def test_v_cvt_f16_f32_small(self):
"""V_CVT_F16_F32 converts small f32 value."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
instructions = [
v_mov_b32_e32(v[0], 0.5),
v_cvt_f16_f32_e32(v[1], v[0]),
@@ -293,7 +293,7 @@ class TestF16Conversions(unittest.TestCase):
def test_v_cvt_f16_f32_reads_full_32bit_source(self):
"""V_CVT_F16_F32 must read full 32-bit f32 source."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x3fc00000), # f32 1.5
v_mov_b32_e32(v[0], s[0]),
@@ -560,7 +560,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_abs_negative(self):
"""V_CVT_F32_F16 with |abs| on negative value."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
instructions = [
s_mov_b32(s[0], f16_neg1),
@@ -573,7 +573,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_abs_positive(self):
"""V_CVT_F32_F16 with |abs| on positive value (should stay positive)."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
f16_2 = f32_to_f16(2.0) # 0x4000
instructions = [
s_mov_b32(s[0], f16_2),
@@ -586,7 +586,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_neg_positive(self):
"""V_CVT_F32_F16 with neg on positive value."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
f16_2 = f32_to_f16(2.0) # 0x4000
instructions = [
s_mov_b32(s[0], f16_2),
@@ -599,7 +599,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f32_f16_neg_negative(self):
"""V_CVT_F32_F16 with neg on negative value (double negative)."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
f16_neg2 = f32_to_f16(-2.0) # 0xc000
instructions = [
s_mov_b32(s[0], f16_neg2),
@@ -612,7 +612,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
def test_v_cvt_f16_f32_then_pack_for_wmma(self):
"""CVT F32->F16 followed by pack (common WMMA pattern)."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
f32_val = 3.5
instructions = [
s_mov_b32(s[0], f2i(f32_val)),
@@ -668,7 +668,7 @@ class TestConversionRounding(unittest.TestCase):
def test_f16_to_f32_precision(self):
"""F16 to F32 conversion precision."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
f16_val = f32_to_f16(1.5)
instructions = [
s_mov_b32(s[0], f16_val),
@@ -680,7 +680,7 @@ class TestConversionRounding(unittest.TestCase):
def test_f16_denormal_to_f32(self):
"""F16 denormal converts to small positive f32."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
f16_denorm = 0x0001 # Smallest positive f16 denormal
instructions = [
v_mov_b32_e32(v[0], f16_denorm),

View File

@@ -768,7 +768,7 @@ class TestF16Modifiers(unittest.TestCase):
def test_v_fma_f16_inline_const_1_0(self):
"""V_FMA_F16: a*b + 1.0 should use f16 inline constant."""
from extra.assembly.amd.pcode import f32_to_f16, _f16
from extra.assembly.amd.dsl import f32_to_f16, _f16
f16_a = f32_to_f16(0.325928) # ~0x3537
f16_b = f32_to_f16(-0.486572) # ~0xb7c9
instructions = [
@@ -785,7 +785,7 @@ class TestF16Modifiers(unittest.TestCase):
def test_v_fma_f16_inline_const_0_5(self):
"""V_FMA_F16: a*b + 0.5 should use f16 inline constant."""
from extra.assembly.amd.pcode import f32_to_f16, _f16
from extra.assembly.amd.dsl import f32_to_f16, _f16
f16_a = f32_to_f16(2.0)
f16_b = f32_to_f16(3.0)
instructions = [
@@ -802,7 +802,7 @@ class TestF16Modifiers(unittest.TestCase):
def test_v_fma_f16_inline_const_neg_1_0(self):
"""V_FMA_F16: a*b + (-1.0) should use f16 inline constant."""
from extra.assembly.amd.pcode import f32_to_f16, _f16
from extra.assembly.amd.dsl import f32_to_f16, _f16
f16_a = f32_to_f16(2.0)
f16_b = f32_to_f16(3.0)
instructions = [
@@ -819,7 +819,7 @@ class TestF16Modifiers(unittest.TestCase):
def test_v_add_f16_abs_both(self):
"""V_ADD_F16 with abs on both operands."""
from extra.assembly.amd.pcode import f32_to_f16, _f16
from extra.assembly.amd.dsl import f32_to_f16, _f16
f16_neg2 = f32_to_f16(-2.0)
f16_neg3 = f32_to_f16(-3.0)
instructions = [
@@ -835,7 +835,7 @@ class TestF16Modifiers(unittest.TestCase):
def test_v_mul_f16_neg_abs(self):
"""V_MUL_F16 with neg on one operand and abs on another."""
from extra.assembly.amd.pcode import f32_to_f16, _f16
from extra.assembly.amd.dsl import f32_to_f16, _f16
f16_2 = f32_to_f16(2.0)
f16_neg3 = f32_to_f16(-3.0)
instructions = [
@@ -854,7 +854,7 @@ class TestF16Modifiers(unittest.TestCase):
This tests the case from AMD_LLVM sin(0) where V_FMAC_F16 writes to v0.h.
"""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x38003c00), # v0 = {hi=0.5, lo=1.0}
v_mov_b32_e32(v[0], s[0]),

View File

@@ -155,7 +155,7 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mix_f32_src2_f16_lo(self):
"""V_FMA_MIX_F32 with src2 as f16 from lo bits."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
f16_2 = f32_to_f16(2.0)
instructions = [
s_mov_b32(s[0], f2i(1.0)),
@@ -172,7 +172,7 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mix_f32_src2_f16_hi(self):
"""V_FMA_MIX_F32 with src2 as f16 from hi bits."""
from extra.assembly.amd.pcode import f32_to_f16
from extra.assembly.amd.dsl import f32_to_f16
f16_2 = f32_to_f16(2.0)
val = (f16_2 << 16) | 0
instructions = [
@@ -205,7 +205,7 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mixlo_f16(self):
"""V_FMA_MIXLO_F16 writes to low 16 bits of destination."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], f2i(2.0)),
v_mov_b32_e32(v[0], s[0]),
@@ -225,7 +225,7 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mixlo_f16_all_f32_sources(self):
"""V_FMA_MIXLO_F16 with all f32 sources."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], f2i(1.0)),
v_mov_b32_e32(v[0], s[0]),
@@ -243,7 +243,7 @@ class TestFmaMix(unittest.TestCase):
def test_v_fma_mixlo_f16_sin_case(self):
"""V_FMA_MIXLO_F16 case from sin kernel."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x3f800000), # f32 1.0
v_mov_b32_e32(v[3], s[0]),
@@ -265,7 +265,7 @@ class TestVOP3P(unittest.TestCase):
def test_v_pk_add_f16_basic(self):
"""V_PK_ADD_F16 adds two packed f16 values."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
@@ -282,7 +282,7 @@ class TestVOP3P(unittest.TestCase):
def test_v_pk_mul_f16_basic(self):
"""V_PK_MUL_F16 multiplies two packed f16 values."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x42004000), # hi=3.0, lo=2.0
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
@@ -299,7 +299,7 @@ class TestVOP3P(unittest.TestCase):
def test_v_pk_fma_f16_basic(self):
"""V_PK_FMA_F16: D = A * B + C for packed f16."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x42004000), # A: hi=3.0, lo=2.0
s_mov_b32(s[1], 0x45004400), # B: hi=5.0, lo=4.0
@@ -321,7 +321,7 @@ class TestVOP3P(unittest.TestCase):
Inline constants for VOP3P are f16 values in the low 16 bits only.
hi half of inline constant is 0, so hi result = v0.hi + 0 = 1.0.
"""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
v_mov_b32_e32(v[0], s[0]),
@@ -339,7 +339,7 @@ class TestVOP3P(unittest.TestCase):
"""V_PK_MUL_F16 with inline constant POS_TWO (2.0).
Inline constant has value only in low 16 bits, hi is 0.
"""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
# v0 = packed (3.0, 4.0), multiply by POS_TWO
# lo = 3.0 * 2.0 = 6.0, hi = 4.0 * 0.0 = 0.0 (inline const hi is 0)
instructions = [
@@ -504,7 +504,7 @@ class TestPackedMixedSigns(unittest.TestCase):
def test_pk_add_f16_mixed_signs(self):
"""V_PK_ADD_F16 with mixed positive/negative values."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0xc0003c00), # packed: hi=-2.0, lo=1.0
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
@@ -521,7 +521,7 @@ class TestPackedMixedSigns(unittest.TestCase):
def test_pk_mul_f16_zero(self):
"""V_PK_MUL_F16 with zero."""
from extra.assembly.amd.pcode import _f16
from extra.assembly.amd.dsl import _f16
instructions = [
s_mov_b32(s[0], 0x40004000), # packed: 2.0, 2.0
s_mov_b32(s[1], 0x00000000), # packed: 0.0, 0.0

View File

@@ -1,403 +0,0 @@
#!/usr/bin/env python3
"""Tests for the RDNA3 pseudocode DSL."""
import unittest
from extra.assembly.amd.pcode import (Reg, TypedView, TypedView, MASK32, MASK64,
_f32, _i32, _f16, _i16, f32_to_f16, isNAN, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
BYTE_PERMUTE, v_sad_u8, v_msad_u8, _compile_pseudocode, _expr, compile_pseudocode)
from extra.assembly.amd.test.helpers import ExecContext
from extra.assembly.amd.autogen.rdna3.str_pcode import VOP3SDOp_PCODE, VOPCOp_PCODE
from extra.assembly.amd.autogen.rdna3.enum import VOP3SDOp, VOPCOp
# Compile pseudocode functions on demand for regression tests
_VOP3SDOp_V_DIV_SCALE_F32 = compile_pseudocode('VOP3SDOp', 'V_DIV_SCALE_F32', VOP3SDOp_PCODE[VOP3SDOp.V_DIV_SCALE_F32])
_VOPCOp_V_CMP_CLASS_F32 = compile_pseudocode('VOPCOp', 'V_CMP_CLASS_F32', VOPCOp_PCODE[VOPCOp.V_CMP_CLASS_F32])
class TestReg(unittest.TestCase):
def test_u32_read(self):
r = Reg(0xDEADBEEF)
self.assertEqual(int(r.u32), 0xDEADBEEF)
def test_u32_write(self):
r = Reg(0)
r.u32 = 0x12345678
self.assertEqual(r._val, 0x12345678)
def test_f32_read(self):
r = Reg(0x40400000) # 3.0f
self.assertAlmostEqual(float(r.f32), 3.0)
def test_f32_write(self):
r = Reg(0)
r.f32 = 3.0
self.assertEqual(r._val, 0x40400000)
def test_i32_signed(self):
r = Reg(0xFFFFFFFF) # -1 as signed
self.assertEqual(int(r.i32), -1)
def test_u64(self):
r = Reg(0xDEADBEEFCAFEBABE)
self.assertEqual(int(r.u64), 0xDEADBEEFCAFEBABE)
def test_f64(self):
r = Reg(0x4008000000000000) # 3.0 as f64
self.assertAlmostEqual(float(r.f64), 3.0)
class TestTypedView(unittest.TestCase):
def test_bit_slice(self):
r = Reg(0xDEADBEEF)
# Slices return TypedView which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
self.assertEqual(r.u32[7:0].u32, 0xEF)
self.assertEqual(r.u32[15:8].u32, 0xBE)
self.assertEqual(r.u32[23:16].u32, 0xAD)
self.assertEqual(r.u32[31:24].u32, 0xDE)
# Also works with int() for arithmetic
self.assertEqual(int(r.u32[7:0]), 0xEF)
def test_single_bit_read(self):
r = Reg(0b11010101)
self.assertEqual(r.u32[0], 1)
self.assertEqual(r.u32[1], 0)
self.assertEqual(r.u32[2], 1)
self.assertEqual(r.u32[3], 0)
def test_single_bit_write(self):
r = Reg(0)
r.u32[5] = 1
r.u32[3] = 1
self.assertEqual(r._val, 0b00101000)
def test_nested_bit_access(self):
# S0.u32[S1.u32[4:0]] - access bit at position from another register
s0 = Reg(0b11010101)
s1 = Reg(3)
bit_pos = s1.u32[4:0] # TypedView, int value = 3
bit_val = s0.u32[int(bit_pos)] # bit 3 of s0 = 0
self.assertEqual(int(bit_pos), 3)
self.assertEqual(bit_val, 0)
def test_arithmetic(self):
r1 = Reg(0x40400000) # 3.0f
r2 = Reg(0x40800000) # 4.0f
result = r1.f32 + r2.f32
self.assertAlmostEqual(result, 7.0)
def test_comparison(self):
r1 = Reg(5)
r2 = Reg(3)
self.assertTrue(r1.u32 > r2.u32)
self.assertFalse(r1.u32 < r2.u32)
self.assertTrue(r1.u32 != r2.u32)
class TestTypedView(unittest.TestCase):
def test_slice_read(self):
r = Reg(0x56781234)
self.assertEqual(r[15:0].u16, 0x1234)
self.assertEqual(r[31:16].u16, 0x5678)
def test_slice_write(self):
r = Reg(0)
r[15:0].u16 = 0x1234
r[31:16].u16 = 0x5678
self.assertEqual(r._val, 0x56781234)
def test_slice_f16(self):
r = Reg(0)
r[15:0].f16 = 3.0
self.assertAlmostEqual(_f16(r._val & 0xffff), 3.0, places=2)
class TestCompiler(unittest.TestCase):
def test_ternary(self):
result = _expr("a > b ? 1 : 0")
self.assertIn("if", result)
self.assertIn("else", result)
def test_type_prefix_strip(self):
self.assertEqual(_expr("1'0U"), "0")
self.assertEqual(_expr("32'1"), "1")
self.assertEqual(_expr("16'0xFFFF"), "0xFFFF")
def test_suffix_strip(self):
self.assertEqual(_expr("0ULL"), "0")
self.assertEqual(_expr("1LL"), "1")
self.assertEqual(_expr("5U"), "5")
self.assertEqual(_expr("3.14F"), "3.14")
def test_boolean_ops(self):
self.assertIn("and", _expr("a && b"))
self.assertIn("or", _expr("a || b"))
self.assertIn("!=", _expr("a <> b"))
def test_pack16(self):
result = _expr("{ a, b }")
self.assertIn("_pack", result)
def test_type_cast_strip(self):
self.assertEqual(_expr("64'U(x)"), "(x)")
self.assertEqual(_expr("32'I(y)"), "(y)")
class TestExecContext(unittest.TestCase):
def test_float_add(self):
ctx = ExecContext(s0=0x40400000, s1=0x40800000) # 3.0f, 4.0f
ctx.D0.f32 = ctx.S0.f32 + ctx.S1.f32
self.assertAlmostEqual(_f32(ctx.D0._val), 7.0)
def test_float_mul(self):
ctx = ExecContext(s0=0x40400000, s1=0x40800000) # 3.0f, 4.0f
ctx.run("D0.f32 = S0.f32 * S1.f32")
self.assertAlmostEqual(_f32(ctx.D0._val), 12.0)
def test_scc_comparison(self):
ctx = ExecContext(s0=42, s1=42)
ctx.run("SCC = S0.u32 == S1.u32")
self.assertEqual(ctx.SCC._val, 1)
def test_scc_comparison_false(self):
ctx = ExecContext(s0=42, s1=43)
ctx.run("SCC = S0.u32 == S1.u32")
self.assertEqual(ctx.SCC._val, 0)
def test_ternary(self):
code = _compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
ctx = ExecContext(s0=5, s1=3)
ctx.run(code)
self.assertEqual(ctx.D0._val, 1)
def test_pack(self):
code = _compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
ctx = ExecContext(s0=0x1234, s1=0x5678)
ctx.run(code)
self.assertEqual(ctx.D0._val, 0x56781234)
def test_tmp_with_typed_access(self):
code = _compile_pseudocode("""tmp = S0.u32 + S1.u32
D0.u32 = tmp.u32""")
ctx = ExecContext(s0=100, s1=200)
ctx.run(code)
self.assertEqual(ctx.D0._val, 300)
def test_s_add_u32_pattern(self):
# Real pseudocode pattern from S_ADD_U32
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
D0.u32 = tmp.u32""")
# Test overflow case
ctx = ExecContext(s0=0xFFFFFFFF, s1=0x00000001)
ctx.run(code)
self.assertEqual(ctx.D0._val, 0) # Wraps to 0
self.assertEqual(ctx.SCC._val, 1) # Carry set
def test_s_add_u32_no_overflow(self):
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
D0.u32 = tmp.u32""")
ctx = ExecContext(s0=100, s1=200)
ctx.run(code)
self.assertEqual(ctx.D0._val, 300)
self.assertEqual(ctx.SCC._val, 0) # No carry
def test_vcc_lane_read(self):
ctx = ExecContext(vcc=0b1010, lane=1)
# Lane 1 is set
self.assertEqual(ctx.VCC.u64[1], 1)
self.assertEqual(ctx.VCC.u64[2], 0)
def test_vcc_lane_write(self):
ctx = ExecContext(vcc=0, lane=0)
ctx.VCC.u64[3] = 1
ctx.VCC.u64[1] = 1
self.assertEqual(ctx.VCC._val, 0b1010)
def test_for_loop(self):
# CTZ pattern - find first set bit
code = _compile_pseudocode("""tmp = -1
for i in 0 : 31 do
if S0.u32[i] == 1 then
tmp = i
endif
endfor
D0.i32 = tmp""")
ctx = ExecContext(s0=0b1000) # Bit 3 is set
ctx.run(code)
self.assertEqual(ctx.D0._val & MASK32, 3)
def test_result_dict(self):
ctx = ExecContext(s0=5, s1=3)
ctx.D0.u32 = 42
ctx.SCC._val = 1
result = ctx.result()
self.assertEqual(result['d0'], 42)
self.assertEqual(result['scc'], 1)
class TestPseudocodeRegressions(unittest.TestCase):
"""Regression tests for pseudocode instruction emulation bugs."""
def test_v_div_scale_f32_vcc_always_returned(self):
"""V_DIV_SCALE_F32 must always return VCC, even when VCC=0 (no scaling needed).
Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written.
This caused division to produce wrong results for multiple lanes."""
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
s0 = 0x3f800000 # 1.0
s1 = 0x40400000 # 3.0
s2 = 0x3f800000 # 1.0 (numerator)
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None)
# Must always have VCC in result
self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC")
self.assertEqual(result['VCC'] & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
def test_v_cmp_class_f32_detects_quiet_nan(self):
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
Bug: isQuietNAN and isSignalNAN both used math.isnan which can't distinguish them."""
quiet_nan = 0x7fc00000 # quiet NaN: exponent=255, bit22=1
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
# Test quiet NaN detection (bit 1 in mask)
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 1, "Should detect quiet NaN with quiet NaN mask")
# Test signaling NaN detection (bit 0 in mask)
s1_signal = 0b0000000001 # bit 0 = signaling NaN
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 1, "Should detect signaling NaN with signaling NaN mask")
# Test that quiet NaN doesn't match signaling NaN mask
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 0, "Quiet NaN should not match signaling NaN mask")
# Test that signaling NaN doesn't match quiet NaN mask
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
def testisNAN_with_typed_view(self):
"""isNAN must work with TypedView objects, not just Python floats.
Bug: isNAN checked isinstance(x, float) which returned False for TypedView."""
nan_reg = Reg(0x7fc00000) # quiet NaN
normal_reg = Reg(0x3f800000) # 1.0
inf_reg = Reg(0x7f800000) # +inf
self.assertTrue(isNAN(nan_reg.f32), "isNAN should return True for NaN TypedView")
self.assertFalse(isNAN(normal_reg.f32), "isNAN should return False for normal TypedView")
self.assertFalse(isNAN(inf_reg.f32), "isNAN should return False for inf TypedView")
class TestBF16(unittest.TestCase):
"""Tests for BF16 (bfloat16) support."""
def test_bf16_conversion(self):
"""Test bf16 <-> f32 conversion."""
# bf16 is just the top 16 bits of f32
# 1.0f = 0x3f800000, bf16 = 0x3f80
self.assertAlmostEqual(_bf16(0x3f80), 1.0, places=2)
self.assertEqual(_ibf16(1.0), 0x3f80)
# 2.0f = 0x40000000, bf16 = 0x4000
self.assertAlmostEqual(_bf16(0x4000), 2.0, places=2)
self.assertEqual(_ibf16(2.0), 0x4000)
# -1.0f = 0xbf800000, bf16 = 0xbf80
self.assertAlmostEqual(_bf16(0xbf80), -1.0, places=2)
self.assertEqual(_ibf16(-1.0), 0xbf80)
def test_bf16_special_values(self):
"""Test bf16 special values (inf, nan)."""
import math
# +inf: f32 = 0x7f800000, bf16 = 0x7f80
self.assertTrue(math.isinf(_bf16(0x7f80)))
self.assertEqual(_ibf16(float('inf')), 0x7f80)
# -inf: f32 = 0xff800000, bf16 = 0xff80
self.assertTrue(math.isinf(_bf16(0xff80)))
self.assertEqual(_ibf16(float('-inf')), 0xff80)
# NaN: quiet NaN bf16 = 0x7fc0
self.assertTrue(math.isnan(_bf16(0x7fc0)))
self.assertEqual(_ibf16(float('nan')), 0x7fc0)
def test_bf16_register_property(self):
"""Test Reg.bf16 property."""
r = Reg(0)
r.bf16 = 3.0 # 3.0f = 0x40400000, bf16 = 0x4040
self.assertEqual(r._val & 0xffff, 0x4040)
self.assertAlmostEqual(float(r.bf16), 3.0, places=1)
def test_bf16_slice_property(self):
"""Test TypedView.bf16 property."""
r = Reg(0x40404040) # Two bf16 3.0 values
self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1)
self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1)
class TestBytePermute(unittest.TestCase):
"""Tests for BYTE_PERMUTE helper function (V_PERM_B32)."""
def test_byte_select_0_to_7(self):
"""Test selecting bytes 0-7 from 64-bit data."""
# data = {s0, s1} where s0 is bytes 0-3, s1 is bytes 4-7
# Combined: 0x0706050403020100 (byte 0 = 0x00, byte 7 = 0x07)
data = 0x0706050403020100
for i in range(8):
self.assertEqual(BYTE_PERMUTE(data, i), i, f"byte {i} should be {i}")
def test_sign_extend_bytes(self):
"""Test sign extension selectors 8-11."""
# sel 8: sign of byte 1 (bits 15:8)
# sel 9: sign of byte 3 (bits 31:24)
# sel 10: sign of byte 5 (bits 47:40)
# sel 11: sign of byte 7 (bits 63:56)
data = 0x8000800080008000 # All relevant bytes have sign bit set
self.assertEqual(BYTE_PERMUTE(data, 8), 0xff)
self.assertEqual(BYTE_PERMUTE(data, 9), 0xff)
self.assertEqual(BYTE_PERMUTE(data, 10), 0xff)
self.assertEqual(BYTE_PERMUTE(data, 11), 0xff)
data = 0x7f007f007f007f00 # No sign bits set
self.assertEqual(BYTE_PERMUTE(data, 8), 0x00)
self.assertEqual(BYTE_PERMUTE(data, 9), 0x00)
self.assertEqual(BYTE_PERMUTE(data, 10), 0x00)
self.assertEqual(BYTE_PERMUTE(data, 11), 0x00)
def test_constant_zero(self):
"""Test selector 12 returns 0x00."""
self.assertEqual(BYTE_PERMUTE(0xffffffffffffffff, 12), 0x00)
def test_constant_ff(self):
"""Test selectors >= 13 return 0xFF."""
for sel in [13, 14, 15, 255]:
self.assertEqual(BYTE_PERMUTE(0, sel), 0xff, f"sel {sel} should be 0xff")
class TestSADHelpers(unittest.TestCase):
"""Tests for V_SAD_U8 and V_MSAD_U8 helper functions."""
def test_v_sad_u8_basic(self):
"""Test v_sad_u8 with simple values."""
# s0 = 0x04030201, s1 = 0x04030201 -> diff = 0 for all bytes
result = v_sad_u8(0x04030201, 0x04030201, 0)
self.assertEqual(result, 0)
# s0 = 0x05040302, s1 = 0x04030201 -> diff = 1+1+1+1 = 4
result = v_sad_u8(0x05040302, 0x04030201, 0)
self.assertEqual(result, 4)
def test_v_sad_u8_with_accumulator(self):
"""Test v_sad_u8 with non-zero accumulator."""
# s0 = 0x05040302, s1 = 0x04030201, s2 = 100 -> 4 + 100 = 104
result = v_sad_u8(0x05040302, 0x04030201, 100)
self.assertEqual(result, 104)
def test_v_sad_u8_large_diff(self):
"""Test v_sad_u8 with maximum byte differences."""
# s0 = 0xffffffff, s1 = 0x00000000 -> diff = 255*4 = 1020
result = v_sad_u8(0xffffffff, 0x00000000, 0)
self.assertEqual(result, 1020)
def test_v_msad_u8_basic(self):
"""Test v_msad_u8 masks when reference byte is 0."""
# s0 = 0x10101010, s1 = 0x00000000 -> all masked, result = 0
result = v_msad_u8(0x10101010, 0x00000000, 0)
self.assertEqual(result, 0)
# s0 = 0x10101010, s1 = 0x01010101 -> diff = |0x10-0x01|*4 = 15*4 = 60
result = v_msad_u8(0x10101010, 0x01010101, 0)
self.assertEqual(result, 60)
def test_v_msad_u8_partial_mask(self):
"""Test v_msad_u8 with partial masking."""
# s0 = 0x10101010, s1 = 0x00010001 -> bytes 1 and 3 masked
# diff = |0x10-0x01| + |0x10-0x01| = 15 + 15 = 30
result = v_msad_u8(0x10101010, 0x00010001, 0)
self.assertEqual(result, 30)
def test_v_msad_u8_with_accumulator(self):
"""Test v_msad_u8 with non-zero accumulator."""
result = v_msad_u8(0x10101010, 0x01010101, 50)
self.assertEqual(result, 110) # 60 + 50
if __name__ == '__main__':
unittest.main()

View File

@@ -47,10 +47,14 @@ INPUT_VARS = {
'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)),
'OPSEL': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL', 0, 7)),
'OPSEL_HI': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL_HI', 0, 7)),
'SRC0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('SRC0', 0, 0xffffffff)), # Source register index
'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('VDST', 0, 0xffffffff)), # Dest register index (for writelane)
'M0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('M0', 0, 0xffffffff)), # M0 register
}
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0)
VGPR_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(addrspace=AddrSpace.GLOBAL), arg=1) # VGPR[lane][reg] as flat array
class Ctx:
def __init__(self, mem_buf: UOp = MEM_BUF):
@@ -85,6 +89,25 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
if name == 'NAN.f32': return UOp.const(dtypes.float32, float('nan'))
if name in ('VCCZ', 'EXECZ'):
return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32)
if name == 'EXEC_LO':
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 0xffffffff))), hint or dtypes.uint32)
if name == 'EXEC_HI':
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 32))), hint or dtypes.uint32)
if name == 'VCC_LO':
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 0xffffffff))), hint or dtypes.uint32)
if name == 'VCC_HI':
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 32))), hint or dtypes.uint32)
if name == 'laneID' or name == 'laneId':
return ctx.vars.get('laneId', UOp.const(dtypes.uint32, 0))
if name == 'ThreadMask':
# ThreadMask is the same as EXEC for wave32
return _cast(ctx.vars.get('EXEC'), hint or dtypes.uint32)
if name == 'DST':
# DST is the raw destination register index from the instruction
return ctx.vars.get('VDST', UOp.const(dtypes.uint32, 0))
if name == 'LDS':
# LDS is the local data share memory - treat as memory buffer
return UOp.const(dtypes.uint64, 0) # Base address placeholder
if name.startswith('eval '): return ctx.vars.get('_eval', UOp.const(dtypes.uint32, 0))
if name not in ctx.vars: raise ValueError(f"Unknown variable: {name}")
return _cast(ctx.vars[name], hint or ctx.vars[name].dtype)
@@ -109,6 +132,18 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
return UOp.const(dt, denorm)
if name in ('VCCZ', 'EXECZ'):
return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32)
if name == 'EXEC_LO':
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 0xffffffff))), dt)
if name == 'EXEC_HI':
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 32))), dt)
if name == 'VCC_LO':
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 0xffffffff))), dt)
if name == 'VCC_HI':
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 32))), dt)
if name == 'DST':
return _cast(ctx.vars.get('VDST', UOp.const(dtypes.uint32, 0)), dt)
if name == 'laneID' or name == 'laneId':
return _cast(ctx.vars.get('laneId', UOp.const(dtypes.uint32, 0)), dt)
if name.startswith('WAVE_STATUS.COND_DBG'): return UOp.const(dtypes.uint32, 0)
vn = name + '_64' if dt.itemsize == 8 and name.isupper() else name
base = ctx.vars.get(vn) if vn in ctx.vars else ctx.vars.get(name)
@@ -132,6 +167,20 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
return _cast(inner_resolved, dt)
case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice or array access
# Check for VGPR[lane][reg] access pattern (nested CUSTOMI where inner base is VGPR)
if base_expr.op == Ops.CUSTOMI and hi_expr is lo_expr:
inner_base, inner_idx, _ = base_expr.src
if inner_base.op == Ops.DEFINE_VAR and inner_base.arg[0] == 'VGPR':
# VGPR[lane][reg] -> load from VGPR buffer at index (lane * 256 + reg)
lane_uop = _expr(inner_idx, ctx, dtypes.uint32)
reg_uop = _expr(hi_expr, ctx, dtypes.uint32)
# Compute flat index: lane * 256 + reg (256 VGPRs per lane)
idx = UOp(Ops.ADD, dtypes.uint32, (UOp(Ops.MUL, dtypes.uint32, (lane_uop, UOp.const(dtypes.uint32, 256))), reg_uop))
return UOp(Ops.CUSTOM, dtypes.uint32, (idx,), arg='vgpr_read')
# Check for SGPR[idx] access pattern (scalar register file access)
if base_expr.op == Ops.DEFINE_VAR and base_expr.arg[0] == 'SGPR' and hi_expr is lo_expr:
idx_uop = _expr(hi_expr, ctx, dtypes.uint32)
return UOp(Ops.CUSTOM, dtypes.uint32, (idx_uop,), arg='sgpr_read')
# Check for array element access first: arr[idx] where arr is a vector type
if base_expr.op == Ops.DEFINE_VAR and base_expr.arg[1] is None and hi_expr is lo_expr:
name = base_expr.arg[0]
@@ -377,6 +426,14 @@ def _call_trig_preop_result(shift):
# Returns CUSTOM op that gets evaluated at runtime with the 1201-bit constant
return UOp(Ops.CUSTOM, dtypes.float64, (shift,), arg='trig_preop_result')
def _call_s_ff1_i32_b32(v):
# Find first 1 bit (count trailing zeros) - returns CUSTOM op evaluated at runtime
return UOp(Ops.CUSTOM, dtypes.int32, (_cast(v, dtypes.uint32),), arg='s_ff1_i32_b32')
def _call_s_ff1_i32_b64(v):
# Find first 1 bit in 64-bit value (count trailing zeros) - returns CUSTOM op evaluated at runtime
return UOp(Ops.CUSTOM, dtypes.int32, (_cast(v, dtypes.uint64),), arg='s_ff1_i32_b64')
CALL_DISPATCH = {
'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt,
'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isQuietNAN,
@@ -384,6 +441,9 @@ CALL_DISPATCH = {
'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven,
'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8,
'BYTE_PERMUTE': _call_BYTE_PERMUTE, 'bf16_to_f32': _call_bf16_to_f32, 'trig_preop_result': _call_trig_preop_result,
's_ff1_i32_b32': _call_s_ff1_i32_b32, 's_ff1_i32_b64': _call_s_ff1_i32_b64,
'u8_to_u32': lambda v: _cast(UOp(Ops.AND, dtypes.uint32, (_cast(v, dtypes.uint32), UOp.const(dtypes.uint32, 0xff))), dtypes.uint32),
'u4_to_u32': lambda v: _cast(UOp(Ops.AND, dtypes.uint32, (_cast(v, dtypes.uint32), UOp.const(dtypes.uint32, 0xf))), dtypes.uint32),
}
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
@@ -432,40 +492,53 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
return result
raise ValueError(f"Unknown function: {name}")
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None, int|None]:
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var, array_idx)"""
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, UOp|str|None, int|None, UOp|None]:
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var, array_idx, dynamic_idx_uop)
dynamic_idx_uop is set when the bit index is a runtime expression (not constant or simple variable)"""
match lhs:
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None, None
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None, None, None
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))),)):
return name, dt, int(hi), int(lo), None, None
return name, dt, int(hi), int(lo), None, None, None
case UOp(Ops.BITCAST, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
return name, dtypes.uint64, None, None, idx, None
return name, dtypes.uint64, None, None, idx, None, None
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
return name, dt, None, None, idx, None
return name, dt, None, None, idx, None, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
return name, dtypes.uint32, int(hi), int(lo), None, None
return name, dtypes.uint32, int(hi), int(lo), None, None, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, idx), _)) if lhs.src[1] is lhs.src[2]:
# Check if this is array element access (variable is a vector type)
var_dtype = ctx.decls.get(name)
if var_dtype is not None and var_dtype.count > 1:
return name, var_dtype.scalar(), None, None, None, int(idx)
return name, dtypes.uint32, int(idx), int(idx), None, None
return name, var_dtype.scalar(), None, None, None, int(idx), None
return name, dtypes.uint32, int(idx), int(idx), None, None, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
return name, dtypes.uint32, int(hi), int(lo), None, None
return name, dtypes.uint32, int(hi), int(lo), None, None, None
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
return name, dt, None, None, idx, None
return name, dt, None, None, idx, None, None
# Handle arr[i] where i is a variable - check if it's array element or bit index
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
var_dtype = ctx.decls.get(name)
if var_dtype is not None and var_dtype.count > 1:
# Array element access with variable index
return name, var_dtype.scalar(), None, None, None, idx # Return idx as variable name for array_idx
return name, dtypes.uint32, None, None, idx, None
return name, var_dtype.scalar(), None, None, None, idx, None # Return idx as variable name for array_idx
return name, dtypes.uint32, None, None, idx, None, None
# Handle D0.u32[expr] where expr is a complex expression (dynamic bit index)
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), idx_expr, idx_expr2)) if lhs.src[1] is lhs.src[2]:
return name, dt, None, None, None, None, idx_expr # Return expression as dynamic_idx_uop
# Handle VGPR[lane][reg] = value (nested CUSTOMI for VGPR write)
case UOp(Ops.CUSTOMI, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, ('VGPR', None, None)), lane_expr, _)), reg_expr, _)):
return 'VGPR', dtypes.uint32, None, None, None, None, (lane_expr, reg_expr) # Return tuple for VGPR write
# Handle VGPR[laneId][addr].type = value (with BITCAST wrapper)
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, ('VGPR', None, None)), lane_expr, _)), reg_expr, _)),)):
return 'VGPR', dt, None, None, None, None, (lane_expr, reg_expr) # Return tuple for VGPR write
# Handle SGPR[addr].type = value (scalar register write with BITCAST)
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, ('SGPR', None, None)), reg_expr, _)),)):
return 'SGPR', dt, None, None, None, None, reg_expr # Return expr for SGPR write
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)):
# If the variable already exists, use its dtype; otherwise default to uint32
existing = ctx.vars.get(name)
dtype = existing.dtype if existing is not None else dtypes.uint32
return name, dtype, None, None, None, None
return name, dtype, None, None, None, None, None
raise ValueError(f"Cannot parse LHS: {lhs}")
def _stmt(stmt, ctx: Ctx):
@@ -511,9 +584,47 @@ def _stmt(stmt, ctx: Ctx):
offset += bits
return
var, dtype, hi, lo, idx_var, array_idx = _get_lhs_info(lhs, ctx)
var, dtype, hi, lo, idx_var, array_idx, dynamic_idx = _get_lhs_info(lhs, ctx)
out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA')
# Handle VGPR write: VGPR[lane][reg] = value
if var == 'VGPR' and isinstance(dynamic_idx, tuple):
lane_expr, reg_expr = dynamic_idx
lane_uop = _expr(lane_expr, ctx, dtypes.uint32)
reg_uop = _expr(reg_expr, ctx, dtypes.uint32)
val_uop = _expr(rhs, ctx, dtype)
# Compute flat index: lane * 256 + reg
idx = UOp(Ops.ADD, dtypes.uint32, (UOp(Ops.MUL, dtypes.uint32, (lane_uop, UOp.const(dtypes.uint32, 256))), reg_uop))
ctx.outputs.append(('VGPR_WRITE', UOp(Ops.CUSTOM, dtypes.uint32, (idx, _cast(val_uop, dtypes.uint32)), arg='vgpr_write'), dtypes.uint32))
return
# Handle SGPR write: SGPR[reg] = value
if var == 'SGPR' and dynamic_idx is not None and not isinstance(dynamic_idx, tuple):
reg_uop = _expr(dynamic_idx, ctx, dtypes.uint32)
val_uop = _expr(rhs, ctx, dtype)
ctx.outputs.append(('SGPR_WRITE', UOp(Ops.CUSTOM, dtypes.uint32, (reg_uop, _cast(val_uop, dtypes.uint32)), arg='sgpr_write'), dtypes.uint32))
return
# Handle dynamic bit index: D0.u32[expr] = value (where expr is runtime expression)
if dynamic_idx is not None and not isinstance(dynamic_idx, tuple):
idx_uop = _expr(dynamic_idx, ctx, dtypes.uint32)
rhs_uop = _expr(rhs, ctx, dtypes.uint32)
op_dt = dtype if dtype.itemsize >= 4 else dtypes.uint32
if dtype.itemsize == 8: op_dt = dtypes.uint64
base = ctx.vars.get(var, UOp.const(op_dt, 0))
if base.dtype != op_dt: base = _cast(base, op_dt)
# Set single bit at dynamic index: base = (base & ~(1 << idx)) | ((val & 1) << idx)
one = UOp.const(op_dt, 1)
bit_mask = UOp(Ops.SHL, op_dt, (one, _cast(idx_uop, op_dt)))
inv_mask = UOp(Ops.XOR, op_dt, (bit_mask, UOp.const(op_dt, -1)))
val_bit = UOp(Ops.SHL, op_dt, (UOp(Ops.AND, op_dt, (_cast(rhs_uop, op_dt), one)), _cast(idx_uop, op_dt)))
result = UOp(Ops.OR, op_dt, (UOp(Ops.AND, op_dt, (base, inv_mask)), val_bit))
ctx.vars[var] = result
if var in out_vars:
ctx.outputs = [(n, u, d) for n, u, d in ctx.outputs if n != var]
ctx.outputs.append((var, result, op_dt))
return
# Handle array element assignment: arr[idx] = value
if array_idx is not None:
var_dtype = ctx.decls.get(var)
@@ -717,6 +828,28 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
shifted = (TWO_OVER_PI_1201 << int(shift)) >> (1201 - 53)
mantissa = shifted & 0x1fffffffffffff
return float(mantissa)
if u.op == Ops.CUSTOM and u.arg == 's_ff1_i32_b32':
# Find first 1 bit (count trailing zeros) in 32-bit value
v = _eval_uop(u.src[0])
if v is None: return None
v = int(v) & 0xffffffff
if v == 0: return 32
n = 0
while (v & 1) == 0: v >>= 1; n += 1
return n
if u.op == Ops.CUSTOM and u.arg == 's_ff1_i32_b64':
# Find first 1 bit (count trailing zeros) in 64-bit value
v = _eval_uop(u.src[0])
if v is None: return None
v = int(v) & 0xffffffffffffffff
if v == 0: return 64
n = 0
while (v & 1) == 0: v >>= 1; n += 1
return n
if u.op == Ops.CUSTOM and u.arg == 'vgpr_read':
# VGPR read - returns CUSTOM that will be resolved with VGPR data at runtime
# This can't be evaluated statically - needs VGPR substitution
return None
return None
def _extract_results(s, MEM=None):
@@ -790,8 +923,20 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
input_vars['SIMM16']: UOp.const(dtypes.int32, simm16), input_vars['SIMM32']: UOp.const(dtypes.uint32, literal or 0),
input_vars['PC']: UOp.const(dtypes.uint64, pc or 0),
input_vars['OPSEL']: UOp.const(dtypes.uint32, opsel), input_vars['OPSEL_HI']: UOp.const(dtypes.uint32, opsel_hi),
input_vars['SRC0']: UOp.const(dtypes.uint32, src0_idx),
}
return _extract_results(sink.substitute(dvars).simplify())
s1_sub = sink.substitute(dvars).simplify()
# Handle VGPR reads - substitute vgpr_read CUSTOM ops with actual values
if VGPR is not None:
vgpr_subs = {}
for u in s1_sub.toposort():
if u.op == Ops.CUSTOM and u.arg == 'vgpr_read':
idx = _eval_uop(u.src[0])
if idx is not None:
lane, reg = int(idx) // 256, int(idx) % 256
vgpr_subs[u] = UOp.const(dtypes.uint32, VGPR[lane][reg] if lane < len(VGPR) and reg < len(VGPR[lane]) else 0)
if vgpr_subs: s1_sub = s1_sub.substitute(vgpr_subs).simplify()
return _extract_results(s1_sub)
return fn
# Ops that need Python exec features (inline conditionals, complex PDF fixes) - fall back to pcode.py