mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
no pcode
This commit is contained in:
@@ -43,6 +43,14 @@ def _i64(f):
|
||||
try: return _struct_Q.unpack(_struct_d.pack(f))[0]
|
||||
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
|
||||
|
||||
# Float conversion helpers (used by tests)
|
||||
def f32_to_f16(f):
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0x7e00 # f16 NaN
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
|
||||
try: return _struct_H.unpack(_struct_e.pack(f))[0]
|
||||
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
|
||||
|
||||
# Instruction spec - register counts and dtypes derived from instruction names
|
||||
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
|
||||
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
|
||||
|
||||
@@ -5,7 +5,6 @@ import ctypes, functools
|
||||
from tinygrad.runtime.autogen import hsa
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.pcode import compile_pseudocode
|
||||
from extra.assembly.amd.ucode import compile_uop
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
@@ -381,7 +380,7 @@ def decode_program(data: bytes) -> dict[int, Inst]:
|
||||
# Compile pcode for instructions that use it (not VOPD which has _fnx/_fny, not special dispatches)
|
||||
# Try ucode first (UOp-based), fall back to pcode (Python exec-based)
|
||||
def _compile_op(cls_name, op_name, pcode):
|
||||
return compile_uop(op_name, pcode) or compile_pseudocode(cls_name, op_name, pcode)
|
||||
return compile_uop(op_name, pcode) #or compile_pseudocode(cls_name, op_name, pcode)
|
||||
# VOPD needs separate functions for X and Y ops
|
||||
if isinstance(inst, VOPD):
|
||||
def _compile_vopd_op(op): return _compile_op(type(op).__name__, op.name, PSEUDOCODE_STRINGS[type(op)][op])
|
||||
|
||||
@@ -1,765 +0,0 @@
|
||||
# DSL for RDNA3 pseudocode - makes pseudocode expressions work directly as Python
|
||||
import struct, math, re, functools
|
||||
from extra.assembly.amd.dsl import MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# INTERNAL HELPERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _div(a, b):
|
||||
try: return a / b
|
||||
except ZeroDivisionError:
|
||||
if a == 0.0 or math.isnan(a): return float("nan")
|
||||
return math.copysign(float("inf"), a * b) if b == 0.0 else float("inf") if a > 0 else float("-inf")
|
||||
def _check_nan_type(x, quiet_bit_expected, default):
|
||||
try:
|
||||
if not math.isnan(float(x)): return False
|
||||
if hasattr(x, '_reg') and hasattr(x, '_bits'):
|
||||
bits = x._reg._val & ((1 << x._bits) - 1)
|
||||
exp_bits, quiet_pos, mant_mask = {16: (0x1f, 9, 0x3ff), 32: (0xff, 22, 0x7fffff), 64: (0x7ff, 51, 0xfffffffffffff)}.get(x._bits, (0,0,0))
|
||||
exp_shift = {16: 10, 32: 23, 64: 52}.get(x._bits, 0)
|
||||
if exp_bits and ((bits >> exp_shift) & exp_bits) == exp_bits and (bits & mant_mask) != 0:
|
||||
return ((bits >> quiet_pos) & 1) == quiet_bit_expected
|
||||
return default
|
||||
except (TypeError, ValueError): return False
|
||||
def _gt_neg_zero(a, b): return (a > b) or (a == 0 and b == 0 and not math.copysign(1, a) < 0 and math.copysign(1, b) < 0)
|
||||
def _lt_neg_zero(a, b): return (a < b) or (a == 0 and b == 0 and math.copysign(1, a) < 0 and not math.copysign(1, b) < 0)
|
||||
def _fpop(fn):
|
||||
def wrapper(x):
|
||||
x = float(x)
|
||||
if math.isnan(x) or math.isinf(x): return x
|
||||
result = float(fn(x))
|
||||
return math.copysign(0.0, x) if result == 0.0 else result
|
||||
return wrapper
|
||||
def _f_to_int(f, lo, hi): f = float(f); return 0 if math.isnan(f) else (hi if f >= hi else lo if f <= lo else int(f))
|
||||
def _f16_to_f32_bits(bits): return struct.unpack("<e", struct.pack("<H", int(bits) & 0xffff))[0]
|
||||
def _brev(v, bits): return int(bin(v & ((1 << bits) - 1))[2:].zfill(bits)[::-1], 2)
|
||||
def _ctz(v, bits):
|
||||
v, n = int(v) & ((1 << bits) - 1), 0
|
||||
if v == 0: return bits
|
||||
while (v & 1) == 0: v >>= 1; n += 1
|
||||
return n
|
||||
|
||||
def _bf16(i):
|
||||
"""Convert bf16 bits to float. BF16 is just the top 16 bits of f32."""
|
||||
return struct.unpack("<f", struct.pack("<I", (i & 0xffff) << 16))[0]
|
||||
def _ibf16(f):
|
||||
"""Convert float to bf16 bits (truncate to top 16 bits of f32)."""
|
||||
if math.isnan(f): return 0x7fc0 # bf16 quiet NaN
|
||||
if math.isinf(f): return 0x7f80 if f > 0 else 0xff80 # bf16 ±infinity
|
||||
try: return (struct.unpack("<I", struct.pack("<f", float(f)))[0] >> 16) & 0xffff
|
||||
except (OverflowError, struct.error): return 0x7f80 if f > 0 else 0xff80
|
||||
def _trig(fn, x):
|
||||
# V_SIN/COS_F32: hardware does frac on input cycles before computing
|
||||
if math.isinf(x) or math.isnan(x): return float("nan")
|
||||
frac_cycles = fract(x / (2 * math.pi))
|
||||
result = fn(frac_cycles * 2 * math.pi)
|
||||
# Hardware returns exactly 0 for cos(π/2), sin(π), etc. due to lookup table
|
||||
# Round very small results (below f32 precision) to exactly 0
|
||||
if abs(result) < 1e-7: return 0.0
|
||||
return result
|
||||
|
||||
class _SafeFloat(float):
|
||||
"""Float subclass that uses _div for division to handle 0/inf correctly."""
|
||||
def __truediv__(self, o): return _div(float(self), float(o))
|
||||
def __rtruediv__(self, o): return _div(float(o), float(self))
|
||||
|
||||
class _Inf:
|
||||
f16 = f32 = f64 = float('inf')
|
||||
def __neg__(self): return _NegInf()
|
||||
def __pos__(self): return self
|
||||
def __float__(self): return float('inf')
|
||||
def __eq__(self, other): return float(other) == float('inf') if not isinstance(other, _NegInf) else False
|
||||
def __req__(self, other): return self.__eq__(other)
|
||||
class _NegInf:
|
||||
f16 = f32 = f64 = float('-inf')
|
||||
def __neg__(self): return _Inf()
|
||||
def __pos__(self): return self
|
||||
def __float__(self): return float('-inf')
|
||||
def __eq__(self, other): return float(other) == float('-inf') if not isinstance(other, _Inf) else False
|
||||
def __req__(self, other): return self.__eq__(other)
|
||||
|
||||
class _RoundMode:
|
||||
NEAREST_EVEN = 0
|
||||
|
||||
class _WaveMode:
|
||||
IEEE = False
|
||||
|
||||
class _DenormChecker:
|
||||
"""Comparator for denormalized floats. x == DENORM.f32 checks if x is denormalized."""
|
||||
def __init__(self, bits): self._bits = bits
|
||||
def _check(self, other):
|
||||
f = float(other)
|
||||
if math.isinf(f) or math.isnan(f) or f == 0.0: return False
|
||||
if self._bits == 64:
|
||||
bits = struct.unpack("<Q", struct.pack("<d", f))[0]
|
||||
return (bits >> 52) & 0x7ff == 0
|
||||
bits = struct.unpack("<I", struct.pack("<f", f))[0]
|
||||
return (bits >> 23) & 0xff == 0
|
||||
def __eq__(self, other): return self._check(other)
|
||||
def __req__(self, other): return self._check(other)
|
||||
def __ne__(self, other): return not self._check(other)
|
||||
|
||||
class _Denorm:
|
||||
f32 = _DenormChecker(32)
|
||||
f64 = _DenormChecker(64)
|
||||
|
||||
_pack = lambda hi, lo: ((int(hi) & 0xffff) << 16) | (int(lo) & 0xffff)
|
||||
_pack32 = lambda hi, lo: ((int(hi) & 0xffffffff) << 32) | (int(lo) & 0xffffffff)
|
||||
|
||||
class TypedView:
|
||||
"""View into a Reg with typed access. Used for both full-width (Reg.u32) and slices (Reg[31:16])."""
|
||||
__slots__ = ('_reg', '_high', '_low', '_signed', '_float', '_bf16', '_reversed')
|
||||
def __init__(self, reg, high, low=0, signed=False, is_float=False, is_bf16=False):
|
||||
# Handle reversed slices like [0:31] which means bit-reverse
|
||||
if high < low: high, low, reversed = low, high, True
|
||||
else: reversed = False
|
||||
self._reg, self._high, self._low, self._reversed = reg, high, low, reversed
|
||||
self._signed, self._float, self._bf16 = signed, is_float, is_bf16
|
||||
|
||||
def _nbits(self): return self._high - self._low + 1
|
||||
def _mask(self): return (1 << self._nbits()) - 1
|
||||
def _get(self):
|
||||
v = (self._reg._val >> self._low) & self._mask()
|
||||
return _brev(v, self._nbits()) if self._reversed else v
|
||||
def _set(self, v):
|
||||
v = int(v)
|
||||
if self._reversed: v = _brev(v, self._nbits())
|
||||
self._reg._val = (self._reg._val & ~(self._mask() << self._low)) | ((v & self._mask()) << self._low)
|
||||
|
||||
@property
|
||||
def _val(self): return self._get()
|
||||
@property
|
||||
def _bits(self): return self._nbits()
|
||||
|
||||
# Type accessors for slices (e.g., D0[31:16].f16)
|
||||
u8 = property(lambda s: s._get() & 0xff)
|
||||
u16 = property(lambda s: s._get() & 0xffff, lambda s, v: s._set(v))
|
||||
u32 = property(lambda s: s._get() & MASK32, lambda s, v: s._set(v))
|
||||
i16 = property(lambda s: _sext(s._get() & 0xffff, 16), lambda s, v: s._set(v))
|
||||
i32 = property(lambda s: _sext(s._get() & MASK32, 32), lambda s, v: s._set(v))
|
||||
f16 = property(lambda s: _f16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _i16(float(v))))
|
||||
f32 = property(lambda s: _f32(s._get()), lambda s, v: s._set(_i32(float(v))))
|
||||
bf16 = property(lambda s: _bf16(s._get()), lambda s, v: s._set(v if isinstance(v, int) else _ibf16(float(v))))
|
||||
b16, b32 = u16, u32
|
||||
|
||||
# Chained type access (e.g., jump_addr.i64 when jump_addr is already TypedView)
|
||||
@property
|
||||
def i64(s): return s if s._nbits() == 64 and s._signed else int(s)
|
||||
@property
|
||||
def u64(s): return s if s._nbits() == 64 and not s._signed else int(s) & MASK64
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
high, low = int(key.start), int(key.stop)
|
||||
return TypedView(self._reg, high, low)
|
||||
return (self._get() >> int(key)) & 1
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if isinstance(key, slice):
|
||||
high, low = int(key.start), int(key.stop)
|
||||
if high < low: high, low, value = low, high, _brev(int(value), low - high + 1)
|
||||
mask = (1 << (high - low + 1)) - 1
|
||||
self._reg._val = (self._reg._val & ~(mask << low)) | ((int(value) & mask) << low)
|
||||
elif value: self._reg._val |= (1 << int(key))
|
||||
else: self._reg._val &= ~(1 << int(key))
|
||||
|
||||
def __int__(self): return _sext(self._get(), self._nbits()) if self._signed else self._get()
|
||||
def __index__(self): return int(self)
|
||||
def __trunc__(self): return int(float(self)) if self._float else int(self)
|
||||
def __float__(self):
|
||||
if self._float:
|
||||
if self._bf16: return _bf16(self._get())
|
||||
bits = self._nbits()
|
||||
return _f16(self._get()) if bits == 16 else _f32(self._get()) if bits == 32 else _f64(self._get())
|
||||
return float(int(self))
|
||||
def __bool__(s): return bool(int(s))
|
||||
|
||||
# Arithmetic - floats use float(), ints use int()
|
||||
def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o)
|
||||
def __radd__(s, o): return float(o) + float(s) if s._float else int(o) + int(s)
|
||||
def __sub__(s, o): return float(s) - float(o) if s._float else int(s) - int(o)
|
||||
def __rsub__(s, o): return float(o) - float(s) if s._float else int(o) - int(s)
|
||||
def __mul__(s, o): return float(s) * float(o) if s._float else int(s) * int(o)
|
||||
def __rmul__(s, o): return float(o) * float(s) if s._float else int(o) * int(s)
|
||||
def __truediv__(s, o): return _div(float(s), float(o)) if s._float else _div(int(s), int(o))
|
||||
def __rtruediv__(s, o): return _div(float(o), float(s)) if s._float else _div(int(o), int(s))
|
||||
def __pow__(s, o): return float(s) ** float(o) if s._float else int(s) ** int(o)
|
||||
def __rpow__(s, o): return float(o) ** float(s) if s._float else int(o) ** int(s)
|
||||
def __neg__(s): return -float(s) if s._float else -int(s)
|
||||
def __abs__(s): return abs(float(s)) if s._float else abs(int(s))
|
||||
|
||||
# Bitwise - GPU shifts mask the shift amount to valid range
|
||||
def __and__(s, o): return int(s) & int(o)
|
||||
def __or__(s, o): return int(s) | int(o)
|
||||
def __xor__(s, o): return int(s) ^ int(o)
|
||||
def __invert__(s): return ~int(s)
|
||||
def __lshift__(s, o): n = int(o); return int(s) << n if 0 <= n < 64 or s._nbits() > 64 else 0
|
||||
def __rshift__(s, o): n = int(o); return int(s) >> n if 0 <= n < 64 or s._nbits() > 64 else 0
|
||||
def __rand__(s, o): return int(o) & int(s)
|
||||
def __ror__(s, o): return int(o) | int(s)
|
||||
def __rxor__(s, o): return int(o) ^ int(s)
|
||||
def __rlshift__(s, o): n = int(s); return int(o) << n if 0 <= n < 64 else 0
|
||||
def __rrshift__(s, o): n = int(s); return int(o) >> n if 0 <= n < 64 else 0
|
||||
|
||||
# Comparison - handle _DenormChecker specially
|
||||
def __eq__(s, o):
|
||||
if isinstance(o, _DenormChecker): return o._check(s)
|
||||
return float(s) == float(o) if s._float else int(s) == int(o)
|
||||
def __ne__(s, o):
|
||||
if isinstance(o, _DenormChecker): return not o._check(s)
|
||||
return float(s) != float(o) if s._float else int(s) != int(o)
|
||||
def __lt__(s, o): return float(s) < float(o) if s._float else int(s) < int(o)
|
||||
def __le__(s, o): return float(s) <= float(o) if s._float else int(s) <= int(o)
|
||||
def __gt__(s, o): return float(s) > float(o) if s._float else int(s) > int(o)
|
||||
def __ge__(s, o): return float(s) >= float(o) if s._float else int(s) >= int(o)
|
||||
|
||||
class Reg:
|
||||
"""GPU register: D0.f32 = S0.f32 + S1.f32 just works. Supports up to 128 bits for DS_LOAD_B128."""
|
||||
__slots__ = ('_val',)
|
||||
def __init__(self, val=0): self._val = int(val)
|
||||
|
||||
# Typed views - TypedView(reg, high, signed, is_float, is_bf16)
|
||||
u64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
i64 = property(lambda s: TypedView(s, 63, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
b64 = property(lambda s: TypedView(s, 63), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
f64 = property(lambda s: TypedView(s, 63, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v))))
|
||||
u32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
i32 = property(lambda s: TypedView(s, 31, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
b32 = property(lambda s: TypedView(s, 31), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
f32 = property(lambda s: TypedView(s, 31, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v))))
|
||||
u24 = property(lambda s: TypedView(s, 23))
|
||||
i24 = property(lambda s: TypedView(s, 23, signed=True))
|
||||
u16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
|
||||
i16 = property(lambda s: TypedView(s, 15, signed=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
|
||||
b16 = property(lambda s: TypedView(s, 15), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
|
||||
f16 = property(lambda s: TypedView(s, 15, is_float=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _i16(float(v))) & 0xffff)))
|
||||
bf16 = property(lambda s: TypedView(s, 15, is_float=True, is_bf16=True), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | ((v if isinstance(v, int) else _ibf16(float(v))) & 0xffff)))
|
||||
u8 = property(lambda s: TypedView(s, 7))
|
||||
i8 = property(lambda s: TypedView(s, 7, signed=True))
|
||||
u3 = property(lambda s: TypedView(s, 2)) # 3-bit for opsel fields
|
||||
u1 = property(lambda s: TypedView(s, 0)) # single bit
|
||||
|
||||
def __getitem__(s, key):
|
||||
if isinstance(key, slice): return TypedView(s, int(key.start), int(key.stop))
|
||||
return (s._val >> int(key)) & 1
|
||||
|
||||
def __setitem__(s, key, value):
|
||||
if isinstance(key, slice):
|
||||
high, low = int(key.start), int(key.stop)
|
||||
if high < low: high, low = low, high
|
||||
mask = (1 << (high - low + 1)) - 1
|
||||
s._val = (s._val & ~(mask << low)) | ((int(value) & mask) << low)
|
||||
elif value: s._val |= (1 << int(key))
|
||||
else: s._val &= ~(1 << int(key))
|
||||
|
||||
def __int__(s): return s._val
|
||||
def __index__(s): return s._val
|
||||
def __bool__(s): return bool(s._val)
|
||||
|
||||
# Arithmetic (for tmp = tmp + 1 patterns). Float operands trigger f32 interpretation.
|
||||
def __add__(s, o): return (_f32(s._val) + float(o)) if isinstance(o, float) else s._val + int(o)
|
||||
def __radd__(s, o): return (float(o) + _f32(s._val)) if isinstance(o, float) else int(o) + s._val
|
||||
def __sub__(s, o): return (_f32(s._val) - float(o)) if isinstance(o, float) else s._val - int(o)
|
||||
def __rsub__(s, o): return (float(o) - _f32(s._val)) if isinstance(o, float) else int(o) - s._val
|
||||
def __mul__(s, o): return (_f32(s._val) * float(o)) if isinstance(o, float) else s._val * int(o)
|
||||
def __rmul__(s, o): return (float(o) * _f32(s._val)) if isinstance(o, float) else int(o) * s._val
|
||||
def __and__(s, o): return s._val & int(o)
|
||||
def __rand__(s, o): return int(o) & s._val
|
||||
def __or__(s, o): return s._val | int(o)
|
||||
def __ror__(s, o): return int(o) | s._val
|
||||
def __xor__(s, o): return s._val ^ int(o)
|
||||
def __rxor__(s, o): return int(o) ^ s._val
|
||||
def __lshift__(s, o): n = int(o); return s._val << n if 0 <= n < 64 else 0
|
||||
def __rshift__(s, o): n = int(o); return s._val >> n if 0 <= n < 64 else 0
|
||||
def __invert__(s): return ~s._val
|
||||
|
||||
# Comparison (for tmp >= 0x100000000 patterns)
|
||||
def __lt__(s, o): return s._val < int(o)
|
||||
def __le__(s, o): return s._val <= int(o)
|
||||
def __gt__(s, o): return s._val > int(o)
|
||||
def __ge__(s, o): return s._val >= int(o)
|
||||
def __eq__(s, o): return s._val == int(o)
|
||||
def __ne__(s, o): return s._val != int(o)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PSEUDOCODE API - Functions and constants from AMD ISA pseudocode
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Rounding and float operations
|
||||
trunc, floor, ceil = _fpop(math.trunc), _fpop(math.floor), _fpop(math.ceil)
|
||||
def sqrt(x): return _SafeFloat(math.sqrt(x)) if x >= 0 else _SafeFloat(float("nan"))
|
||||
def log2(x): return math.log2(x) if x > 0 else (float("-inf") if x == 0 else float("nan"))
|
||||
def fract(x): return x - math.floor(x)
|
||||
def sin(x): return _trig(math.sin, x)
|
||||
def cos(x): return _trig(math.cos, x)
|
||||
def pow(a, b):
|
||||
try: return a ** b
|
||||
except OverflowError: return float("inf") if b > 0 else 0.0
|
||||
def isEven(x):
|
||||
x = float(x)
|
||||
if math.isinf(x) or math.isnan(x): return False
|
||||
return int(x) % 2 == 0
|
||||
def mantissa(f):
|
||||
if f == 0.0 or math.isinf(f) or math.isnan(f): return f
|
||||
m, _ = math.frexp(f)
|
||||
return m # AMD V_FREXP_MANT returns mantissa in [0.5, 1.0) range
|
||||
def signext_from_bit(val, bit):
|
||||
bit = int(bit)
|
||||
if bit == 0: return 0
|
||||
mask = (1 << bit) - 1
|
||||
val = int(val) & mask
|
||||
if val & (1 << (bit - 1)): return val - (1 << bit)
|
||||
return val
|
||||
|
||||
# Type conversions
|
||||
i32_to_f32 = u32_to_f32 = i32_to_f64 = u32_to_f64 = f32_to_f64 = f64_to_f32 = float
|
||||
def f32_to_i32(f): return _f_to_int(f, -2147483648, 2147483647)
|
||||
def f32_to_u32(f): return _f_to_int(f, 0, 4294967295)
|
||||
f64_to_i32, f64_to_u32 = f32_to_i32, f32_to_u32
|
||||
def f32_to_f16(f):
|
||||
f = float(f)
|
||||
if math.isnan(f): return 0x7e00 # f16 NaN
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00 # f16 ±infinity
|
||||
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
||||
except OverflowError: return 0x7c00 if f > 0 else 0xfc00 # overflow -> ±infinity
|
||||
def f16_to_f32(v): return v if isinstance(v, float) else _f16_to_f32_bits(v)
|
||||
def i16_to_f16(v): return f32_to_f16(float(_sext(int(v) & 0xffff, 16)))
|
||||
def u16_to_f16(v): return f32_to_f16(float(int(v) & 0xffff))
|
||||
def f16_to_i16(bits): f = _f16_to_f32_bits(bits); return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def f16_to_u16(bits): f = _f16_to_f32_bits(bits); return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def bf16_to_f32(v): return _bf16(v) if isinstance(v, int) else float(v)
|
||||
def f32_to_bf16(f): return _ibf16(f)
|
||||
def u8_to_u32(v): return int(v) & 0xff
|
||||
def u4_to_u32(v): return int(v) & 0xf
|
||||
def u32_to_u16(u): return int(u) & 0xffff
|
||||
def i32_to_i16(i): return ((int(i) + 32768) & 0xffff) - 32768
|
||||
def f16_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f16_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def f32_to_snorm(f): return max(-32768, min(32767, int(round(max(-1.0, min(1.0, f)) * 32767))))
|
||||
def f32_to_unorm(f): return max(0, min(65535, int(round(max(0.0, min(1.0, f)) * 65535))))
|
||||
def v_cvt_i16_f32(f): return max(-32768, min(32767, int(f))) if not math.isnan(f) else 0
|
||||
def v_cvt_u16_f32(f): return max(0, min(65535, int(f))) if not math.isnan(f) else 0
|
||||
def SAT8(v): return max(0, min(255, int(v)))
|
||||
def f32_to_u8(f): return max(0, min(255, int(f))) if not math.isnan(f) else 0
|
||||
|
||||
# Min/max operations
|
||||
def v_min_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _lt_neg_zero(a, b) else b)
|
||||
def v_max_f32(a, b): return a if math.isnan(b) else b if math.isnan(a) else (a if _gt_neg_zero(a, b) else b)
|
||||
v_min_f16, v_max_f16 = v_min_f32, v_max_f32
|
||||
v_min_i32, v_max_i32 = min, max
|
||||
v_min_i16, v_max_i16 = min, max
|
||||
def v_min_u32(a, b): return min(a & MASK32, b & MASK32)
|
||||
def v_max_u32(a, b): return max(a & MASK32, b & MASK32)
|
||||
def v_min_u16(a, b): return min(a & 0xffff, b & 0xffff)
|
||||
def v_max_u16(a, b): return max(a & 0xffff, b & 0xffff)
|
||||
def v_min3_f32(a, b, c): return v_min_f32(v_min_f32(a, b), c)
|
||||
def v_max3_f32(a, b, c): return v_max_f32(v_max_f32(a, b), c)
|
||||
v_min3_f16, v_max3_f16 = v_min3_f32, v_max3_f32
|
||||
v_min3_i32, v_max3_i32, v_min3_i16, v_max3_i16 = min, max, min, max
|
||||
def v_min3_u32(a, b, c): return min(a & MASK32, b & MASK32, c & MASK32)
|
||||
def v_max3_u32(a, b, c): return max(a & MASK32, b & MASK32, c & MASK32)
|
||||
def v_min3_u16(a, b, c): return min(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
def v_max3_u16(a, b, c): return max(a & 0xffff, b & 0xffff, c & 0xffff)
|
||||
|
||||
# SAD/MSAD operations
|
||||
def ABSDIFF(a, b): return abs(int(a) - int(b))
|
||||
def v_sad_u8(s0, s1, s2):
|
||||
"""V_SAD_U8: Sum of absolute differences of 4 byte pairs plus accumulator."""
|
||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||
result = s2
|
||||
for i in range(4):
|
||||
a = (s0 >> (i * 8)) & 0xff
|
||||
b = (s1 >> (i * 8)) & 0xff
|
||||
result += abs(a - b)
|
||||
return result & 0xffffffff
|
||||
def v_msad_u8(s0, s1, s2):
|
||||
"""V_MSAD_U8: Masked sum of absolute differences (skip if reference byte is 0)."""
|
||||
s0, s1, s2 = int(s0), int(s1), int(s2)
|
||||
result = s2
|
||||
for i in range(4):
|
||||
a = (s0 >> (i * 8)) & 0xff
|
||||
b = (s1 >> (i * 8)) & 0xff
|
||||
if b != 0: # Only add diff if reference (s1) byte is non-zero
|
||||
result += abs(a - b)
|
||||
return result & 0xffffffff
|
||||
|
||||
def BYTE_PERMUTE(data, sel):
|
||||
"""Select a byte from 64-bit data based on selector value."""
|
||||
sel = int(sel) & 0xff
|
||||
if sel <= 7: return (int(data) >> (sel * 8)) & 0xff
|
||||
if sel == 8: return 0xff if ((int(data) >> 15) & 1) else 0x00
|
||||
if sel == 9: return 0xff if ((int(data) >> 31) & 1) else 0x00
|
||||
if sel == 10: return 0xff if ((int(data) >> 47) & 1) else 0x00
|
||||
if sel == 11: return 0xff if ((int(data) >> 63) & 1) else 0x00
|
||||
if sel == 12: return 0x00
|
||||
return 0xff
|
||||
|
||||
# Pseudocode functions
|
||||
def s_ff1_i32_b32(v): return _ctz(v, 32)
|
||||
def s_ff1_i32_b64(v): return _ctz(v, 64)
|
||||
GT_NEG_ZERO, LT_NEG_ZERO = _gt_neg_zero, _lt_neg_zero
|
||||
def isNAN(x):
|
||||
try: return math.isnan(float(x))
|
||||
except (TypeError, ValueError): return False
|
||||
def isQuietNAN(x): return _check_nan_type(x, 1, True)
|
||||
def isSignalNAN(x): return _check_nan_type(x, 0, False)
|
||||
def fma(a, b, c):
|
||||
try: return math.fma(a, b, c)
|
||||
except ValueError: return float('nan')
|
||||
def ldexp(m, e): return math.ldexp(m, e)
|
||||
def sign(f): return 1 if math.copysign(1.0, f) < 0 else 0
|
||||
def exponent(f):
|
||||
if hasattr(f, '_bits') and hasattr(f, '_float') and f._float:
|
||||
raw = f._val
|
||||
if f._bits == 16: return (raw >> 10) & 0x1f
|
||||
if f._bits == 32: return (raw >> 23) & 0xff
|
||||
if f._bits == 64: return (raw >> 52) & 0x7ff
|
||||
f = float(f)
|
||||
if math.isinf(f) or math.isnan(f): return 255
|
||||
if f == 0.0: return 0
|
||||
try: bits = struct.unpack("<I", struct.pack("<f", f))[0]; return (bits >> 23) & 0xff
|
||||
except: return 0
|
||||
def signext(x): return int(x)
|
||||
def cvtToQuietNAN(x): return float('nan')
|
||||
|
||||
def F(x):
|
||||
"""32'F(x) or 64'F(x) - interpret x as float. If x is int, treat as bit pattern."""
|
||||
if isinstance(x, int): return _f32(x)
|
||||
if isinstance(x, TypedView): return x
|
||||
return float(x)
|
||||
|
||||
# Constants
|
||||
PI = math.pi
|
||||
WAVE32, WAVE64 = True, False
|
||||
OVERFLOW_F32, UNDERFLOW_F32 = float('inf'), 0.0
|
||||
OVERFLOW_F64, UNDERFLOW_F64 = float('inf'), 0.0
|
||||
MAX_FLOAT_F32 = 3.4028235e+38
|
||||
INF = _Inf()
|
||||
ROUND_MODE = _RoundMode()
|
||||
WAVE_MODE = _WaveMode()
|
||||
DENORM = _Denorm()
|
||||
|
||||
# 2/PI with 1201 bits of precision for V_TRIG_PREOP_F64
|
||||
TWO_OVER_PI_1201 = Reg(0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
line = line.strip()
|
||||
if joined_lines and (joined_lines[-1].rstrip().endswith(('||', '&&', '(', ',')) or
|
||||
(joined_lines[-1].count('(') > joined_lines[-1].count(')'))):
|
||||
joined_lines[-1] = joined_lines[-1].rstrip() + ' ' + line
|
||||
else:
|
||||
joined_lines.append(line)
|
||||
|
||||
lines = []
|
||||
indent, need_pass, in_first_match_loop = 0, False, False
|
||||
for line in joined_lines:
|
||||
line = line.split('//')[0].strip() # Strip C-style comments
|
||||
if not line: continue
|
||||
if line.startswith('if '):
|
||||
lines.append(' ' * indent + f"if {_expr(line[3:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('elsif '):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + f"elif {_expr(line[6:].rstrip(' then'))}:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line == 'else':
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
lines.append(' ' * indent + "else:")
|
||||
indent += 1
|
||||
need_pass = True
|
||||
elif line.startswith('endif'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass = False
|
||||
elif line.startswith('endfor'):
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
indent -= 1
|
||||
need_pass, in_first_match_loop = False, False
|
||||
elif line.startswith('declare '):
|
||||
pass
|
||||
elif m := re.match(r'for (\w+) in (.+?)\s*:\s*(.+?) do', line):
|
||||
start, end = _expr(m[2].strip()), _expr(m[3].strip())
|
||||
lines.append(' ' * indent + f"for {m[1]} in range({start}, int({end})+1):")
|
||||
indent += 1
|
||||
need_pass, in_first_match_loop = True, True
|
||||
elif '=' in line and not line.startswith('=='):
|
||||
need_pass = False
|
||||
line = line.rstrip(';')
|
||||
if m := re.match(r'\{\s*D1\.[ui]1\s*,\s*D0\.[ui]64\s*\}\s*=\s*(.+)', line):
|
||||
rhs = _expr(m[1])
|
||||
lines.append(' ' * indent + f"_full = {rhs}")
|
||||
lines.append(' ' * indent + f"D0.u64 = int(_full) & 0xffffffffffffffff")
|
||||
lines.append(' ' * indent + f"D1 = Reg((int(_full) >> 64) & 1)")
|
||||
elif any(op in line for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^=')):
|
||||
for op in ('+=', '-=', '*=', '/=', '|=', '&=', '^='):
|
||||
if op in line:
|
||||
lhs, rhs = line.split(op, 1)
|
||||
lines.append(' ' * indent + f"{lhs.strip()} {op} {_expr(rhs.strip())}")
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
lines.append(' ' * indent + stmt)
|
||||
if need_pass: lines.append(' ' * indent + "pass")
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _assign(lhs: str, rhs: str) -> str:
|
||||
if lhs in ('tmp', 'SCC', 'VCC', 'EXEC', 'D0', 'D1', 'saveexec', 'PC'):
|
||||
return f"{lhs} = Reg({rhs})"
|
||||
return f"{lhs} = {rhs}"
|
||||
|
||||
def _expr(e: str) -> str:
|
||||
e = e.strip()
|
||||
e = e.replace('&&', ' and ').replace('||', ' or ').replace('<>', ' != ')
|
||||
e = re.sub(r'!([^=])', r' not \1', e)
|
||||
e = re.sub(r'\{\s*(\w+\.u32)\s*,\s*(\w+\.u32)\s*\}', r'_pack32(\1, \2)', e)
|
||||
def pack(m):
|
||||
hi, lo = _expr(m[1].strip()), _expr(m[2].strip())
|
||||
return f'_pack({hi}, {lo})'
|
||||
e = re.sub(r'\{\s*([^,{}]+)\s*,\s*([^,{}]+)\s*\}', pack, e)
|
||||
e = re.sub(r"1201'B\(2\.0\s*/\s*PI\)", "TWO_OVER_PI_1201", e)
|
||||
e = re.sub(r"\d+'([0-9a-fA-Fx]+)[UuFf]*", r'\1', e)
|
||||
e = re.sub(r"\d+'[FIBU]\(", "(", e)
|
||||
e = re.sub(r'\bB\(', '(', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])ULL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])LL\b', r'\1', e)
|
||||
e = re.sub(r'([0-9a-fA-Fx])U\b', r'\1', e)
|
||||
e = re.sub(r'(\d\.?\d*)F\b', r'\1', e)
|
||||
e = re.sub(r'(\[laneId\])\.[uib]\d+', r'\1', e)
|
||||
e = e.replace('+INF', 'INF').replace('-INF', '(-INF)')
|
||||
e = re.sub(r'NAN\.f\d+', 'float("nan")', e)
|
||||
def convert_verilog_slice(m):
|
||||
start, width = m.group(1).strip(), m.group(2).strip()
|
||||
return f'[({start}) + ({width}) - 1 : ({start})]'
|
||||
e = re.sub(r'\[([^:\[\]]+)\s*\+:\s*([^:\[\]]+)\]', convert_verilog_slice, e)
|
||||
def process_brackets(s):
|
||||
result, i = [], 0
|
||||
while i < len(s):
|
||||
if s[i] == '[':
|
||||
depth, start = 1, i + 1
|
||||
j = start
|
||||
while j < len(s) and depth > 0:
|
||||
if s[j] == '[': depth += 1
|
||||
elif s[j] == ']': depth -= 1
|
||||
j += 1
|
||||
inner = _expr(s[start:j-1])
|
||||
result.append('[' + inner + ']')
|
||||
i = j
|
||||
else:
|
||||
result.append(s[i])
|
||||
i += 1
|
||||
return ''.join(result)
|
||||
e = process_brackets(e)
|
||||
while '?' in e:
|
||||
depth, bracket, q = 0, 0, -1
|
||||
for i, c in enumerate(e):
|
||||
if c == '(': depth += 1
|
||||
elif c == ')': depth -= 1
|
||||
elif c == '[': bracket += 1
|
||||
elif c == ']': bracket -= 1
|
||||
elif c == '?' and depth == 0 and bracket == 0: q = i; break
|
||||
if q < 0: break
|
||||
depth, bracket, col = 0, 0, -1
|
||||
for i in range(q + 1, len(e)):
|
||||
if e[i] == '(': depth += 1
|
||||
elif e[i] == ')': depth -= 1
|
||||
elif e[i] == '[': bracket += 1
|
||||
elif e[i] == ']': bracket -= 1
|
||||
elif e[i] == ':' and depth == 0 and bracket == 0: col = i; break
|
||||
if col < 0: break
|
||||
cond, t, f = e[:q].strip(), e[q+1:col].strip(), e[col+1:].strip()
|
||||
e = f'(({t}) if ({cond}) else ({f}))'
|
||||
return e
|
||||
|
||||
def _apply_pseudocode_fixes(op_name: str, code: str) -> str:
|
||||
"""Apply known fixes for PDF pseudocode bugs."""
|
||||
if op_name == 'V_DIV_FMAS_F32':
|
||||
code = code.replace('D0.f32 = 2.0 ** 32 * fma(S0.f32, S1.f32, S2.f32)',
|
||||
'D0.f32 = (2.0 ** 64 if exponent(S2.f32) > 127 else 2.0 ** -64) * fma(S0.f32, S1.f32, S2.f32)')
|
||||
if op_name == 'V_DIV_FMAS_F64':
|
||||
code = code.replace('D0.f64 = 2.0 ** 64 * fma(S0.f64, S1.f64, S2.f64)',
|
||||
'D0.f64 = (2.0 ** 128 if exponent(S2.f64) > 1023 else 2.0 ** -128) * fma(S0.f64, S1.f64, S2.f64)')
|
||||
if op_name == 'V_DIV_SCALE_F32':
|
||||
code = code.replace('D0.f32 = float("nan")', 'VCC = Reg(0x1); D0.f32 = float("nan")')
|
||||
code = code.replace('elif S1.f32 == DENORM.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif False:\n pass')
|
||||
code += '\nif S1.f32 == DENORM.f32:\n D0.f32 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f32) <= 23:\n D0.f32 = ldexp(S0.f32, 64)', 'elif exponent(S2.f32) <= 23:\n VCC = Reg(0x1); D0.f32 = ldexp(S0.f32, 64)')
|
||||
code = code.replace('elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)\n if S0.f32 == S2.f32:\n D0.f32 = ldexp(S0.f32, 64)', 'elif S2.f32 / S1.f32 == DENORM.f32:\n VCC = Reg(0x1)')
|
||||
if op_name == 'V_DIV_SCALE_F64':
|
||||
code = code.replace('D0.f64 = float("nan")', 'VCC = Reg(0x1); D0.f64 = float("nan")')
|
||||
code = code.replace('elif S1.f64 == DENORM.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif False:\n pass')
|
||||
code += '\nif S1.f64 == DENORM.f64:\n D0.f64 = float("nan")'
|
||||
code = code.replace('elif exponent(S2.f64) <= 52:\n D0.f64 = ldexp(S0.f64, 128)', 'elif exponent(S2.f64) <= 52:\n VCC = Reg(0x1); D0.f64 = ldexp(S0.f64, 128)')
|
||||
code = code.replace('elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)\n if S0.f64 == S2.f64:\n D0.f64 = ldexp(S0.f64, 128)', 'elif S2.f64 / S1.f64 == DENORM.f64:\n VCC = Reg(0x1)')
|
||||
if op_name == 'V_DIV_FIXUP_F32':
|
||||
code = code.replace('D0.f32 = ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))',
|
||||
'D0.f32 = ((-OVERFLOW_F32) if (sign_out) else (OVERFLOW_F32)) if isNAN(S0.f32) else ((-abs(S0.f32)) if (sign_out) else (abs(S0.f32)))')
|
||||
if op_name == 'V_DIV_FIXUP_F64':
|
||||
code = code.replace('D0.f64 = ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))',
|
||||
'D0.f64 = ((-OVERFLOW_F64) if (sign_out) else (OVERFLOW_F64)) if isNAN(S0.f64) else ((-abs(S0.f64)) if (sign_out) else (abs(S0.f64)))')
|
||||
if op_name == 'V_TRIG_PREOP_F64':
|
||||
code = code.replace('result = F((TWO_OVER_PI_1201[1200 : 0] << shift.u32) & 0x1fffffffffffff)',
|
||||
'result = float(((TWO_OVER_PI_1201[1200 : 0] << int(shift)) >> (1201 - 53)) & 0x1fffffffffffff)')
|
||||
return code
|
||||
|
||||
def _generate_function(cls_name: str, op_name: str, pc: str, code: str) -> str:
|
||||
"""Generate a single compiled pseudocode function.
|
||||
Functions take int parameters and return dict of int values.
|
||||
Reg wrapping happens inside the function, only for registers actually used."""
|
||||
has_d1 = '{ D1' in pc
|
||||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op_name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
is_ds = cls_name == 'DSOp'
|
||||
is_flat = cls_name in ('FLATOp', 'GLOBALOp', 'SCRATCHOp')
|
||||
is_smem = cls_name == 'SMEMOp'
|
||||
has_s_array = 'S[i]' in pc # FMA_MIX style: S[0], S[1], S[2] array access
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op_name}"
|
||||
|
||||
# Detect which registers are used/modified
|
||||
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
||||
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
|
||||
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
|
||||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||
|
||||
# Build function signature and Reg init lines
|
||||
if is_smem:
|
||||
lines = [f"def {fn_name}(MEM, addr):"]
|
||||
reg_inits = ["ADDR=Reg(addr)", "SDATA=Reg(0)"]
|
||||
special_regs = []
|
||||
elif is_ds:
|
||||
lines = [f"def {fn_name}(MEM, addr, data0, data1, offset0, offset1):"]
|
||||
reg_inits = ["ADDR=Reg(addr)", "DATA0=Reg(data0)", "DATA1=Reg(data1)", "OFFSET0=Reg(offset0)", "OFFSET1=Reg(offset1)", "RETURN_DATA=Reg(0)"]
|
||||
special_regs = [('DATA', 'DATA0'), ('DATA2', 'DATA1'), ('OFFSET', 'OFFSET0'), ('ADDR_BASE', 'ADDR')]
|
||||
elif is_flat:
|
||||
lines = [f"def {fn_name}(MEM, addr, vdata, vdst):"]
|
||||
reg_inits = ["ADDR=addr", "VDATA=Reg(vdata)", "VDST=Reg(vdst)", "RETURN_DATA=Reg(0)"]
|
||||
special_regs = [('DATA', 'VDATA')]
|
||||
elif has_s_array:
|
||||
# FMA_MIX style: needs S[i] array, opsel, opsel_hi for source selection (neg/neg_hi applied in emu.py before call)
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None, opsel=0, opsel_hi=0):"]
|
||||
reg_inits = ["S0=Reg(s0)", "S1=Reg(s1)", "S2=Reg(s2)", "S=[S0,S1,S2]", "D0=Reg(d0)", "OPSEL=Reg(opsel)", "OPSEL_HI=Reg(opsel_hi)"]
|
||||
special_regs = []
|
||||
# Detect array declarations like "declare in : 32'F[3]" and create them (rename 'in' to 'ins' since 'in' is a keyword)
|
||||
if "in[" in combined:
|
||||
reg_inits.append("ins=[Reg(0),Reg(0),Reg(0)]")
|
||||
code = code.replace("in[", "ins[")
|
||||
else:
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None):"]
|
||||
# Only create Regs for registers actually used in the pseudocode
|
||||
reg_inits = []
|
||||
if 'S0' in combined: reg_inits.append("S0=Reg(s0)")
|
||||
if 'S1' in combined: reg_inits.append("S1=Reg(s1)")
|
||||
if 'S2' in combined: reg_inits.append("S2=Reg(s2)")
|
||||
if modifies_d0 or 'D0' in combined: reg_inits.append("D0=Reg(s0)" if is_div_scale else "D0=Reg(d0)")
|
||||
if modifies_scc or 'SCC' in combined: reg_inits.append("SCC=Reg(scc)")
|
||||
if modifies_vcc or 'VCC' in combined: reg_inits.append("VCC=Reg(vcc)")
|
||||
if modifies_exec or 'EXEC' in combined: reg_inits.append("EXEC=Reg(exec_mask)")
|
||||
if modifies_pc or 'PC' in combined: reg_inits.append("PC=Reg(pc) if pc is not None else None")
|
||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
|
||||
# Build init code
|
||||
init_parts = reg_inits.copy()
|
||||
for name, init in special_regs:
|
||||
if name in combined: init_parts.append(f"{name}={init}")
|
||||
if 'EXEC_LO' in code: init_parts.append("EXEC_LO=TypedView(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in code: init_parts.append("EXEC_HI=TypedView(EXEC, 63, 32)")
|
||||
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_parts.append("VCCZ=Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_parts.append("EXECZ=Reg(1 if EXEC._val == 0 else 0)")
|
||||
|
||||
# Add init line and separator
|
||||
if init_parts: lines.append(f" {'; '.join(init_parts)}")
|
||||
|
||||
# Add compiled pseudocode
|
||||
for line in code.split('\n'):
|
||||
if line.strip(): lines.append(f" {line}")
|
||||
|
||||
# Build result dict
|
||||
result_items = []
|
||||
if modifies_d0: result_items.append("'D0': D0._val")
|
||||
if modifies_scc: result_items.append("'SCC': SCC._val")
|
||||
if modifies_vcc: result_items.append("'VCC': VCC._val")
|
||||
if modifies_exec: result_items.append("'EXEC': EXEC._val")
|
||||
if has_d1: result_items.append("'D1': D1._val")
|
||||
if modifies_pc: result_items.append("'PC': PC._val")
|
||||
if is_smem and 'SDATA' in combined and re.search(r'^\s*SDATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'SDATA': SDATA._val")
|
||||
if is_ds and 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||
if is_flat:
|
||||
if 'RETURN_DATA' in combined and re.search(r'^\s*RETURN_DATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'RETURN_DATA': RETURN_DATA._val")
|
||||
if re.search(r'^\s*VDATA[\.\[].*=', code, re.MULTILINE):
|
||||
result_items.append("'VDATA': VDATA._val")
|
||||
lines.append(f" return {{{', '.join(result_items)}}}")
|
||||
return '\n'.join(lines)
|
||||
|
||||
# Build the globals dict for exec() - includes all pcode symbols
|
||||
_PCODE_GLOBALS = {
|
||||
'Reg': Reg, 'TypedView': TypedView, '_pack': _pack, '_pack32': _pack32,
|
||||
'ABSDIFF': ABSDIFF, 'BYTE_PERMUTE': BYTE_PERMUTE, 'DENORM': DENORM, 'F': F,
|
||||
'GT_NEG_ZERO': GT_NEG_ZERO, 'LT_NEG_ZERO': LT_NEG_ZERO, 'INF': INF,
|
||||
'MAX_FLOAT_F32': MAX_FLOAT_F32, 'OVERFLOW_F32': OVERFLOW_F32, 'OVERFLOW_F64': OVERFLOW_F64,
|
||||
'UNDERFLOW_F32': UNDERFLOW_F32, 'UNDERFLOW_F64': UNDERFLOW_F64,
|
||||
'PI': PI, 'ROUND_MODE': ROUND_MODE, 'WAVE_MODE': WAVE_MODE,
|
||||
'WAVE32': WAVE32, 'WAVE64': WAVE64, 'TWO_OVER_PI_1201': TWO_OVER_PI_1201,
|
||||
'SAT8': SAT8, 'trunc': trunc, 'floor': floor, 'ceil': ceil, 'sqrt': sqrt,
|
||||
'log2': log2, 'fract': fract, 'sin': sin, 'cos': cos, 'pow': pow,
|
||||
'isEven': isEven, 'mantissa': mantissa, 'signext_from_bit': signext_from_bit,
|
||||
'i32_to_f32': i32_to_f32, 'u32_to_f32': u32_to_f32, 'i32_to_f64': i32_to_f64,
|
||||
'u32_to_f64': u32_to_f64, 'f32_to_f64': f32_to_f64, 'f64_to_f32': f64_to_f32,
|
||||
'f32_to_i32': f32_to_i32, 'f32_to_u32': f32_to_u32, 'f64_to_i32': f64_to_i32,
|
||||
'f64_to_u32': f64_to_u32, 'f32_to_f16': f32_to_f16, 'f16_to_f32': f16_to_f32,
|
||||
'i16_to_f16': i16_to_f16, 'u16_to_f16': u16_to_f16, 'f16_to_i16': f16_to_i16,
|
||||
'f16_to_u16': f16_to_u16, 'bf16_to_f32': bf16_to_f32, 'f32_to_bf16': f32_to_bf16,
|
||||
'u8_to_u32': u8_to_u32, 'u4_to_u32': u4_to_u32, 'u32_to_u16': u32_to_u16,
|
||||
'i32_to_i16': i32_to_i16, 'f16_to_snorm': f16_to_snorm, 'f16_to_unorm': f16_to_unorm,
|
||||
'f32_to_snorm': f32_to_snorm, 'f32_to_unorm': f32_to_unorm,
|
||||
'v_cvt_i16_f32': v_cvt_i16_f32, 'v_cvt_u16_f32': v_cvt_u16_f32, 'f32_to_u8': f32_to_u8,
|
||||
'v_min_f32': v_min_f32, 'v_max_f32': v_max_f32, 'v_min_f16': v_min_f16, 'v_max_f16': v_max_f16,
|
||||
'v_min_i32': v_min_i32, 'v_max_i32': v_max_i32, 'v_min_i16': v_min_i16, 'v_max_i16': v_max_i16,
|
||||
'v_min_u32': v_min_u32, 'v_max_u32': v_max_u32, 'v_min_u16': v_min_u16, 'v_max_u16': v_max_u16,
|
||||
'v_min3_f32': v_min3_f32, 'v_max3_f32': v_max3_f32, 'v_min3_f16': v_min3_f16, 'v_max3_f16': v_max3_f16,
|
||||
'v_min3_i32': v_min3_i32, 'v_max3_i32': v_max3_i32, 'v_min3_i16': v_min3_i16, 'v_max3_i16': v_max3_i16,
|
||||
'v_min3_u32': v_min3_u32, 'v_max3_u32': v_max3_u32, 'v_min3_u16': v_min3_u16, 'v_max3_u16': v_max3_u16,
|
||||
'v_sad_u8': v_sad_u8, 'v_msad_u8': v_msad_u8,
|
||||
's_ff1_i32_b32': s_ff1_i32_b32, 's_ff1_i32_b64': s_ff1_i32_b64,
|
||||
'isNAN': isNAN, 'isQuietNAN': isQuietNAN, 'isSignalNAN': isSignalNAN,
|
||||
'fma': fma, 'ldexp': ldexp, 'sign': sign, 'exponent': exponent,
|
||||
'signext': signext, 'cvtToQuietNAN': cvtToQuietNAN,
|
||||
}
|
||||
|
||||
@functools.cache
|
||||
def compile_pseudocode(cls_name: str, op_name: str, pseudocode: str):
|
||||
"""Compile pseudocode string to executable function. Cached for performance."""
|
||||
code = _compile_pseudocode(pseudocode)
|
||||
code = _apply_pseudocode_fixes(op_name, code)
|
||||
fn_code = _generate_function(cls_name, op_name, pseudocode, code)
|
||||
fn_name = f"_{cls_name}_{op_name}"
|
||||
local_ns = {}
|
||||
exec(fn_code, _PCODE_GLOBALS, local_ns)
|
||||
return local_ns[fn_name]
|
||||
@@ -22,45 +22,3 @@ def get_llvm_objdump():
|
||||
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-objdump not found")
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# EXECUTION CONTEXT (for testing compiled pseudocode)
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class ExecContext:
|
||||
"""Context for running compiled pseudocode in tests."""
|
||||
def __init__(self, s0=0, s1=0, s2=0, d0=0, scc=0, vcc=0, lane=0, exec_mask=0xffffffff, literal=0, vgprs=None, src0_idx=0, vdst_idx=0):
|
||||
from extra.assembly.amd.pcode import Reg, MASK32, MASK64, TypedView
|
||||
self._Reg, self._MASK64, self._TypedView = Reg, MASK64, TypedView
|
||||
self.S0, self.S1, self.S2 = Reg(s0), Reg(s1), Reg(s2)
|
||||
self.D0, self.D1 = Reg(d0), Reg(0)
|
||||
self.SCC, self.VCC, self.EXEC = Reg(scc), Reg(vcc), Reg(exec_mask)
|
||||
self.tmp, self.saveexec = Reg(0), Reg(exec_mask)
|
||||
self.lane, self.laneId, self.literal = lane, lane, literal
|
||||
self.SIMM16, self.SIMM32 = Reg(literal), Reg(literal)
|
||||
self.VGPR = vgprs if vgprs is not None else {}
|
||||
self.SRC0, self.VDST = Reg(src0_idx), Reg(vdst_idx)
|
||||
|
||||
def run(self, code: str):
|
||||
"""Execute compiled code."""
|
||||
import extra.assembly.amd.pcode as pcode
|
||||
ns = {k: getattr(pcode, k) for k in dir(pcode) if not k.startswith('_')}
|
||||
# Also include underscore-prefixed helpers that compiled pseudocode uses
|
||||
for k in ['_pack', '_pack32']:
|
||||
if hasattr(pcode, k): ns[k] = getattr(pcode, k)
|
||||
ns.update({
|
||||
'S0': self.S0, 'S1': self.S1, 'S2': self.S2, 'D0': self.D0, 'D1': self.D1,
|
||||
'SCC': self.SCC, 'VCC': self.VCC, 'EXEC': self.EXEC,
|
||||
'EXEC_LO': self._TypedView(self.EXEC, 31, 0), 'EXEC_HI': self._TypedView(self.EXEC, 63, 32),
|
||||
'tmp': self.tmp, 'saveexec': self.saveexec,
|
||||
'lane': self.lane, 'laneId': self.laneId, 'literal': self.literal,
|
||||
'SIMM16': self.SIMM16, 'SIMM32': self.SIMM32, 'VGPR': self.VGPR, 'SRC0': self.SRC0, 'VDST': self.VDST,
|
||||
})
|
||||
exec(code, ns)
|
||||
def _sync(ctx_reg, ns_val):
|
||||
if isinstance(ns_val, self._Reg): ctx_reg._val = ns_val._val
|
||||
else: ctx_reg._val = int(ns_val) & self._MASK64
|
||||
for name in ('SCC', 'VCC', 'EXEC', 'D0', 'D1', 'tmp', 'saveexec'):
|
||||
if ns.get(name) is not getattr(self, name): _sync(getattr(self, name), ns[name])
|
||||
|
||||
def result(self) -> dict: return {"d0": self.D0._val, "scc": self.SCC._val & 1}
|
||||
|
||||
@@ -5,9 +5,8 @@ Set USE_HW=1 to run on both emulator and real hardware, comparing results.
|
||||
"""
|
||||
import ctypes, os, struct
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import RawImm
|
||||
from extra.assembly.amd.dsl import RawImm, _i32, _f32
|
||||
from extra.assembly.amd.emu import WaveState, run_asm, set_valid_mem_ranges
|
||||
from extra.assembly.amd.pcode import _i32, _f32
|
||||
|
||||
VCC = SrcEnum.VCC_LO # For VOP3SD sdst field
|
||||
USE_HW = os.environ.get("USE_HW", "0") == "1"
|
||||
|
||||
@@ -255,7 +255,7 @@ class TestF16Conversions(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f16_f32_small(self):
|
||||
"""V_CVT_F16_F32 converts small f32 value."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 0.5),
|
||||
v_cvt_f16_f32_e32(v[1], v[0]),
|
||||
@@ -293,7 +293,7 @@ class TestF16Conversions(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f16_f32_reads_full_32bit_source(self):
|
||||
"""V_CVT_F16_F32 must read full 32-bit f32 source."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3fc00000), # f32 1.5
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -560,7 +560,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_abs_negative(self):
|
||||
"""V_CVT_F32_F16 with |abs| on negative value."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
f16_neg1 = f32_to_f16(-1.0) # 0xbc00
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_neg1),
|
||||
@@ -573,7 +573,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_abs_positive(self):
|
||||
"""V_CVT_F32_F16 with |abs| on positive value (should stay positive)."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0) # 0x4000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_2),
|
||||
@@ -586,7 +586,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_neg_positive(self):
|
||||
"""V_CVT_F32_F16 with neg on positive value."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0) # 0x4000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_2),
|
||||
@@ -599,7 +599,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f32_f16_neg_negative(self):
|
||||
"""V_CVT_F32_F16 with neg on negative value (double negative)."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
f16_neg2 = f32_to_f16(-2.0) # 0xc000
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_neg2),
|
||||
@@ -612,7 +612,7 @@ class TestCvtF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_cvt_f16_f32_then_pack_for_wmma(self):
|
||||
"""CVT F32->F16 followed by pack (common WMMA pattern)."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
f32_val = 3.5
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(f32_val)),
|
||||
@@ -668,7 +668,7 @@ class TestConversionRounding(unittest.TestCase):
|
||||
|
||||
def test_f16_to_f32_precision(self):
|
||||
"""F16 to F32 conversion precision."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
f16_val = f32_to_f16(1.5)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f16_val),
|
||||
@@ -680,7 +680,7 @@ class TestConversionRounding(unittest.TestCase):
|
||||
|
||||
def test_f16_denormal_to_f32(self):
|
||||
"""F16 denormal converts to small positive f32."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
f16_denorm = 0x0001 # Smallest positive f16 denormal
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], f16_denorm),
|
||||
|
||||
@@ -768,7 +768,7 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_fma_f16_inline_const_1_0(self):
|
||||
"""V_FMA_F16: a*b + 1.0 should use f16 inline constant."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16, _f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(0.325928) # ~0x3537
|
||||
f16_b = f32_to_f16(-0.486572) # ~0xb7c9
|
||||
instructions = [
|
||||
@@ -785,7 +785,7 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_fma_f16_inline_const_0_5(self):
|
||||
"""V_FMA_F16: a*b + 0.5 should use f16 inline constant."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16, _f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(2.0)
|
||||
f16_b = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
@@ -802,7 +802,7 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_fma_f16_inline_const_neg_1_0(self):
|
||||
"""V_FMA_F16: a*b + (-1.0) should use f16 inline constant."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16, _f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16, _f16
|
||||
f16_a = f32_to_f16(2.0)
|
||||
f16_b = f32_to_f16(3.0)
|
||||
instructions = [
|
||||
@@ -819,7 +819,7 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_add_f16_abs_both(self):
|
||||
"""V_ADD_F16 with abs on both operands."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16, _f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16, _f16
|
||||
f16_neg2 = f32_to_f16(-2.0)
|
||||
f16_neg3 = f32_to_f16(-3.0)
|
||||
instructions = [
|
||||
@@ -835,7 +835,7 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
def test_v_mul_f16_neg_abs(self):
|
||||
"""V_MUL_F16 with neg on one operand and abs on another."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16, _f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16, _f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
f16_neg3 = f32_to_f16(-3.0)
|
||||
instructions = [
|
||||
@@ -854,7 +854,7 @@ class TestF16Modifiers(unittest.TestCase):
|
||||
|
||||
This tests the case from AMD_LLVM sin(0) where V_FMAC_F16 writes to v0.h.
|
||||
"""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x38003c00), # v0 = {hi=0.5, lo=1.0}
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
|
||||
@@ -155,7 +155,7 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mix_f32_src2_f16_lo(self):
|
||||
"""V_FMA_MIX_F32 with src2 as f16 from lo bits."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(1.0)),
|
||||
@@ -172,7 +172,7 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mix_f32_src2_f16_hi(self):
|
||||
"""V_FMA_MIX_F32 with src2 as f16 from hi bits."""
|
||||
from extra.assembly.amd.pcode import f32_to_f16
|
||||
from extra.assembly.amd.dsl import f32_to_f16
|
||||
f16_2 = f32_to_f16(2.0)
|
||||
val = (f16_2 << 16) | 0
|
||||
instructions = [
|
||||
@@ -205,7 +205,7 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mixlo_f16(self):
|
||||
"""V_FMA_MIXLO_F16 writes to low 16 bits of destination."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(2.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -225,7 +225,7 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mixlo_f16_all_f32_sources(self):
|
||||
"""V_FMA_MIXLO_F16 with all f32 sources."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], f2i(1.0)),
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -243,7 +243,7 @@ class TestFmaMix(unittest.TestCase):
|
||||
|
||||
def test_v_fma_mixlo_f16_sin_case(self):
|
||||
"""V_FMA_MIXLO_F16 case from sin kernel."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3f800000), # f32 1.0
|
||||
v_mov_b32_e32(v[3], s[0]),
|
||||
@@ -265,7 +265,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
|
||||
def test_v_pk_add_f16_basic(self):
|
||||
"""V_PK_ADD_F16 adds two packed f16 values."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x40003c00), # hi=2.0, lo=1.0
|
||||
s_mov_b32(s[1], 0x44004200), # hi=4.0, lo=3.0
|
||||
@@ -282,7 +282,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
|
||||
def test_v_pk_mul_f16_basic(self):
|
||||
"""V_PK_MUL_F16 multiplies two packed f16 values."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x42004000), # hi=3.0, lo=2.0
|
||||
s_mov_b32(s[1], 0x45004400), # hi=5.0, lo=4.0
|
||||
@@ -299,7 +299,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
|
||||
def test_v_pk_fma_f16_basic(self):
|
||||
"""V_PK_FMA_F16: D = A * B + C for packed f16."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x42004000), # A: hi=3.0, lo=2.0
|
||||
s_mov_b32(s[1], 0x45004400), # B: hi=5.0, lo=4.0
|
||||
@@ -321,7 +321,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
Inline constants for VOP3P are f16 values in the low 16 bits only.
|
||||
hi half of inline constant is 0, so hi result = v0.hi + 0 = 1.0.
|
||||
"""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x3c003c00), # packed f16: hi=1.0, lo=1.0
|
||||
v_mov_b32_e32(v[0], s[0]),
|
||||
@@ -339,7 +339,7 @@ class TestVOP3P(unittest.TestCase):
|
||||
"""V_PK_MUL_F16 with inline constant POS_TWO (2.0).
|
||||
Inline constant has value only in low 16 bits, hi is 0.
|
||||
"""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
# v0 = packed (3.0, 4.0), multiply by POS_TWO
|
||||
# lo = 3.0 * 2.0 = 6.0, hi = 4.0 * 0.0 = 0.0 (inline const hi is 0)
|
||||
instructions = [
|
||||
@@ -504,7 +504,7 @@ class TestPackedMixedSigns(unittest.TestCase):
|
||||
|
||||
def test_pk_add_f16_mixed_signs(self):
|
||||
"""V_PK_ADD_F16 with mixed positive/negative values."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0xc0003c00), # packed: hi=-2.0, lo=1.0
|
||||
s_mov_b32(s[1], 0x3c003c00), # packed: hi=1.0, lo=1.0
|
||||
@@ -521,7 +521,7 @@ class TestPackedMixedSigns(unittest.TestCase):
|
||||
|
||||
def test_pk_mul_f16_zero(self):
|
||||
"""V_PK_MUL_F16 with zero."""
|
||||
from extra.assembly.amd.pcode import _f16
|
||||
from extra.assembly.amd.dsl import _f16
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0x40004000), # packed: 2.0, 2.0
|
||||
s_mov_b32(s[1], 0x00000000), # packed: 0.0, 0.0
|
||||
|
||||
@@ -1,403 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for the RDNA3 pseudocode DSL."""
|
||||
import unittest
|
||||
from extra.assembly.amd.pcode import (Reg, TypedView, TypedView, MASK32, MASK64,
|
||||
_f32, _i32, _f16, _i16, f32_to_f16, isNAN, _bf16, _ibf16, bf16_to_f32, f32_to_bf16,
|
||||
BYTE_PERMUTE, v_sad_u8, v_msad_u8, _compile_pseudocode, _expr, compile_pseudocode)
|
||||
from extra.assembly.amd.test.helpers import ExecContext
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import VOP3SDOp_PCODE, VOPCOp_PCODE
|
||||
from extra.assembly.amd.autogen.rdna3.enum import VOP3SDOp, VOPCOp
|
||||
|
||||
# Compile pseudocode functions on demand for regression tests
|
||||
_VOP3SDOp_V_DIV_SCALE_F32 = compile_pseudocode('VOP3SDOp', 'V_DIV_SCALE_F32', VOP3SDOp_PCODE[VOP3SDOp.V_DIV_SCALE_F32])
|
||||
_VOPCOp_V_CMP_CLASS_F32 = compile_pseudocode('VOPCOp', 'V_CMP_CLASS_F32', VOPCOp_PCODE[VOPCOp.V_CMP_CLASS_F32])
|
||||
|
||||
class TestReg(unittest.TestCase):
|
||||
def test_u32_read(self):
|
||||
r = Reg(0xDEADBEEF)
|
||||
self.assertEqual(int(r.u32), 0xDEADBEEF)
|
||||
|
||||
def test_u32_write(self):
|
||||
r = Reg(0)
|
||||
r.u32 = 0x12345678
|
||||
self.assertEqual(r._val, 0x12345678)
|
||||
|
||||
def test_f32_read(self):
|
||||
r = Reg(0x40400000) # 3.0f
|
||||
self.assertAlmostEqual(float(r.f32), 3.0)
|
||||
|
||||
def test_f32_write(self):
|
||||
r = Reg(0)
|
||||
r.f32 = 3.0
|
||||
self.assertEqual(r._val, 0x40400000)
|
||||
|
||||
def test_i32_signed(self):
|
||||
r = Reg(0xFFFFFFFF) # -1 as signed
|
||||
self.assertEqual(int(r.i32), -1)
|
||||
|
||||
def test_u64(self):
|
||||
r = Reg(0xDEADBEEFCAFEBABE)
|
||||
self.assertEqual(int(r.u64), 0xDEADBEEFCAFEBABE)
|
||||
|
||||
def test_f64(self):
|
||||
r = Reg(0x4008000000000000) # 3.0 as f64
|
||||
self.assertAlmostEqual(float(r.f64), 3.0)
|
||||
|
||||
class TestTypedView(unittest.TestCase):
|
||||
def test_bit_slice(self):
|
||||
r = Reg(0xDEADBEEF)
|
||||
# Slices return TypedView which supports .u32, .u16 etc (matching pseudocode like S1.u32[1:0].u32)
|
||||
self.assertEqual(r.u32[7:0].u32, 0xEF)
|
||||
self.assertEqual(r.u32[15:8].u32, 0xBE)
|
||||
self.assertEqual(r.u32[23:16].u32, 0xAD)
|
||||
self.assertEqual(r.u32[31:24].u32, 0xDE)
|
||||
# Also works with int() for arithmetic
|
||||
self.assertEqual(int(r.u32[7:0]), 0xEF)
|
||||
|
||||
def test_single_bit_read(self):
|
||||
r = Reg(0b11010101)
|
||||
self.assertEqual(r.u32[0], 1)
|
||||
self.assertEqual(r.u32[1], 0)
|
||||
self.assertEqual(r.u32[2], 1)
|
||||
self.assertEqual(r.u32[3], 0)
|
||||
|
||||
def test_single_bit_write(self):
|
||||
r = Reg(0)
|
||||
r.u32[5] = 1
|
||||
r.u32[3] = 1
|
||||
self.assertEqual(r._val, 0b00101000)
|
||||
|
||||
def test_nested_bit_access(self):
|
||||
# S0.u32[S1.u32[4:0]] - access bit at position from another register
|
||||
s0 = Reg(0b11010101)
|
||||
s1 = Reg(3)
|
||||
bit_pos = s1.u32[4:0] # TypedView, int value = 3
|
||||
bit_val = s0.u32[int(bit_pos)] # bit 3 of s0 = 0
|
||||
self.assertEqual(int(bit_pos), 3)
|
||||
self.assertEqual(bit_val, 0)
|
||||
|
||||
def test_arithmetic(self):
|
||||
r1 = Reg(0x40400000) # 3.0f
|
||||
r2 = Reg(0x40800000) # 4.0f
|
||||
result = r1.f32 + r2.f32
|
||||
self.assertAlmostEqual(result, 7.0)
|
||||
|
||||
def test_comparison(self):
|
||||
r1 = Reg(5)
|
||||
r2 = Reg(3)
|
||||
self.assertTrue(r1.u32 > r2.u32)
|
||||
self.assertFalse(r1.u32 < r2.u32)
|
||||
self.assertTrue(r1.u32 != r2.u32)
|
||||
|
||||
class TestTypedView(unittest.TestCase):
|
||||
def test_slice_read(self):
|
||||
r = Reg(0x56781234)
|
||||
self.assertEqual(r[15:0].u16, 0x1234)
|
||||
self.assertEqual(r[31:16].u16, 0x5678)
|
||||
|
||||
def test_slice_write(self):
|
||||
r = Reg(0)
|
||||
r[15:0].u16 = 0x1234
|
||||
r[31:16].u16 = 0x5678
|
||||
self.assertEqual(r._val, 0x56781234)
|
||||
|
||||
def test_slice_f16(self):
|
||||
r = Reg(0)
|
||||
r[15:0].f16 = 3.0
|
||||
self.assertAlmostEqual(_f16(r._val & 0xffff), 3.0, places=2)
|
||||
|
||||
class TestCompiler(unittest.TestCase):
|
||||
def test_ternary(self):
|
||||
result = _expr("a > b ? 1 : 0")
|
||||
self.assertIn("if", result)
|
||||
self.assertIn("else", result)
|
||||
|
||||
def test_type_prefix_strip(self):
|
||||
self.assertEqual(_expr("1'0U"), "0")
|
||||
self.assertEqual(_expr("32'1"), "1")
|
||||
self.assertEqual(_expr("16'0xFFFF"), "0xFFFF")
|
||||
|
||||
def test_suffix_strip(self):
|
||||
self.assertEqual(_expr("0ULL"), "0")
|
||||
self.assertEqual(_expr("1LL"), "1")
|
||||
self.assertEqual(_expr("5U"), "5")
|
||||
self.assertEqual(_expr("3.14F"), "3.14")
|
||||
|
||||
def test_boolean_ops(self):
|
||||
self.assertIn("and", _expr("a && b"))
|
||||
self.assertIn("or", _expr("a || b"))
|
||||
self.assertIn("!=", _expr("a <> b"))
|
||||
|
||||
def test_pack16(self):
|
||||
result = _expr("{ a, b }")
|
||||
self.assertIn("_pack", result)
|
||||
|
||||
def test_type_cast_strip(self):
|
||||
self.assertEqual(_expr("64'U(x)"), "(x)")
|
||||
self.assertEqual(_expr("32'I(y)"), "(y)")
|
||||
|
||||
class TestExecContext(unittest.TestCase):
|
||||
def test_float_add(self):
|
||||
ctx = ExecContext(s0=0x40400000, s1=0x40800000) # 3.0f, 4.0f
|
||||
ctx.D0.f32 = ctx.S0.f32 + ctx.S1.f32
|
||||
self.assertAlmostEqual(_f32(ctx.D0._val), 7.0)
|
||||
|
||||
def test_float_mul(self):
|
||||
ctx = ExecContext(s0=0x40400000, s1=0x40800000) # 3.0f, 4.0f
|
||||
ctx.run("D0.f32 = S0.f32 * S1.f32")
|
||||
self.assertAlmostEqual(_f32(ctx.D0._val), 12.0)
|
||||
|
||||
def test_scc_comparison(self):
|
||||
ctx = ExecContext(s0=42, s1=42)
|
||||
ctx.run("SCC = S0.u32 == S1.u32")
|
||||
self.assertEqual(ctx.SCC._val, 1)
|
||||
|
||||
def test_scc_comparison_false(self):
|
||||
ctx = ExecContext(s0=42, s1=43)
|
||||
ctx.run("SCC = S0.u32 == S1.u32")
|
||||
self.assertEqual(ctx.SCC._val, 0)
|
||||
|
||||
def test_ternary(self):
|
||||
code = _compile_pseudocode("D0.u32 = S0.u32 > S1.u32 ? 1'1U : 1'0U")
|
||||
ctx = ExecContext(s0=5, s1=3)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 1)
|
||||
|
||||
def test_pack(self):
|
||||
code = _compile_pseudocode("D0 = { S1[15:0].u16, S0[15:0].u16 }")
|
||||
ctx = ExecContext(s0=0x1234, s1=0x5678)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 0x56781234)
|
||||
|
||||
def test_tmp_with_typed_access(self):
|
||||
code = _compile_pseudocode("""tmp = S0.u32 + S1.u32
|
||||
D0.u32 = tmp.u32""")
|
||||
ctx = ExecContext(s0=100, s1=200)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 300)
|
||||
|
||||
def test_s_add_u32_pattern(self):
|
||||
# Real pseudocode pattern from S_ADD_U32
|
||||
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||
D0.u32 = tmp.u32""")
|
||||
# Test overflow case
|
||||
ctx = ExecContext(s0=0xFFFFFFFF, s1=0x00000001)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 0) # Wraps to 0
|
||||
self.assertEqual(ctx.SCC._val, 1) # Carry set
|
||||
|
||||
def test_s_add_u32_no_overflow(self):
|
||||
code = _compile_pseudocode("""tmp = 64'U(S0.u32) + 64'U(S1.u32)
|
||||
SCC = tmp >= 0x100000000ULL ? 1'1U : 1'0U
|
||||
D0.u32 = tmp.u32""")
|
||||
ctx = ExecContext(s0=100, s1=200)
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val, 300)
|
||||
self.assertEqual(ctx.SCC._val, 0) # No carry
|
||||
|
||||
def test_vcc_lane_read(self):
|
||||
ctx = ExecContext(vcc=0b1010, lane=1)
|
||||
# Lane 1 is set
|
||||
self.assertEqual(ctx.VCC.u64[1], 1)
|
||||
self.assertEqual(ctx.VCC.u64[2], 0)
|
||||
|
||||
def test_vcc_lane_write(self):
|
||||
ctx = ExecContext(vcc=0, lane=0)
|
||||
ctx.VCC.u64[3] = 1
|
||||
ctx.VCC.u64[1] = 1
|
||||
self.assertEqual(ctx.VCC._val, 0b1010)
|
||||
|
||||
def test_for_loop(self):
|
||||
# CTZ pattern - find first set bit
|
||||
code = _compile_pseudocode("""tmp = -1
|
||||
for i in 0 : 31 do
|
||||
if S0.u32[i] == 1 then
|
||||
tmp = i
|
||||
endif
|
||||
endfor
|
||||
D0.i32 = tmp""")
|
||||
ctx = ExecContext(s0=0b1000) # Bit 3 is set
|
||||
ctx.run(code)
|
||||
self.assertEqual(ctx.D0._val & MASK32, 3)
|
||||
|
||||
def test_result_dict(self):
|
||||
ctx = ExecContext(s0=5, s1=3)
|
||||
ctx.D0.u32 = 42
|
||||
ctx.SCC._val = 1
|
||||
result = ctx.result()
|
||||
self.assertEqual(result['d0'], 42)
|
||||
self.assertEqual(result['scc'], 1)
|
||||
|
||||
class TestPseudocodeRegressions(unittest.TestCase):
|
||||
"""Regression tests for pseudocode instruction emulation bugs."""
|
||||
|
||||
def test_v_div_scale_f32_vcc_always_returned(self):
|
||||
"""V_DIV_SCALE_F32 must always return VCC, even when VCC=0 (no scaling needed).
|
||||
Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written.
|
||||
This caused division to produce wrong results for multiple lanes."""
|
||||
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
|
||||
s0 = 0x3f800000 # 1.0
|
||||
s1 = 0x40400000 # 3.0
|
||||
s2 = 0x3f800000 # 1.0 (numerator)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None)
|
||||
# Must always have VCC in result
|
||||
self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC")
|
||||
self.assertEqual(result['VCC'] & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
|
||||
|
||||
def test_v_cmp_class_f32_detects_quiet_nan(self):
|
||||
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
|
||||
Bug: isQuietNAN and isSignalNAN both used math.isnan which can't distinguish them."""
|
||||
quiet_nan = 0x7fc00000 # quiet NaN: exponent=255, bit22=1
|
||||
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
|
||||
# Test quiet NaN detection (bit 1 in mask)
|
||||
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
|
||||
self.assertEqual(result['D0'] & 1, 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
# Test signaling NaN detection (bit 0 in mask)
|
||||
s1_signal = 0b0000000001 # bit 0 = signaling NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
|
||||
self.assertEqual(result['D0'] & 1, 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
# Test that quiet NaN doesn't match signaling NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
|
||||
self.assertEqual(result['D0'] & 1, 0, "Quiet NaN should not match signaling NaN mask")
|
||||
# Test that signaling NaN doesn't match quiet NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None)
|
||||
self.assertEqual(result['D0'] & 1, 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def testisNAN_with_typed_view(self):
|
||||
"""isNAN must work with TypedView objects, not just Python floats.
|
||||
Bug: isNAN checked isinstance(x, float) which returned False for TypedView."""
|
||||
nan_reg = Reg(0x7fc00000) # quiet NaN
|
||||
normal_reg = Reg(0x3f800000) # 1.0
|
||||
inf_reg = Reg(0x7f800000) # +inf
|
||||
self.assertTrue(isNAN(nan_reg.f32), "isNAN should return True for NaN TypedView")
|
||||
self.assertFalse(isNAN(normal_reg.f32), "isNAN should return False for normal TypedView")
|
||||
self.assertFalse(isNAN(inf_reg.f32), "isNAN should return False for inf TypedView")
|
||||
|
||||
class TestBF16(unittest.TestCase):
|
||||
"""Tests for BF16 (bfloat16) support."""
|
||||
|
||||
def test_bf16_conversion(self):
|
||||
"""Test bf16 <-> f32 conversion."""
|
||||
# bf16 is just the top 16 bits of f32
|
||||
# 1.0f = 0x3f800000, bf16 = 0x3f80
|
||||
self.assertAlmostEqual(_bf16(0x3f80), 1.0, places=2)
|
||||
self.assertEqual(_ibf16(1.0), 0x3f80)
|
||||
# 2.0f = 0x40000000, bf16 = 0x4000
|
||||
self.assertAlmostEqual(_bf16(0x4000), 2.0, places=2)
|
||||
self.assertEqual(_ibf16(2.0), 0x4000)
|
||||
# -1.0f = 0xbf800000, bf16 = 0xbf80
|
||||
self.assertAlmostEqual(_bf16(0xbf80), -1.0, places=2)
|
||||
self.assertEqual(_ibf16(-1.0), 0xbf80)
|
||||
|
||||
def test_bf16_special_values(self):
|
||||
"""Test bf16 special values (inf, nan)."""
|
||||
import math
|
||||
# +inf: f32 = 0x7f800000, bf16 = 0x7f80
|
||||
self.assertTrue(math.isinf(_bf16(0x7f80)))
|
||||
self.assertEqual(_ibf16(float('inf')), 0x7f80)
|
||||
# -inf: f32 = 0xff800000, bf16 = 0xff80
|
||||
self.assertTrue(math.isinf(_bf16(0xff80)))
|
||||
self.assertEqual(_ibf16(float('-inf')), 0xff80)
|
||||
# NaN: quiet NaN bf16 = 0x7fc0
|
||||
self.assertTrue(math.isnan(_bf16(0x7fc0)))
|
||||
self.assertEqual(_ibf16(float('nan')), 0x7fc0)
|
||||
|
||||
def test_bf16_register_property(self):
|
||||
"""Test Reg.bf16 property."""
|
||||
r = Reg(0)
|
||||
r.bf16 = 3.0 # 3.0f = 0x40400000, bf16 = 0x4040
|
||||
self.assertEqual(r._val & 0xffff, 0x4040)
|
||||
self.assertAlmostEqual(float(r.bf16), 3.0, places=1)
|
||||
|
||||
def test_bf16_slice_property(self):
|
||||
"""Test TypedView.bf16 property."""
|
||||
r = Reg(0x40404040) # Two bf16 3.0 values
|
||||
self.assertAlmostEqual(r[15:0].bf16, 3.0, places=1)
|
||||
self.assertAlmostEqual(r[31:16].bf16, 3.0, places=1)
|
||||
|
||||
class TestBytePermute(unittest.TestCase):
|
||||
"""Tests for BYTE_PERMUTE helper function (V_PERM_B32)."""
|
||||
|
||||
def test_byte_select_0_to_7(self):
|
||||
"""Test selecting bytes 0-7 from 64-bit data."""
|
||||
# data = {s0, s1} where s0 is bytes 0-3, s1 is bytes 4-7
|
||||
# Combined: 0x0706050403020100 (byte 0 = 0x00, byte 7 = 0x07)
|
||||
data = 0x0706050403020100
|
||||
for i in range(8):
|
||||
self.assertEqual(BYTE_PERMUTE(data, i), i, f"byte {i} should be {i}")
|
||||
|
||||
def test_sign_extend_bytes(self):
|
||||
"""Test sign extension selectors 8-11."""
|
||||
# sel 8: sign of byte 1 (bits 15:8)
|
||||
# sel 9: sign of byte 3 (bits 31:24)
|
||||
# sel 10: sign of byte 5 (bits 47:40)
|
||||
# sel 11: sign of byte 7 (bits 63:56)
|
||||
data = 0x8000800080008000 # All relevant bytes have sign bit set
|
||||
self.assertEqual(BYTE_PERMUTE(data, 8), 0xff)
|
||||
self.assertEqual(BYTE_PERMUTE(data, 9), 0xff)
|
||||
self.assertEqual(BYTE_PERMUTE(data, 10), 0xff)
|
||||
self.assertEqual(BYTE_PERMUTE(data, 11), 0xff)
|
||||
data = 0x7f007f007f007f00 # No sign bits set
|
||||
self.assertEqual(BYTE_PERMUTE(data, 8), 0x00)
|
||||
self.assertEqual(BYTE_PERMUTE(data, 9), 0x00)
|
||||
self.assertEqual(BYTE_PERMUTE(data, 10), 0x00)
|
||||
self.assertEqual(BYTE_PERMUTE(data, 11), 0x00)
|
||||
|
||||
def test_constant_zero(self):
|
||||
"""Test selector 12 returns 0x00."""
|
||||
self.assertEqual(BYTE_PERMUTE(0xffffffffffffffff, 12), 0x00)
|
||||
|
||||
def test_constant_ff(self):
|
||||
"""Test selectors >= 13 return 0xFF."""
|
||||
for sel in [13, 14, 15, 255]:
|
||||
self.assertEqual(BYTE_PERMUTE(0, sel), 0xff, f"sel {sel} should be 0xff")
|
||||
|
||||
class TestSADHelpers(unittest.TestCase):
|
||||
"""Tests for V_SAD_U8 and V_MSAD_U8 helper functions."""
|
||||
|
||||
def test_v_sad_u8_basic(self):
|
||||
"""Test v_sad_u8 with simple values."""
|
||||
# s0 = 0x04030201, s1 = 0x04030201 -> diff = 0 for all bytes
|
||||
result = v_sad_u8(0x04030201, 0x04030201, 0)
|
||||
self.assertEqual(result, 0)
|
||||
# s0 = 0x05040302, s1 = 0x04030201 -> diff = 1+1+1+1 = 4
|
||||
result = v_sad_u8(0x05040302, 0x04030201, 0)
|
||||
self.assertEqual(result, 4)
|
||||
|
||||
def test_v_sad_u8_with_accumulator(self):
|
||||
"""Test v_sad_u8 with non-zero accumulator."""
|
||||
# s0 = 0x05040302, s1 = 0x04030201, s2 = 100 -> 4 + 100 = 104
|
||||
result = v_sad_u8(0x05040302, 0x04030201, 100)
|
||||
self.assertEqual(result, 104)
|
||||
|
||||
def test_v_sad_u8_large_diff(self):
|
||||
"""Test v_sad_u8 with maximum byte differences."""
|
||||
# s0 = 0xffffffff, s1 = 0x00000000 -> diff = 255*4 = 1020
|
||||
result = v_sad_u8(0xffffffff, 0x00000000, 0)
|
||||
self.assertEqual(result, 1020)
|
||||
|
||||
def test_v_msad_u8_basic(self):
|
||||
"""Test v_msad_u8 masks when reference byte is 0."""
|
||||
# s0 = 0x10101010, s1 = 0x00000000 -> all masked, result = 0
|
||||
result = v_msad_u8(0x10101010, 0x00000000, 0)
|
||||
self.assertEqual(result, 0)
|
||||
# s0 = 0x10101010, s1 = 0x01010101 -> diff = |0x10-0x01|*4 = 15*4 = 60
|
||||
result = v_msad_u8(0x10101010, 0x01010101, 0)
|
||||
self.assertEqual(result, 60)
|
||||
|
||||
def test_v_msad_u8_partial_mask(self):
|
||||
"""Test v_msad_u8 with partial masking."""
|
||||
# s0 = 0x10101010, s1 = 0x00010001 -> bytes 1 and 3 masked
|
||||
# diff = |0x10-0x01| + |0x10-0x01| = 15 + 15 = 30
|
||||
result = v_msad_u8(0x10101010, 0x00010001, 0)
|
||||
self.assertEqual(result, 30)
|
||||
|
||||
def test_v_msad_u8_with_accumulator(self):
|
||||
"""Test v_msad_u8 with non-zero accumulator."""
|
||||
result = v_msad_u8(0x10101010, 0x01010101, 50)
|
||||
self.assertEqual(result, 110) # 60 + 50
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -47,10 +47,14 @@ INPUT_VARS = {
|
||||
'OFFSET1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OFFSET1', 0, 0xff)),
|
||||
'OPSEL': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL', 0, 7)),
|
||||
'OPSEL_HI': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('OPSEL_HI', 0, 7)),
|
||||
'SRC0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('SRC0', 0, 0xffffffff)), # Source register index
|
||||
'VDST': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('VDST', 0, 0xffffffff)), # Dest register index (for writelane)
|
||||
'M0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('M0', 0, 0xffffffff)), # M0 register
|
||||
}
|
||||
|
||||
MEM_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint8.ptr(addrspace=AddrSpace.GLOBAL), arg=0)
|
||||
LDS_BUF = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(addrspace=AddrSpace.LOCAL), arg=0)
|
||||
VGPR_BUF = UOp(Ops.DEFINE_GLOBAL, dtypes.uint32.ptr(addrspace=AddrSpace.GLOBAL), arg=1) # VGPR[lane][reg] as flat array
|
||||
|
||||
class Ctx:
|
||||
def __init__(self, mem_buf: UOp = MEM_BUF):
|
||||
@@ -85,6 +89,25 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
if name == 'NAN.f32': return UOp.const(dtypes.float32, float('nan'))
|
||||
if name in ('VCCZ', 'EXECZ'):
|
||||
return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32)
|
||||
if name == 'EXEC_LO':
|
||||
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 0xffffffff))), hint or dtypes.uint32)
|
||||
if name == 'EXEC_HI':
|
||||
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 32))), hint or dtypes.uint32)
|
||||
if name == 'VCC_LO':
|
||||
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 0xffffffff))), hint or dtypes.uint32)
|
||||
if name == 'VCC_HI':
|
||||
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 32))), hint or dtypes.uint32)
|
||||
if name == 'laneID' or name == 'laneId':
|
||||
return ctx.vars.get('laneId', UOp.const(dtypes.uint32, 0))
|
||||
if name == 'ThreadMask':
|
||||
# ThreadMask is the same as EXEC for wave32
|
||||
return _cast(ctx.vars.get('EXEC'), hint or dtypes.uint32)
|
||||
if name == 'DST':
|
||||
# DST is the raw destination register index from the instruction
|
||||
return ctx.vars.get('VDST', UOp.const(dtypes.uint32, 0))
|
||||
if name == 'LDS':
|
||||
# LDS is the local data share memory - treat as memory buffer
|
||||
return UOp.const(dtypes.uint64, 0) # Base address placeholder
|
||||
if name.startswith('eval '): return ctx.vars.get('_eval', UOp.const(dtypes.uint32, 0))
|
||||
if name not in ctx.vars: raise ValueError(f"Unknown variable: {name}")
|
||||
return _cast(ctx.vars[name], hint or ctx.vars[name].dtype)
|
||||
@@ -109,6 +132,18 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
return UOp.const(dt, denorm)
|
||||
if name in ('VCCZ', 'EXECZ'):
|
||||
return _cast(UOp(Ops.CMPEQ, dtypes.bool, (ctx.vars.get('VCC' if name == 'VCCZ' else 'EXEC'), UOp.const(dtypes.uint64, 0))), dtypes.uint32)
|
||||
if name == 'EXEC_LO':
|
||||
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 0xffffffff))), dt)
|
||||
if name == 'EXEC_HI':
|
||||
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('EXEC'), UOp.const(dtypes.uint64, 32))), dt)
|
||||
if name == 'VCC_LO':
|
||||
return _cast(UOp(Ops.AND, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 0xffffffff))), dt)
|
||||
if name == 'VCC_HI':
|
||||
return _cast(UOp(Ops.SHR, dtypes.uint64, (ctx.vars.get('VCC'), UOp.const(dtypes.uint64, 32))), dt)
|
||||
if name == 'DST':
|
||||
return _cast(ctx.vars.get('VDST', UOp.const(dtypes.uint32, 0)), dt)
|
||||
if name == 'laneID' or name == 'laneId':
|
||||
return _cast(ctx.vars.get('laneId', UOp.const(dtypes.uint32, 0)), dt)
|
||||
if name.startswith('WAVE_STATUS.COND_DBG'): return UOp.const(dtypes.uint32, 0)
|
||||
vn = name + '_64' if dt.itemsize == 8 and name.isupper() else name
|
||||
base = ctx.vars.get(vn) if vn in ctx.vars else ctx.vars.get(name)
|
||||
@@ -132,6 +167,20 @@ def _expr(node: UOp, ctx: Ctx, hint: DType = None) -> UOp:
|
||||
return _cast(inner_resolved, dt)
|
||||
|
||||
case UOp(Ops.CUSTOMI, _, (base_expr, hi_expr, lo_expr)): # Slice or array access
|
||||
# Check for VGPR[lane][reg] access pattern (nested CUSTOMI where inner base is VGPR)
|
||||
if base_expr.op == Ops.CUSTOMI and hi_expr is lo_expr:
|
||||
inner_base, inner_idx, _ = base_expr.src
|
||||
if inner_base.op == Ops.DEFINE_VAR and inner_base.arg[0] == 'VGPR':
|
||||
# VGPR[lane][reg] -> load from VGPR buffer at index (lane * 256 + reg)
|
||||
lane_uop = _expr(inner_idx, ctx, dtypes.uint32)
|
||||
reg_uop = _expr(hi_expr, ctx, dtypes.uint32)
|
||||
# Compute flat index: lane * 256 + reg (256 VGPRs per lane)
|
||||
idx = UOp(Ops.ADD, dtypes.uint32, (UOp(Ops.MUL, dtypes.uint32, (lane_uop, UOp.const(dtypes.uint32, 256))), reg_uop))
|
||||
return UOp(Ops.CUSTOM, dtypes.uint32, (idx,), arg='vgpr_read')
|
||||
# Check for SGPR[idx] access pattern (scalar register file access)
|
||||
if base_expr.op == Ops.DEFINE_VAR and base_expr.arg[0] == 'SGPR' and hi_expr is lo_expr:
|
||||
idx_uop = _expr(hi_expr, ctx, dtypes.uint32)
|
||||
return UOp(Ops.CUSTOM, dtypes.uint32, (idx_uop,), arg='sgpr_read')
|
||||
# Check for array element access first: arr[idx] where arr is a vector type
|
||||
if base_expr.op == Ops.DEFINE_VAR and base_expr.arg[1] is None and hi_expr is lo_expr:
|
||||
name = base_expr.arg[0]
|
||||
@@ -377,6 +426,14 @@ def _call_trig_preop_result(shift):
|
||||
# Returns CUSTOM op that gets evaluated at runtime with the 1201-bit constant
|
||||
return UOp(Ops.CUSTOM, dtypes.float64, (shift,), arg='trig_preop_result')
|
||||
|
||||
def _call_s_ff1_i32_b32(v):
|
||||
# Find first 1 bit (count trailing zeros) - returns CUSTOM op evaluated at runtime
|
||||
return UOp(Ops.CUSTOM, dtypes.int32, (_cast(v, dtypes.uint32),), arg='s_ff1_i32_b32')
|
||||
|
||||
def _call_s_ff1_i32_b64(v):
|
||||
# Find first 1 bit in 64-bit value (count trailing zeros) - returns CUSTOM op evaluated at runtime
|
||||
return UOp(Ops.CUSTOM, dtypes.int32, (_cast(v, dtypes.uint64),), arg='s_ff1_i32_b64')
|
||||
|
||||
CALL_DISPATCH = {
|
||||
'MEM': _call_MEM, 'fma': _call_fma, 'abs': _call_abs, 'cos': _call_cos, 'rsqrt': _call_rsqrt,
|
||||
'clamp': _call_clamp, 'floor': _call_floor, 'fract': _call_fract, 'isNAN': _call_isNAN, 'isQuietNAN': _call_isQuietNAN,
|
||||
@@ -384,6 +441,9 @@ CALL_DISPATCH = {
|
||||
'sign': _call_sign, 'exponent': _call_exponent, 'mantissa': _call_mantissa, 'isEven': _call_isEven,
|
||||
'signext': _call_signext, 'signext_from_bit': _call_signext_from_bit, 'ABSDIFF': _call_ABSDIFF, 'SAT8': _call_SAT8,
|
||||
'BYTE_PERMUTE': _call_BYTE_PERMUTE, 'bf16_to_f32': _call_bf16_to_f32, 'trig_preop_result': _call_trig_preop_result,
|
||||
's_ff1_i32_b32': _call_s_ff1_i32_b32, 's_ff1_i32_b64': _call_s_ff1_i32_b64,
|
||||
'u8_to_u32': lambda v: _cast(UOp(Ops.AND, dtypes.uint32, (_cast(v, dtypes.uint32), UOp.const(dtypes.uint32, 0xff))), dtypes.uint32),
|
||||
'u4_to_u32': lambda v: _cast(UOp(Ops.AND, dtypes.uint32, (_cast(v, dtypes.uint32), UOp.const(dtypes.uint32, 0xf))), dtypes.uint32),
|
||||
}
|
||||
|
||||
def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
@@ -432,40 +492,53 @@ def _transform_call(name: str, a: list[UOp], hint: DType) -> UOp:
|
||||
return result
|
||||
raise ValueError(f"Unknown function: {name}")
|
||||
|
||||
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, str|None, int|None]:
|
||||
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var, array_idx)"""
|
||||
def _get_lhs_info(lhs: UOp, ctx: Ctx) -> tuple[str, DType, int|None, int|None, UOp|str|None, int|None, UOp|None]:
|
||||
"""Extract assignment target: (var_name, dtype, hi_bit, lo_bit, idx_var, array_idx, dynamic_idx_uop)
|
||||
dynamic_idx_uop is set when the bit index is a runtime expression (not constant or simple variable)"""
|
||||
match lhs:
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None, None
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)): return name, dt, None, None, None, None, None
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))),)):
|
||||
return name, dt, int(hi), int(lo), None, None
|
||||
return name, dt, int(hi), int(lo), None, None, None
|
||||
case UOp(Ops.BITCAST, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
|
||||
return name, dtypes.uint64, None, None, idx, None
|
||||
return name, dtypes.uint64, None, None, idx, None, None
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)),)) if lhs.src[0].src[1] is lhs.src[0].src[2]:
|
||||
return name, dt, None, None, idx, None
|
||||
return name, dt, None, None, idx, None, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
|
||||
return name, dtypes.uint32, int(hi), int(lo), None, None
|
||||
return name, dtypes.uint32, int(hi), int(lo), None, None, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, idx), _)) if lhs.src[1] is lhs.src[2]:
|
||||
# Check if this is array element access (variable is a vector type)
|
||||
var_dtype = ctx.decls.get(name)
|
||||
if var_dtype is not None and var_dtype.count > 1:
|
||||
return name, var_dtype.scalar(), None, None, None, int(idx)
|
||||
return name, dtypes.uint32, int(idx), int(idx), None, None
|
||||
return name, var_dtype.scalar(), None, None, None, int(idx), None
|
||||
return name, dtypes.uint32, int(idx), int(idx), None, None, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.CONST, _, _, hi), UOp(Ops.CONST, _, _, lo))):
|
||||
return name, dtypes.uint32, int(hi), int(lo), None, None
|
||||
return name, dtypes.uint32, int(hi), int(lo), None, None, None
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
|
||||
return name, dt, None, None, idx, None
|
||||
return name, dt, None, None, idx, None, None
|
||||
# Handle arr[i] where i is a variable - check if it's array element or bit index
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)), UOp(Ops.DEFINE_VAR, _, _, (idx, None, None)), idx2)) if lhs.src[1] is lhs.src[2]:
|
||||
var_dtype = ctx.decls.get(name)
|
||||
if var_dtype is not None and var_dtype.count > 1:
|
||||
# Array element access with variable index
|
||||
return name, var_dtype.scalar(), None, None, None, idx # Return idx as variable name for array_idx
|
||||
return name, dtypes.uint32, None, None, idx, None
|
||||
return name, var_dtype.scalar(), None, None, None, idx, None # Return idx as variable name for array_idx
|
||||
return name, dtypes.uint32, None, None, idx, None, None
|
||||
# Handle D0.u32[expr] where expr is a complex expression (dynamic bit index)
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.BITCAST, dt, (UOp(Ops.DEFINE_VAR, _, _, (name, None, None)),)), idx_expr, idx_expr2)) if lhs.src[1] is lhs.src[2]:
|
||||
return name, dt, None, None, None, None, idx_expr # Return expression as dynamic_idx_uop
|
||||
# Handle VGPR[lane][reg] = value (nested CUSTOMI for VGPR write)
|
||||
case UOp(Ops.CUSTOMI, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, ('VGPR', None, None)), lane_expr, _)), reg_expr, _)):
|
||||
return 'VGPR', dtypes.uint32, None, None, None, None, (lane_expr, reg_expr) # Return tuple for VGPR write
|
||||
# Handle VGPR[laneId][addr].type = value (with BITCAST wrapper)
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, ('VGPR', None, None)), lane_expr, _)), reg_expr, _)),)):
|
||||
return 'VGPR', dt, None, None, None, None, (lane_expr, reg_expr) # Return tuple for VGPR write
|
||||
# Handle SGPR[addr].type = value (scalar register write with BITCAST)
|
||||
case UOp(Ops.BITCAST, dt, (UOp(Ops.CUSTOMI, _, (UOp(Ops.DEFINE_VAR, _, _, ('SGPR', None, None)), reg_expr, _)),)):
|
||||
return 'SGPR', dt, None, None, None, None, reg_expr # Return expr for SGPR write
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, None, None)):
|
||||
# If the variable already exists, use its dtype; otherwise default to uint32
|
||||
existing = ctx.vars.get(name)
|
||||
dtype = existing.dtype if existing is not None else dtypes.uint32
|
||||
return name, dtype, None, None, None, None
|
||||
return name, dtype, None, None, None, None, None
|
||||
raise ValueError(f"Cannot parse LHS: {lhs}")
|
||||
|
||||
def _stmt(stmt, ctx: Ctx):
|
||||
@@ -511,9 +584,47 @@ def _stmt(stmt, ctx: Ctx):
|
||||
offset += bits
|
||||
return
|
||||
|
||||
var, dtype, hi, lo, idx_var, array_idx = _get_lhs_info(lhs, ctx)
|
||||
var, dtype, hi, lo, idx_var, array_idx, dynamic_idx = _get_lhs_info(lhs, ctx)
|
||||
out_vars = ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC', 'SDATA', 'VDATA', 'RETURN_DATA')
|
||||
|
||||
# Handle VGPR write: VGPR[lane][reg] = value
|
||||
if var == 'VGPR' and isinstance(dynamic_idx, tuple):
|
||||
lane_expr, reg_expr = dynamic_idx
|
||||
lane_uop = _expr(lane_expr, ctx, dtypes.uint32)
|
||||
reg_uop = _expr(reg_expr, ctx, dtypes.uint32)
|
||||
val_uop = _expr(rhs, ctx, dtype)
|
||||
# Compute flat index: lane * 256 + reg
|
||||
idx = UOp(Ops.ADD, dtypes.uint32, (UOp(Ops.MUL, dtypes.uint32, (lane_uop, UOp.const(dtypes.uint32, 256))), reg_uop))
|
||||
ctx.outputs.append(('VGPR_WRITE', UOp(Ops.CUSTOM, dtypes.uint32, (idx, _cast(val_uop, dtypes.uint32)), arg='vgpr_write'), dtypes.uint32))
|
||||
return
|
||||
|
||||
# Handle SGPR write: SGPR[reg] = value
|
||||
if var == 'SGPR' and dynamic_idx is not None and not isinstance(dynamic_idx, tuple):
|
||||
reg_uop = _expr(dynamic_idx, ctx, dtypes.uint32)
|
||||
val_uop = _expr(rhs, ctx, dtype)
|
||||
ctx.outputs.append(('SGPR_WRITE', UOp(Ops.CUSTOM, dtypes.uint32, (reg_uop, _cast(val_uop, dtypes.uint32)), arg='sgpr_write'), dtypes.uint32))
|
||||
return
|
||||
|
||||
# Handle dynamic bit index: D0.u32[expr] = value (where expr is runtime expression)
|
||||
if dynamic_idx is not None and not isinstance(dynamic_idx, tuple):
|
||||
idx_uop = _expr(dynamic_idx, ctx, dtypes.uint32)
|
||||
rhs_uop = _expr(rhs, ctx, dtypes.uint32)
|
||||
op_dt = dtype if dtype.itemsize >= 4 else dtypes.uint32
|
||||
if dtype.itemsize == 8: op_dt = dtypes.uint64
|
||||
base = ctx.vars.get(var, UOp.const(op_dt, 0))
|
||||
if base.dtype != op_dt: base = _cast(base, op_dt)
|
||||
# Set single bit at dynamic index: base = (base & ~(1 << idx)) | ((val & 1) << idx)
|
||||
one = UOp.const(op_dt, 1)
|
||||
bit_mask = UOp(Ops.SHL, op_dt, (one, _cast(idx_uop, op_dt)))
|
||||
inv_mask = UOp(Ops.XOR, op_dt, (bit_mask, UOp.const(op_dt, -1)))
|
||||
val_bit = UOp(Ops.SHL, op_dt, (UOp(Ops.AND, op_dt, (_cast(rhs_uop, op_dt), one)), _cast(idx_uop, op_dt)))
|
||||
result = UOp(Ops.OR, op_dt, (UOp(Ops.AND, op_dt, (base, inv_mask)), val_bit))
|
||||
ctx.vars[var] = result
|
||||
if var in out_vars:
|
||||
ctx.outputs = [(n, u, d) for n, u, d in ctx.outputs if n != var]
|
||||
ctx.outputs.append((var, result, op_dt))
|
||||
return
|
||||
|
||||
# Handle array element assignment: arr[idx] = value
|
||||
if array_idx is not None:
|
||||
var_dtype = ctx.decls.get(var)
|
||||
@@ -717,6 +828,28 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
shifted = (TWO_OVER_PI_1201 << int(shift)) >> (1201 - 53)
|
||||
mantissa = shifted & 0x1fffffffffffff
|
||||
return float(mantissa)
|
||||
if u.op == Ops.CUSTOM and u.arg == 's_ff1_i32_b32':
|
||||
# Find first 1 bit (count trailing zeros) in 32-bit value
|
||||
v = _eval_uop(u.src[0])
|
||||
if v is None: return None
|
||||
v = int(v) & 0xffffffff
|
||||
if v == 0: return 32
|
||||
n = 0
|
||||
while (v & 1) == 0: v >>= 1; n += 1
|
||||
return n
|
||||
if u.op == Ops.CUSTOM and u.arg == 's_ff1_i32_b64':
|
||||
# Find first 1 bit (count trailing zeros) in 64-bit value
|
||||
v = _eval_uop(u.src[0])
|
||||
if v is None: return None
|
||||
v = int(v) & 0xffffffffffffffff
|
||||
if v == 0: return 64
|
||||
n = 0
|
||||
while (v & 1) == 0: v >>= 1; n += 1
|
||||
return n
|
||||
if u.op == Ops.CUSTOM and u.arg == 'vgpr_read':
|
||||
# VGPR read - returns CUSTOM that will be resolved with VGPR data at runtime
|
||||
# This can't be evaluated statically - needs VGPR substitution
|
||||
return None
|
||||
return None
|
||||
|
||||
def _extract_results(s, MEM=None):
|
||||
@@ -790,8 +923,20 @@ def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[s
|
||||
input_vars['SIMM16']: UOp.const(dtypes.int32, simm16), input_vars['SIMM32']: UOp.const(dtypes.uint32, literal or 0),
|
||||
input_vars['PC']: UOp.const(dtypes.uint64, pc or 0),
|
||||
input_vars['OPSEL']: UOp.const(dtypes.uint32, opsel), input_vars['OPSEL_HI']: UOp.const(dtypes.uint32, opsel_hi),
|
||||
input_vars['SRC0']: UOp.const(dtypes.uint32, src0_idx),
|
||||
}
|
||||
return _extract_results(sink.substitute(dvars).simplify())
|
||||
s1_sub = sink.substitute(dvars).simplify()
|
||||
# Handle VGPR reads - substitute vgpr_read CUSTOM ops with actual values
|
||||
if VGPR is not None:
|
||||
vgpr_subs = {}
|
||||
for u in s1_sub.toposort():
|
||||
if u.op == Ops.CUSTOM and u.arg == 'vgpr_read':
|
||||
idx = _eval_uop(u.src[0])
|
||||
if idx is not None:
|
||||
lane, reg = int(idx) // 256, int(idx) % 256
|
||||
vgpr_subs[u] = UOp.const(dtypes.uint32, VGPR[lane][reg] if lane < len(VGPR) and reg < len(VGPR[lane]) else 0)
|
||||
if vgpr_subs: s1_sub = s1_sub.substitute(vgpr_subs).simplify()
|
||||
return _extract_results(s1_sub)
|
||||
return fn
|
||||
|
||||
# Ops that need Python exec features (inline conditionals, complex PDF fixes) - fall back to pcode.py
|
||||
|
||||
Reference in New Issue
Block a user