Files
tinygrad/test/mockgpu/amd/pcode.py
qazal cf6a429aaa mypy emulator pre-commit passing (#15379)
* fix dict stuff

* add type: ignores

* fix pcode to put uops not ints
2026-03-20 14:44:09 +09:00

1338 lines
68 KiB
Python

# Tokenizer-based expression parser for AMD pcode
from typing import Any, Callable
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, UOp
from tinygrad.uop.decompositions import f2f
# Type alias for vars dict: stores UOps for variables and tuples for lambda definitions
VarVal = UOp | tuple[str, list[str], str]
def _const(dt, v): return UOp.const(dt, v)
def _u32(v): return _const(dtypes.uint32, v)
def _u64(v): return _const(dtypes.uint64, v)
def _to_u32(v): return v if v.dtype == dtypes.uint32 else v.bitcast(dtypes.uint32) if v.dtype.itemsize == 4 else v.cast(dtypes.uint32)
def _to_bool(v): return v if v.dtype == dtypes.bool else v.ne(_const(v.dtype, 0))
def _cast_to(v, dt):
if v.dtype == dt: return v
if dt == dtypes.half: return v.cast(dtypes.uint16).bitcast(dtypes.half)
return v.cast(dt) if dt.itemsize != v.dtype.itemsize else v.bitcast(dt)
# Float bit extraction - returns (bits, exp_mask, mant_mask, quiet_bit, exp_shift) based on float type
def _float_info(v: UOp) -> tuple[UOp, UOp, UOp, UOp, int]:
if v.dtype in (dtypes.float64, dtypes.uint64):
bits = v.bitcast(dtypes.uint64) if v.dtype == dtypes.float64 else v.cast(dtypes.uint64)
return bits, _u64(0x7FF0000000000000), _u64(0x000FFFFFFFFFFFFF), _u64(0x0008000000000000), 52
if v.dtype in (dtypes.half, dtypes.uint16):
bits = (v.bitcast(dtypes.uint16) if v.dtype == dtypes.half else (v & _u32(0xFFFF)).cast(dtypes.uint16)).cast(dtypes.uint32)
return bits, _u32(0x7C00), _u32(0x03FF), _u32(0x0200), 10
bits = v.bitcast(dtypes.uint32) if v.dtype == dtypes.float32 else v.cast(dtypes.uint32)
return bits, _u32(0x7F800000), _u32(0x007FFFFF), _u32(0x00400000), 23
def _isnan(v: UOp) -> UOp:
bits, exp_m, mant_m, _, _ = _float_info(v.cast(dtypes.float32) if v.dtype == dtypes.half else v)
return (bits & exp_m).eq(exp_m) & (bits & mant_m).ne(_const(bits.dtype, 0))
def _bitreverse(v: UOp, bits: int) -> UOp:
dt, masks = (dtypes.uint64, [(0x5555555555555555,1),(0x3333333333333333,2),(0x0F0F0F0F0F0F0F0F,4),(0x00FF00FF00FF00FF,8),(0x0000FFFF0000FFFF,16)]) \
if bits == 64 else (dtypes.uint32, [(0x55555555,1),(0x33333333,2),(0x0F0F0F0F,4),(0x00FF00FF,8)])
v = v.cast(dt) if v.dtype != dt else v
for m, s in masks: v = ((v >> _const(dt, s)) & _const(dt, m)) | ((v & _const(dt, m)) << _const(dt, s))
return (v >> _const(dt, 32 if bits == 64 else 16)) | (v << _const(dt, 32 if bits == 64 else 16))
def _extract_bits(val: UOp, hi: int, lo: int) -> UOp:
dt = dtypes.uint64 if val.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
width = hi - lo + 1
# Cast to dt first to ensure shift operands have matching types
val_cast = val.cast(dt) if val.dtype != dt else val
result = ((val_cast >> _const(dt, lo)) if lo > 0 else val_cast) & _const(dt, (1 << width) - 1)
# Downcast to match extracted bit width so brace-concat { hi, lo } computes correct output dtype
target_dt = _BITS_DT.get(width) or (dtypes.uint32 if width <= 32 else dtypes.uint64 if width <= 64 else dt)
if result.dtype != target_dt: result = result.cast(target_dt)
return result
def _set_bit(old, pos, val):
mask = _u32(1) << pos
return (old & (mask ^ _u32(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & _u32(1)) << pos)
def _val_to_bits(val):
if val.dtype == dtypes.half: return val.bitcast(dtypes.uint16).cast(dtypes.uint32)
if val.dtype == dtypes.float32: return val.bitcast(dtypes.uint32)
if val.dtype == dtypes.float64: return val.bitcast(dtypes.uint64)
return val if val.dtype == dtypes.uint32 else val.cast(dtypes.uint32)
def _floor(x):
t = UOp(Ops.TRUNC, x.dtype, (x,))
return ((x < _const(x.dtype, 0)) & x.ne(t)).where(t - _const(x.dtype, 1), t)
def _f16_extract(v): return (v & _u32(0xFFFF)).cast(dtypes.uint16).bitcast(dtypes.half) if v.dtype == dtypes.uint32 else v
# ═════ FP8 (E4M3) and BF8 (E5M2) conversion helpers ═════
# f32→fp8/bf8 uses f2f decomposition directly. fp8/bf8→f32 wraps f2f with subnormal handling
# (f2f flushes denormals to zero, but AMD V_CVT_F32_FP8/BF8 preserves subnormals).
def _fp8_to_f32(v: UOp) -> UOp:
b = (v.cast(dtypes.uint32) & _u32(0xFF)).cast(dtypes.uint8)
# E4M3 subnormal: exp==0, mant!=0 -> (-1)^sign * 2^(1-7) * (mant/8) = (-1)^sign * mant * 2^(-9)
bu = b.cast(dtypes.uint32)
sign, exp, mant = (bu >> _u32(7)) << _u32(31), (bu >> _u32(3)) & _u32(0xF), bu & _u32(0x7)
is_sub = exp.eq(_u32(0)) & mant.ne(_u32(0))
sub_f32 = (mant.cast(dtypes.float32) * _const(dtypes.float32, 1.0/512.0)).bitcast(dtypes.uint32) | sign
normal = f2f(b, dtypes.fp8e4m3, dtypes.float32)
return is_sub.where(sub_f32.bitcast(dtypes.float32), normal)
def _bf8_to_f32(v: UOp) -> UOp:
b = (v.cast(dtypes.uint32) & _u32(0xFF)).cast(dtypes.uint8)
# E5M2 subnormal: exp==0, mant!=0 -> (-1)^sign * 2^(1-15) * (mant/4) = (-1)^sign * mant * 2^(-16)
bu = b.cast(dtypes.uint32)
sign, exp, mant = (bu >> _u32(7)) << _u32(31), (bu >> _u32(2)) & _u32(0x1F), bu & _u32(0x3)
is_sub = exp.eq(_u32(0)) & mant.ne(_u32(0))
sub_f32 = (mant.cast(dtypes.float32) * _const(dtypes.float32, 1.0/65536.0)).bitcast(dtypes.uint32) | sign
normal = f2f(b, dtypes.fp8e5m2, dtypes.float32)
return is_sub.where(sub_f32.bitcast(dtypes.float32), normal)
def _f32_to_fp8(v: UOp) -> UOp:
return f2f((v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32), dtypes.float32, dtypes.fp8e4m3)
def _f32_to_bf8(v: UOp) -> UOp:
return f2f((v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32), dtypes.float32, dtypes.fp8e5m2)
def _f32_to_bf16(v: UOp) -> UOp:
"""Convert f32 to bf16 with round-to-nearest-even. BF16 is the upper 16 bits of F32 with rounding."""
bits = (v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32)
# Round-to-nearest-even: add rounding bias. If the bit just below the truncation point is 1 and the rest are 0, round to even.
round_bit = (bits >> _u32(16)) & _u32(1) # bit 16 (LSB of kept part)
rounding = _u32(0x7FFF) + round_bit # 0x7FFF + bit16: rounds to even
rounded = bits + rounding
return (rounded >> _u32(16)).cast(dtypes.uint16)
def _f32_to_bf16_sr(v: UOp, stoch: UOp) -> UOp:
"""Convert f32 to bf16 with stochastic rounding."""
bits = (v.bitcast(dtypes.float32) if v.dtype != dtypes.float32 else v).bitcast(dtypes.uint32)
# Stochastic rounding: add lower 16 bits of stochastic value to lower 16 bits of f32
rounded = bits + (stoch & _u32(0xFFFF))
return (rounded >> _u32(16)).cast(dtypes.uint16)
def _check_nan(v: UOp, quiet: bool) -> UOp:
if v.op == Ops.CAST and v.dtype == dtypes.float64: v = v.src[0]
bits, exp_m, mant_m, qb, _ = _float_info(v)
is_nan_exp, has_mant, is_q = (bits & exp_m).eq(exp_m), (bits & mant_m).ne(_const(bits.dtype, 0)), (bits & qb).ne(_const(bits.dtype, 0))
return (is_nan_exp & is_q) if quiet else (is_nan_exp & has_mant & is_q.logical_not())
def _minmax_reduce(is_max: bool, dt, *args: UOp) -> UOp:
def cast(v: UOp) -> UOp: return v.bitcast(dt) if dt == dtypes.float32 and v.dtype == dtypes.uint32 else v.cast(dt)
def minmax(a: UOp, b: UOp) -> UOp:
if dt in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64): return (a > b).where(a, b) if is_max else (a < b).where(a, b)
return a.maximum(b) if is_max else a.minimum(b)
result = cast(args[0])
for a in args[1:]:
b = cast(a)
if dt == dtypes.float32: result = _isnan(result).where(b, _isnan(b).where(result, minmax(result, b)))
else: result = minmax(result, b)
return result
def _find_two_pi_mul(x):
if x.op != Ops.MUL or len(x.src) != 2: return None
for i, s in enumerate(x.src):
if s.op == Ops.CONST and abs(s.arg - 6.283185307179586) < 1e-5: return (x.src[1-i], 6.283185307179586)
if s.op == Ops.MUL and len(s.src) == 2:
vals = [ss.arg for ss in s.src if ss.op == Ops.CONST] + [ss.src[0].arg for ss in s.src if ss.op == Ops.CAST and ss.src[0].op == Ops.CONST]
if len(vals) == 2 and abs(vals[0] * vals[1] - 6.283185307179586) < 1e-5: return (x.src[1-i], vals[0] * vals[1])
return None
def _trig_reduce(x, phase=0.0):
match = _find_two_pi_mul(x)
if match is not None:
turns, two_pi = match
if phase: turns = turns + _const(turns.dtype, phase)
n = _floor(turns + _const(turns.dtype, 0.5))
return UOp(Ops.SIN, turns.dtype, ((turns - n) * _const(turns.dtype, two_pi),))
if phase: x = x + _const(x.dtype, phase * 6.283185307179586)
n = _floor(x * _const(x.dtype, 0.15915494309189535) + _const(x.dtype, 0.5))
return UOp(Ops.SIN, x.dtype, (x - n * _const(x.dtype, 6.283185307179586),))
def _signext(val: UOp) -> UOp:
for bits, mask, ext in [(4, 0xF, 0xFFFFFFF0), (8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]:
if (val.op == Ops.AND and len(val.src) == 2 and val.src[1].op == Ops.CONST and val.src[1].arg == mask) or val.dtype.itemsize == bits // 8:
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
sb = (v32 >> _u32(bits - 1)) & _u32(1)
return sb.ne(_u32(0)).where(v32 | _u32(ext), v32).cast(dtypes.int)
return val.cast(dtypes.int64) if val.dtype in (dtypes.int, dtypes.int32) else val
def _signext_4bit(val: UOp) -> UOp:
"""Sign extend a 4-bit value to 32-bit signed integer."""
v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val
sb = (v32 >> _u32(3)) & _u32(1) # sign bit at position 3
return sb.ne(_u32(0)).where(v32 | _u32(0xFFFFFFF0), v32).bitcast(dtypes.int)
def _abs(val: UOp) -> UOp:
if val.dtype not in (dtypes.float32, dtypes.float64, dtypes.half): return val
_, _, _, _, shift = _float_info(val)
sign_mask = {10: 0x7FFF, 23: 0x7FFFFFFF, 52: 0x7FFFFFFFFFFFFFFF}[shift]
bt, ft = {10: (dtypes.uint16, dtypes.half), 23: (dtypes.uint32, dtypes.float32), 52: (dtypes.uint64, dtypes.float64)}[shift]
return (val.bitcast(bt) & _const(bt, sign_mask)).bitcast(ft)
def _f_to_u(f, dt):
clamped = (f < _const(f.dtype, 0.0)).where(_const(f.dtype, 0.0), f)
truncated = UOp(Ops.TRUNC, f.dtype, (clamped,))
return (truncated >= _const(f.dtype, 2**(dt.itemsize*8))).where(_const(dt, dt.max), truncated.cast(dt))
def _cvt_quiet(val: UOp) -> UOp:
bits, _, _, qb, _ = _float_info(val)
bt, ft = (dtypes.uint64, dtypes.float64) if val.dtype == dtypes.float64 else \
(dtypes.uint16, dtypes.half) if val.dtype == dtypes.half else (dtypes.uint32, dtypes.float32)
return (val.bitcast(bt) | qb).bitcast(ft)
def _is_denorm(val: UOp) -> UOp:
bits, exp_m, mant_m, _, _ = _float_info(val)
return (bits & exp_m).eq(_const(bits.dtype, 0)) & (bits & mant_m).ne(_const(bits.dtype, 0))
_EXP_BITS = {10: 0x1F, 23: 0xFF, 52: 0x7FF}
def _get_exp(bits: UOp, shift: int) -> UOp: return ((bits >> _const(bits.dtype, shift)) & _const(bits.dtype, _EXP_BITS[shift])).cast(dtypes.int)
def _exponent(val: UOp) -> UOp:
bits, _, _, _, shift = _float_info(val)
return _get_exp(bits, shift)
def _div_would_be_denorm(a: UOp, b: UOp) -> UOp:
bits_n, _, _, _, shift = _float_info(a)
bits_d, _, _, _, _ = _float_info(b)
min_exp = {10: -14, 23: -126, 52: -1022}[shift]
return (_get_exp(bits_n, shift) - _get_exp(bits_d, shift)) < _const(dtypes.int, min_exp)
def _sign(val: UOp) -> UOp:
bits, _, _, _, shift = _float_info(val)
sign_shift = {10: 15, 23: 31, 52: 63}[shift]
return ((bits >> _const(bits.dtype, sign_shift)) & _const(bits.dtype, 1)).cast(dtypes.uint32)
def _signext_from_bit(val: UOp, w: UOp) -> UOp:
is_64bit = val.dtype in (dtypes.uint64, dtypes.int64)
dt = dtypes.uint64 if is_64bit else dtypes.uint32
mask_all = _const(dt, 0xFFFFFFFFFFFFFFFF if is_64bit else 0xFFFFFFFF)
one = _const(dt, 1)
val_u = val.cast(dt) if val.dtype != dt else val
w_val = w.cast(dt) if w.dtype != dt else w
sign_bit = (val_u >> (w_val - one)) & one
ext_mask = ((one << w_val) - one) ^ mask_all
return sign_bit.ne(_const(dt, 0)).where(val_u | ext_mask, val_u)
def _ldexp(val: UOp, exp: UOp) -> UOp:
if val.dtype == dtypes.uint32: val = val.bitcast(dtypes.float32)
elif val.dtype == dtypes.uint64: val = val.bitcast(dtypes.float64)
if exp.dtype in (dtypes.uint32, dtypes.uint64): exp = exp.cast(dtypes.int if exp.dtype == dtypes.uint32 else dtypes.int64)
return val * UOp(Ops.EXP2, val.dtype, (exp.cast(val.dtype),))
def _frexp_mant(val: UOp) -> UOp:
val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val
if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) & _u32(0x807FFFFF)) | _u32(0x3f000000)).bitcast(dtypes.float32)
return ((val.bitcast(dtypes.uint64) & _const(dtypes.uint64, 0x800FFFFFFFFFFFFF)) |
_const(dtypes.uint64, 0x3fe0000000000000)).bitcast(dtypes.float64)
def _frexp_exp(val: UOp) -> UOp:
val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val
if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) >> _u32(23)) & _u32(0xFF)).cast(dtypes.int) - _const(dtypes.int, 126)
return ((val.bitcast(dtypes.uint64) >> _const(dtypes.uint64, 52)) & _const(dtypes.uint64, 0x7FF)).cast(dtypes.int) - _const(dtypes.int, 1022)
TWO_OVER_PI = int(
"0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd"
"63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414"
"da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6", 16)
# TWO_OVER_PI as 19 u64 words for trig_preop_result (word[0] = bits 0-63, word[18] = bits 1152-1200)
_PREOP_WORDS = tuple((TWO_OVER_PI >> (64 * i)) & 0xFFFFFFFFFFFFFFFF for i in range(19))
def _trig_preop(val: UOp) -> UOp:
# Extract 53 bits from position (1148 - shift) in the 1201-bit 2/PI constant
# Using word-based selection: 19 conditions instead of 1149
shift = val.cast(dtypes.uint32)
bit_pos = _u32(1148) - shift # starting bit position from LSB
word_idx = bit_pos >> _u32(6) # // 64
bit_off = bit_pos & _u32(63) # % 64
# Select lo_word and hi_word using shared conditions
lo_word, hi_word = _u64(_PREOP_WORDS[18]), _u64(0)
for i in range(17, -1, -1):
cond = word_idx.eq(_u32(i))
lo_word = cond.where(_u64(_PREOP_WORDS[i]), lo_word)
hi_word = cond.where(_u64(_PREOP_WORDS[i + 1]), hi_word)
# Combine and extract 53 bits: ((lo >> bit_off) | (hi << (64 - bit_off))) & mask
bit_off_64 = bit_off.cast(dtypes.uint64)
result = ((lo_word >> bit_off_64) | (hi_word << (_u64(64) - bit_off_64))) & _u64(0x1fffffffffffff)
return result.cast(dtypes.float64)
def _ff1(val: UOp, bits: int) -> UOp:
dt = dtypes.uint64 if bits == 64 else dtypes.uint32
val = val.cast(dt) if val.dtype != dt else val
result = _const(dtypes.int, -1)
for i in range(bits):
cond = ((val >> _const(dt, i)) & _const(dt, 1)).ne(_const(dt, 0)) & result.eq(_const(dtypes.int, -1))
result = cond.where(_const(dtypes.int, i), result)
return result
def _sad_u8(a: UOp, b: UOp, acc: UOp, masked: bool = False) -> UOp:
"""Sum of absolute differences of 4 unsigned bytes + accumulator. If masked, skips bytes where a == 0."""
a, b, acc = a.cast(dtypes.uint32), b.cast(dtypes.uint32), acc.cast(dtypes.uint32)
result = acc
for i in range(4):
a_byte = (a >> _u32(i * 8)) & _u32(0xFF)
b_byte = (b >> _u32(i * 8)) & _u32(0xFF)
diff = (a_byte > b_byte).where(a_byte - b_byte, b_byte - a_byte)
result = result + (a_byte.ne(_u32(0)).where(diff, _u32(0)) if masked else diff)
return result
_FUNCS: dict[str, Callable[..., UOp]] = {
'sqrt': lambda a: UOp(Ops.SQRT, a.dtype, (a,)), 'trunc': lambda a: UOp(Ops.TRUNC, a.dtype, (a,)),
'log2': lambda a: UOp(Ops.LOG2, a.dtype, (a,)), 'sin': lambda a: _trig_reduce(a),
'cos': lambda a: _trig_reduce(a, 0.25), 'floor': _floor, 'fract': lambda a: a - _floor(a),
'signext': _signext, 'abs': _abs,
'isEven': lambda a: (UOp(Ops.TRUNC, a.dtype, (a,)).cast(dtypes.int) & _const(dtypes.int, 1)).eq(_const(dtypes.int, 0)),
'max': lambda a, b: UOp(Ops.MAX, a.dtype, (a, b)),
'min': lambda a, b: UOp(Ops.MAX, a.dtype, (a.neg(), b.neg())).neg(),
'pow': lambda a, b: UOp(Ops.EXP2, dtypes.float32, (b.bitcast(dtypes.float32),)),
'fma': lambda a, b, c: a * b + c,
'i32_to_f32': lambda a: a.cast(dtypes.int).cast(dtypes.float32),
'u32_to_f32': lambda a: a.cast(dtypes.uint32).cast(dtypes.float32),
'f32_to_i32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int),
'f32_to_u32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint32),
'f64_to_i32': lambda a: UOp(Ops.TRUNC, dtypes.float64, (a.bitcast(dtypes.float64),)).cast(dtypes.int),
'f64_to_u32': lambda a: _f_to_u(a.bitcast(dtypes.float64), dtypes.uint32),
'f16_to_f32': lambda a: _f16_extract(a).cast(dtypes.float32),
'f32_to_f16': lambda a: a.cast(dtypes.half),
'f32_to_f64': lambda a: a.bitcast(dtypes.float32).cast(dtypes.float64),
'f64_to_f32': lambda a: a.bitcast(dtypes.float64).cast(dtypes.float32),
'i32_to_f64': lambda a: a.cast(dtypes.int).cast(dtypes.float64),
'u32_to_f64': lambda a: a.cast(dtypes.uint32).cast(dtypes.float64),
'f16_to_i16': lambda a: UOp(Ops.TRUNC, dtypes.half, (_f16_extract(a),)).cast(dtypes.int16),
'f16_to_u16': lambda a: UOp(Ops.TRUNC, dtypes.half, (_f16_extract(a),)).cast(dtypes.uint16),
'i16_to_f16': lambda a: a.cast(dtypes.int16).cast(dtypes.half),
'u16_to_f16': lambda a: a.cast(dtypes.uint16).cast(dtypes.half),
'bf16_to_f32': lambda a: (((a.cast(dtypes.uint32) if a.dtype != dtypes.uint32 else a) & _u32(0xFFFF)) << _u32(16)).bitcast(dtypes.float32),
'isNAN': _isnan, 'isSignalNAN': lambda a: _check_nan(a, False),
'isQuietNAN': lambda a: _check_nan(a, True), 'cvtToQuietNAN': _cvt_quiet,
'isDENORM': _is_denorm, 'exponent': _exponent, 'divWouldBeDenorm': _div_would_be_denorm, 'sign': _sign,
'signext_from_bit': _signext_from_bit, 'ldexp': _ldexp, 'frexp_mant': _frexp_mant, 'mantissa': _frexp_mant,
'frexp_exp': _frexp_exp, 'trig_preop_result': _trig_preop,
's_ff1_i32_b32': lambda a: _ff1(a, 32), 's_ff1_i32_b64': lambda a: _ff1(a, 64),
# Normalization conversions: map [-1,1] or [0,1] to integer range
# Use floor(x + 0.5) for round-to-nearest
# SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior)
'f16_to_snorm': lambda a: _floor(
_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f16_to_unorm': lambda a: _floor(
_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f32_to_snorm': lambda a: _floor(
a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16),
'f32_to_unorm': lambda a: _floor(
a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16),
'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8),
# Integer truncation conversions
'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16),
'u32_to_u16': lambda a: a.cast(dtypes.uint32).cast(dtypes.uint16),
'u16_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFFFF)),
'u8_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFF)),
'u4_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xF)),
# Signed extraction with sign extension for dot products
'i16_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFFFF)),
'i8_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFF)),
'i4_to_i32': lambda a: _signext_4bit(a.cast(dtypes.uint32) & _u32(0xF)),
# Float to int16 conversions
'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16),
'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16),
# SAD (Sum of Absolute Differences) - sum |a_i - b_i| for 4 bytes + accumulator
'v_sad_u8': lambda a, b, c: _sad_u8(a, b, c),
'v_msad_u8': lambda a, b, c: _sad_u8(a, b, c, masked=True),
# System NOPs - these are scheduling hints, no effect on emulation
'MIN': lambda a, b: (a < b).where(a, b),
's_nop': lambda a: _u32(0),
# Address calculation for memory operations
'CalcDsAddr': lambda a, o, *r: a.cast(dtypes.uint32) + o.cast(dtypes.uint32),
'CalcGlobalAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64),
'CalcScratchAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64),
# FP8/BF8/BF16 conversion functions
'fp8_to_f32': _fp8_to_f32, 'bf8_to_f32': _bf8_to_f32, 'f32_to_fp8': _f32_to_fp8, 'f32_to_bf8': _f32_to_bf8,
'f32_to_bf16': _f32_to_bf16, 'f32_to_bf16_SR': _f32_to_bf16_sr, 'f32_to_bf16_sr': _f32_to_bf16_sr,
}
for is_max, name in [(False, 'min'), (True, 'max')]:
for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]:
_FUNCS[f'v_{name}_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
_FUNCS[f'v_{name}3_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a)
# f16 min/max/min3/max3/med3
for is_max, name in [(False, 'min'), (True, 'max')]:
_FUNCS[f'v_{name}_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
_FUNCS[f'v_{name}3_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}3_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
_FUNCS[f'v_{name}imum_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}imum_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
_FUNCS[f'v_{name}imum3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a])
_FUNCS[f'v_{name}imum3_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a)
# ═══════════════════════════════════════════════════════════════════════════════
# TOKENIZER/PARSER
# ═══════════════════════════════════════════════════════════════════════════════
DTYPES = {'u32': dtypes.uint32, 'i32': dtypes.int, 'f32': dtypes.float32, 'b32': dtypes.uint32, 'u64': dtypes.uint64, 'i64': dtypes.int64,
'f64': dtypes.float64, 'b64': dtypes.uint64, 'u16': dtypes.uint16, 'i16': dtypes.short, 'f16': dtypes.half, 'b16': dtypes.uint16,
'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32,
'fp8': dtypes.uint8, 'bf8': dtypes.uint8, 'b3': dtypes.uint8, 'b2': dtypes.uint8}
_BITS_DT = {8: dtypes.uint8, 16: dtypes.uint16, 32: dtypes.uint32, 64: dtypes.uint64}
_NUM_SUFFIXES = ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f')
def _strip_suffix(num: str) -> tuple[str, str]:
for sfx in _NUM_SUFFIXES:
if num.endswith(sfx): return sfx, num[:-len(sfx)]
return '', num
_SINGLE_CHAR = {'(': 'LPAREN', ')': 'RPAREN', '[': 'LBRACKET', ']': 'RBRACKET', '{': 'LBRACE', '}': 'RBRACE',
':': 'COLON', ',': 'COMMA', '?': 'QUESTION', '.': 'DOT', '=': 'EQUALS', "'": 'QUOTE'}
class Token:
__slots__ = ('type', 'val')
def __init__(self, kind: str, val: str): self.type, self.val = kind, val
def __repr__(self): return f'{self.type}:{self.val}'
def tokenize(s: str) -> list[Token]:
tokens, i, n = [], 0, len(s)
while i < n:
c = s[i]
if c.isspace():
i += 1
continue
if i + 1 < n and s[i:i+2] in ('+=', '-='):
tokens.append(Token('ASSIGN_OP', s[i:i+2]))
i += 2
continue
if i + 1 < n and s[i:i+2] in ('||', '&&', '>=', '<=', '==', '!=', '<>', '>>', '<<', '**', '+:', '-:'):
tokens.append(Token('OP', s[i:i+2]))
i += 2
continue
if c in '|^&><+-*/~!%':
tokens.append(Token('OP', c))
i += 1
continue
if (t := _SINGLE_CHAR.get(c)):
tokens.append(Token(t, c))
i += 1
continue
if c == ';':
i += 1
continue
if c.isdigit() or (c == '-' and i + 1 < n and s[i+1].isdigit()):
start = i
if c == '-': i += 1
if i + 1 < n and s[i] == '0' and s[i+1] in 'xX':
i += 2
while i < n and s[i] in '0123456789abcdefABCDEF': i += 1
else:
while i < n and s[i].isdigit(): i += 1
if i < n and s[i] == '.' and i + 1 < n and s[i+1].isdigit():
i += 1
while i < n and s[i].isdigit(): i += 1
for sfx in ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f'):
if s[i:i+len(sfx)] == sfx:
i += len(sfx)
break
tokens.append(Token('NUM', s[start:i]))
continue
if c.isalpha() or c == '_':
start = i
while i < n and (s[i].isalnum() or s[i] == '_'): i += 1
tokens.append(Token('IDENT', s[start:i]))
continue
raise RuntimeError(f"unexpected char '{c}' at pos {i} in: {s}")
tokens.append(Token('EOF', ''))
return tokens
class Parser:
def __init__(self, tokens: list[Token], env: dict, funcs: dict | None = None):
self.tokens, self.vars, self.funcs, self.pos = tokens, env, funcs if funcs is not None else _FUNCS, 0
def peek(self, offset=0) -> Token: return self.tokens[min(self.pos + offset, len(self.tokens) - 1)]
def at(self, *types) -> bool: return self.peek().type in types
def _advance(self) -> Token:
tok = self.tokens[self.pos]
self.pos += 1
return tok
def eat(self, kind: str) -> Token:
if self.peek().type != kind: raise RuntimeError(f"expected {kind}, got {self.peek()}")
return self._advance()
def try_eat(self, kind: str) -> Token | None: return self._advance() if self.peek().type == kind else None
def try_eat_val(self, val: str, kind: str) -> Token | None:
return self._advance() if self.peek().type == kind and self.peek().val == val else None
def eat_val(self, val: str, kind: str) -> Token:
if self.peek().type != kind or self.peek().val != val: raise RuntimeError(f"expected {kind}:{val}, got {self.peek()}")
return self._advance()
def parse(self) -> UOp:
cond = self.binop(0)
if self.try_eat('QUESTION'):
then_val = self.parse()
self.eat('COLON')
return _to_bool(cond).where(then_val, self.parse())
return cond
def _apply_binop(self, left, right, op):
if op in ('||', '&&', '|', '^', '&'): left, right = self._coerce_bitwise(left, right)
elif op in ('>=', '<=', '>', '<', '==', '!=', '<>', '>>', '<<'): left, right = self._coerce_cmp(left, right)
elif left.dtype != right.dtype: right = right.cast(left.dtype)
match op:
case '||' | '|': return left | right
case '&&' | '&': return left & right
case '^': return left ^ right
case '==': return left.eq(right)
case '!=': return left.ne(right)
case '>=' | '<=' | '>' | '<' | '<>':
ops = {'>=':(lambda a,b:a>=b),'<=':(lambda a,b:a<=b),'>':(lambda a,b:a>b),'<':(lambda a,b:a<b),'<>':(lambda a,b:a.ne(b))}
return self._cmp_nan(left, right, ops[op])
case '>>' | '<<': return (left >> right) if op == '>>' else (left << right)
case '+' | '-':
if op == '-' and left.op == Ops.CONST and right.op == Ops.CONST: return _const(left.dtype, left.arg - right.arg)
return (left + right) if op == '+' else (left - right)
case '*' | '/':
# Integer promotion: promote 16-bit integers to 32-bit before multiply to avoid overflow
# (e.g. SOPP branch offset: SIMM16.i16 * 16'4 can exceed int16 range)
if op == '*' and left.dtype.itemsize == 2 and left.dtype in (dtypes.int16, dtypes.short, dtypes.uint16, dtypes.ushort):
pdt = dtypes.int if left.dtype in (dtypes.int16, dtypes.short) else dtypes.uint
left, right = left.cast(pdt), right.cast(pdt)
return (left * right) if op == '*' else (left / right)
case '**': return UOp(Ops.EXP2, left.dtype, (right.cast(left.dtype),)) if left.op == Ops.CONST and left.arg == 2.0 else left
_PREC = [('||',), ('&&',), ('|',), ('^',), ('&',), ('==', '!=', '<>'), ('>=', '<=', '>', '<'), ('>>', '<<'), ('+', '-'), ('*', '/'), ('**',)]
def binop(self, prec: int) -> UOp:
if prec >= len(self._PREC): return self.unary()
left = self.binop(prec + 1)
ops = self._PREC[prec]
while self.at('OP') and self.peek().val in ops:
op = self.eat('OP').val
left = self._apply_binop(left, self.binop(prec + 1), op)
return left
def unary(self) -> UOp:
if self.try_eat_val('~', 'OP'):
inner = self.unary()
return inner ^ _const(inner.dtype, (1 << (inner.dtype.itemsize * 8)) - 1)
if self.try_eat_val('!', 'OP'):
inner = self.unary()
return inner.eq(_const(inner.dtype, 0))
if self.try_eat_val('-', 'OP'):
inner = self.unary()
if inner.op == Ops.CONST:
return _const(dtypes.int if inner.dtype == dtypes.uint32 else inner.dtype, -inner.arg)
return inner.neg()
if self.try_eat_val('+', 'OP'): return self.unary()
return self.postfix()
def postfix(self) -> UOp:
base = self.primary()
while True:
if self.try_eat('DOT'):
field = self.eat('IDENT').val
base = self._handle_dot(base, field)
elif self.at('LBRACKET'):
base = self._handle_bracket(base)
elif self.at('LBRACE'):
base = self._handle_brace_index(base)
else:
break
return base
def primary(self) -> UOp:
if self.try_eat('LPAREN'):
e = self.parse()
self.eat('RPAREN')
return e
if self.try_eat('LBRACE'):
hi = self.parse()
self.eat('COMMA')
lo = self.parse()
self.eat('RBRACE')
return (hi.cast(dt:=_BITS_DT.get((s:=lo.dtype.bitsize) * 2, dtypes.uint64)) << _const(dt, s)) | lo.cast(dt)
if self.at('NUM'):
num = self.eat('NUM').val
if self.try_eat('QUOTE'):
return self._sized_literal(int(num.rstrip('ULlf')))
return self._parse_number(num)
if self.at('IDENT'):
name = self.eat('IDENT').val
if name == 'MEM':
self.eat('LBRACKET')
addr = self.parse()
self.eat('RBRACKET')
self.eat('DOT')
dt_name = self.eat('IDENT').val
return self._handle_mem_load(addr, DTYPES.get(dt_name, dtypes.uint32))
if name == 'VGPR' and self.at('LBRACKET'):
self.eat('LBRACKET')
lane = self.parse()
self.eat('RBRACKET')
self.eat('LBRACKET')
reg = self.parse()
self.eat('RBRACKET')
vgpr = self.vars.get('_vgpr')
if vgpr is None: return _u32(0)
ws = self.vars.get('_wave_size', 32)
return vgpr.index(_to_u32(reg) * _u32(ws) + _to_u32(lane), ptr=True).load()
if self.try_eat('LPAREN'):
args = self._parse_args()
self.eat('RPAREN')
return self._call_func(name, args)
if name == 'PI': return _const(dtypes.float32, 3.141592653589793)
if name == 'INF': return _const(dtypes.float64, float('inf'))
if name == 'NAN': return _const(dtypes.uint32, 0x7FC00000).bitcast(dtypes.float32)
if name == 'UNDERFLOW_F32': return _const(dtypes.uint32, 1).bitcast(dtypes.float32)
if name == 'OVERFLOW_F32': return _const(dtypes.uint32, 0x7F7FFFFF).bitcast(dtypes.float32)
if name == 'UNDERFLOW_F64': return _const(dtypes.uint64, 1).bitcast(dtypes.float64)
if name == 'OVERFLOW_F64': return _const(dtypes.uint64, 0x7FEFFFFFFFFFFFFF).bitcast(dtypes.float64)
if name == 'WAVE32': return _const(dtypes.bool, self.vars.get('_wave_size', 32) <= 32)
if name == 'WAVE64': return _const(dtypes.bool, self.vars.get('_wave_size', 32) > 32)
if name == 'WAVE_MODE' and self.try_eat('DOT') and self.try_eat_val('IEEE', 'IDENT'): return _u32(1)
if self.try_eat('LBRACE'):
idx = self.eat('NUM').val
self.eat('RBRACE')
# Handle VGPR{lane}[reg] - 2D array access after loop unrolling
if name == 'VGPR' and self.at('LBRACKET'):
self.eat('LBRACKET')
reg = self.parse()
self.eat('RBRACKET')
vgpr = self.vars.get('_vgpr')
if vgpr is None: return _u32(0)
ws = self.vars.get('_wave_size', 32)
return vgpr.index(_to_u32(reg) * _u32(ws) + _u32(int(idx)), ptr=True).load()
elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}'))
if elem is None:
# Extract bit idx from base variable (like var[idx])
base = self.vars.get(name)
assert isinstance(base, UOp), f"unknown variable: {name}{idx}"
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
elem = (base.cast(dt) >> _const(dt, int(idx))) & _const(dt, 1)
if self.try_eat('DOT'):
dt_name = self.eat('IDENT').val
return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32))
if self.at('LBRACKET'):
return self._handle_bracket(elem, name + idx)
return elem
if self.at('LBRACKET') and name not in self.vars:
self.eat('LBRACKET')
first = self.parse()
return self._handle_bracket_rest(first, _u32(0), name)
if name in self.vars:
v = self.vars[name]
assert isinstance(v, UOp), f"expected UOp for {name}, got {type(v)}"
return v
raise RuntimeError(f"unknown variable: {name}")
raise RuntimeError(f"unexpected token in primary: {self.peek()}")
def _handle_dot(self, base, field: str) -> UOp:
assert isinstance(base, UOp), f"expected UOp for dot access, got {type(base)}"
if field == 'u64' and self.at('LBRACKET') and self.peek(1).type == 'IDENT' and self.peek(1).val == 'laneId':
self.eat('LBRACKET')
self.eat_val('laneId', 'IDENT')
self.eat('RBRACKET')
lane = self.vars['laneId']
shift = lane.cast(base.dtype) if base.dtype != dtypes.uint32 else _to_u32(lane)
result = (base >> shift) & _const(base.dtype, 1)
if self.try_eat('DOT'):
dt_name = self.eat('IDENT').val
return result.cast(DTYPES.get(dt_name, dtypes.uint32))
return result
dt = DTYPES.get(field)
if dt is None: return base
if dt == base.dtype: return base
if dt.itemsize == 2 and base.dtype.itemsize == 4:
if dt == dtypes.uint16: return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16)
return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt)
if field == 'i4': return _signext_4bit(base)
return _cast_to(base, dt)
def _handle_bracket(self, base, var_name: str | None = None) -> UOp:
self.eat('LBRACKET')
return self._handle_bracket_rest(self.parse(), base, var_name)
def _handle_bracket_rest(self, first: UOp, base: UOp, var_name: str | None = None) -> UOp:
if self.at('OP') and self.peek().val in ('+:', '-:'):
self.eat('OP')
width = self.parse()
self.eat('RBRACKET')
if width.op == Ops.CONST:
w = int(width.arg)
return (base >> _to_u32(first)) & _const(base.dtype, (1 << w) - 1)
return base
if self.try_eat('COLON'):
second = self.parse()
self.eat('RBRACKET')
if first.op == Ops.CONST and second.op == Ops.CONST:
a, b = int(first.arg), int(second.arg)
if a < b: return _bitreverse(base, b - a + 1)
hi, lo = a, b
if lo >= base.dtype.itemsize * 8:
vn = var_name or self._find_var_name(base)
if vn and f'{vn}{lo // 32}' in self.vars:
base = self.vars[f'{vn}{lo // 32}']
lo, hi = lo % 32, (hi % 32) + (lo % 32)
return _extract_bits(base, hi, lo)
# Dynamic bit slice: (base >> lo) & ((1 << (hi - lo + 1)) - 1)
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
hi, lo = first.cast(dt), second.cast(dt)
width = hi - lo + _const(dt, 1)
mask = (_const(dt, 1) << width) - _const(dt, 1)
return (base.cast(dt) >> lo) & mask
self.eat('RBRACKET')
dt_suffix = None
if self.try_eat('DOT'):
dt_suffix = DTYPES.get(self.eat('IDENT').val, dtypes.uint32)
if var_name is None:
var_name = self._find_var_name(base)
if first.op == Ops.CONST:
idx = int(first.arg)
# Check for array element (var@idx)
if var_name and f'{var_name}@{idx}' in self.vars:
v = self.vars[f'{var_name}@{idx}']
return _cast_to(v, dt_suffix) if dt_suffix else v
# Bit extraction
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
base_cast = base.cast(dt) if base.dtype != dt else base
result = ((base_cast >> _const(dt, idx)) & _const(dt, 1))
return _cast_to(result, dt_suffix) if dt_suffix else result
if var_name:
idx_u32 = _to_u32(first)
elems = [(i, self.vars[f'{var_name}@{i}']) for i in range(256) if f'{var_name}@{i}' in self.vars]
if elems:
result = elems[-1][1]
for ei, ev in reversed(elems[:-1]):
if ev.dtype != result.dtype and ev.dtype.itemsize == result.dtype.itemsize: result = result.cast(ev.dtype)
elif ev.dtype != result.dtype: ev = ev.cast(result.dtype)
result = idx_u32.eq(_u32(ei)).where(ev, result)
return result
dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32
base_cast = base.cast(dt) if base.dtype != dt else base
result = (base_cast >> first.cast(dt)) & _const(dt, 1)
return _cast_to(result, dt_suffix) if dt_suffix else result
def _handle_brace_index(self, base) -> UOp:
self.eat('LBRACE')
idx = self.eat('NUM').val
self.eat('RBRACE')
var_name = self._find_var_name(base)
if var_name:
elem = self.vars.get(f'{var_name}@{idx}', _u32(0)) # use @ to avoid collision with temps like A4
if self.try_eat('DOT'):
dt_name = self.eat('IDENT').val
return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32))
if self.at('LBRACKET'):
return self._handle_bracket(elem)
return elem
return _u32(0)
def _find_var_name(self, base: UOp) -> str | None:
if base.op == Ops.DEFINE_VAR and base.arg: return base.arg[0]
for name, v in self.vars.items():
if isinstance(v, UOp) and v is base: return name
return None
def _sized_literal(self, bits: int) -> UOp:
if self.at('IDENT') and self.peek().val in ('U', 'I', 'F', 'B', 'BF'):
type_char = self.eat('IDENT').val
self.eat('LPAREN')
inner = self.parse()
self.eat('RPAREN')
dt = {('U',32): dtypes.uint32, ('U',64): dtypes.uint64, ('I',32): dtypes.int, ('I',64): dtypes.int64,
('F',16): dtypes.half, ('F',32): dtypes.float32, ('F',64): dtypes.float64,
('BF',16): dtypes.bfloat16,
('B',32): dtypes.uint32, ('B',64): dtypes.uint64}.get((type_char, bits), dtypes.uint64 if bits > 32 else dtypes.uint32)
if type_char == 'F' and inner.dtype in (dtypes.uint32, dtypes.uint64, dtypes.ulong, dtypes.int, dtypes.int64):
if inner.dtype.itemsize != dt.itemsize: inner = inner.cast(dtypes.uint32 if dt.itemsize == 4 else dtypes.uint64)
return inner.bitcast(dt)
return inner.cast(dt)
if self.at('IDENT'):
ident = self.peek().val
fmt = ident[0].lower()
if fmt in ('h', 'b', 'd'):
self.eat('IDENT')
if len(ident) > 1: num = ident[1:]
elif self.at('NUM'): num = self.eat('NUM').val
elif self.at('IDENT'): num = self.eat('IDENT').val
else: raise RuntimeError(f"expected number after {bits}'{fmt}")
if fmt == 'h': val = int(num, 16)
elif fmt == 'b': val = int(num, 2)
else: val = int(num)
return _const(_BITS_DT.get(bits, dtypes.uint32), val)
if self.at('NUM') and self.peek().val.startswith('0x'):
num = self.eat('NUM').val
return _const(_BITS_DT.get(bits, dtypes.uint32), int(num, 16))
if self.at('NUM') or (self.at('OP') and self.peek().val == '-'):
neg = self.try_eat_val('-', 'OP') is not None
suffix, num = _strip_suffix(self.eat('NUM').val)
if num.startswith('0x'):
val = int(num, 16)
if neg: val = -val
elif '.' in num:
fval = float(num)
if neg: fval = -fval
return _const({16: dtypes.half, 32: dtypes.float32, 64: dtypes.float64}.get(bits, dtypes.float32), fval)
else:
val = int(num)
if neg: val = -val
dt = {1: dtypes.uint32, 8: dtypes.uint8, 16: dtypes.int16 if 'U' not in suffix else dtypes.uint16,
32: dtypes.int if 'U' not in suffix else dtypes.uint32, 64: dtypes.int64 if 'U' not in suffix else dtypes.uint64}.get(bits, dtypes.uint32)
return _const(dt, val)
raise RuntimeError(f"unexpected token after {bits}': {self.peek()}")
def _parse_number(self, num: str) -> UOp:
if num.startswith('0x') or num.startswith('0X'):
is_u64 = num.upper().endswith('ULL') or num.upper().endswith('LL') or num.upper().endswith('UL')
return _const(dtypes.uint64 if is_u64 else dtypes.uint32, int(num.rstrip('ULul'), 16))
suffix, num_str = _strip_suffix(num)
if '.' in num_str or suffix in ('F', 'f'):
return _const(dtypes.float32 if suffix in ('F', 'f') else dtypes.float64, float(num_str))
val = int(num_str)
if 'ULL' in suffix or 'LL' in suffix or 'L' in suffix: return _const(dtypes.uint64, val)
if 'U' in suffix: return _const(dtypes.uint32, val)
return _const(dtypes.int if val < 0 else dtypes.uint32, val)
def _parse_args(self) -> list[UOp]:
if self.at('RPAREN'): return []
args = [self.parse()]
while self.try_eat('COMMA'):
args.append(self.parse())
return args
def _call_func(self, name: str, args: list[UOp]) -> UOp:
if name in self.vars and isinstance(self.vars[name], tuple) and self.vars[name][0] == 'lambda':
_, params, body = self.vars[name]
lv = {**self.vars, **dict(zip(params, args))}
if ';' in body or '\n' in body or 'return' in body.lower():
lines = [l.strip() for l in body.replace(';', '\n').split('\n') if l.strip() and not l.strip().startswith('//')]
_, _, result = parse_block(lines, 0, lv, self.funcs)
assert result is not None, f"lambda {name} must return a value"
return result
return parse_expr(body, lv, self.funcs)
if name in self.funcs:
return self.funcs[name](*args)
raise RuntimeError(f"unknown function: {name}")
def _handle_mem_load(self, addr: UOp, dt) -> UOp:
mem = self.vars.get('_vmem') if '_vmem' in self.vars else self.vars.get('_lds')
assert mem is not None, "memory load requires _vmem or _lds"
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
active = self.vars.get('_active')
gate = (active,) if active is not None else ()
byte_mem = mem.dtype.base == dtypes.uint8
if byte_mem:
idx = addr.cast(dtypes.int)
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
val = _u32(0).cast(dtypes.uint64)
for i in range(8): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint64) << _u64(i * 8))
elif dt in (dtypes.uint8, dtypes.int8):
val = mem.index(idx, *gate, ptr=True).load().cast(dt)
elif dt in (dtypes.uint16, dtypes.int16, dtypes.short):
lo = mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32)
hi = mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32)
val = (lo | (hi << _u32(8))).cast(dt)
else:
val = _u32(0)
for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8))
else:
idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int)
val = mem.index(idx, *gate)
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)
val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32))
elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF)
elif dt in (dtypes.uint16, dtypes.int16):
val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF)
else:
# Handle unaligned 32-bit loads: combine two consecutive dwords and shift.
# To avoid OOB at buffer boundaries for aligned loads, clamp idx_hi to idx (safe).
# Use int64 for the WHERE to avoid 32-bit int overflow in C pointer arithmetic (addr can be >8GB).
byte_off = (addr & _const(adt, 3)).cast(dtypes.uint32)
is_unaligned = byte_off.ne(_u32(0))
idx_native = (addr >> _const(adt, 2)).cast(dtypes.int64)
idx_hi_native = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int64)
safe_idx_hi = is_unaligned.where(idx_hi_native, idx_native)
hi = mem.index(safe_idx_hi, *gate)
combined = val.cast(dtypes.uint64) | (hi.cast(dtypes.uint64) << UOp.const(dtypes.uint64, 32))
val = is_unaligned.where((combined >> (byte_off.cast(dtypes.uint64) * UOp.const(dtypes.uint64, 8))).cast(dtypes.uint32), val)
return val
def _coerce_cmp(self, l: UOp, r: UOp) -> tuple[UOp, UOp]:
if l.dtype != r.dtype:
if r.dtype == dtypes.int and r.op == Ops.CONST and r.arg < 0: l = l.cast(dtypes.int)
else: r = r.cast(l.dtype)
return l, r
def _coerce_bitwise(self, l: UOp, r: UOp) -> tuple[UOp, UOp]:
if l.dtype != r.dtype:
if l.dtype.itemsize == r.dtype.itemsize:
t = dtypes.uint32 if l.dtype.itemsize == 4 else dtypes.uint64 if l.dtype.itemsize == 8 else l.dtype
l, r = l.bitcast(t), r.bitcast(t)
else: r = r.cast(l.dtype)
return l, r
def _cmp_nan(self, l: UOp, r: UOp, fn) -> UOp:
result = fn(l, r)
if l.dtype in (dtypes.float32, dtypes.float64, dtypes.half):
return result & _isnan(l).logical_not() & _isnan(r).logical_not()
return result
def _match_bracket(toks: list[Token], start: int) -> tuple[int, list[Token]]:
"""Match brackets from start, return (end_idx, inner_tokens)."""
j, depth = start + 1, 1
while j < len(toks) and depth > 0:
if toks[j].type == 'LBRACKET': depth += 1
elif toks[j].type == 'RBRACKET': depth -= 1
j += 1
return j, [t for t in toks[start+1:j-1] if t.type != 'EOF']
def _tok_str(toks: list[Token]) -> str: return ' '.join(t.val for t in toks if t.type != 'EOF')
def parse_tokens(toks: list[Token], env: dict[str, VarVal], funcs: dict | None = None) -> UOp:
return Parser(toks, env, funcs).parse()
# Unified block parser for pcode
def _subst_loop_var(line: str, loop_var: str, val: int) -> str:
"""Substitute loop variable with its value."""
toks = tokenize(line)
return ' '.join(str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in toks if t.type != 'EOF')
def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp:
"""Set bits [offset:offset+width) in old to val, masking and shifting appropriately."""
if old.dtype in (dtypes.half, dtypes.float32): old = _val_to_bits(old)
is64 = old.dtype in (dtypes.uint64, dtypes.int64) or offset + width > 32
if is64:
old = old.cast(dtypes.uint64) if old.dtype != dtypes.uint64 else old
mask = _u64(((1 << width) - 1) << offset)
v = (val.cast(dtypes.uint64) if val.dtype != dtypes.uint64 else val) & _u64((1 << width) - 1)
return (old & (mask ^ _u64(0xFFFFFFFFFFFFFFFF))) | (v << _u64(offset))
mask = _u32(((1 << width) - 1) << offset)
v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1)
return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset))
def _find_paren_end(s: str, start: int = 0, open_ch: str = '(', close_ch: str = ')') -> int:
"""Find index of matching close paren, starting after the open paren at start."""
depth = 0
for j, ch in enumerate(s[start:], start):
if ch == open_ch: depth += 1
elif ch == close_ch:
depth -= 1
if depth == 0: return j
return len(s)
def parse_block(lines: list[str], start: int, env: dict[str, VarVal], funcs: dict | None = None,
assigns: list | None = None) -> tuple[int, dict[str, VarVal], UOp | None]:
"""Parse a block of pcode. Returns (next_line, block_assigns, return_value).
If assigns list is provided, side effects (MEM/VGPR writes) are appended to it."""
if funcs is None: funcs = _FUNCS
block_assigns: dict[str, VarVal] = {}
i = start
while i < len(lines):
line = lines[i]
toks = tokenize(line)
if toks[0].type != 'IDENT' and toks[0].type != 'LBRACE':
i += 1
continue
first = toks[0].val.lower() if toks[0].type == 'IDENT' else '{'
# Block terminators
if first in ('elsif', 'else', 'endif', 'endfor'): break
# return expr (lambda bodies)
if first == 'return':
rest = line[line.lower().find('return') + 6:].strip()
return i + 1, block_assigns, parse_expr(rest, env, funcs)
# for loop
if first == 'for':
# Parse: for VAR in [SIZE']START : [SIZE']END do
p = Parser(toks, env, funcs)
p.eat_val('for', 'IDENT')
loop_var = p.eat('IDENT').val
p.eat_val('in', 'IDENT')
def parse_bound():
if p.at('NUM') and p.peek(1).type == 'QUOTE':
p.eat('NUM')
p.eat('QUOTE')
if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl'))
expr = p.parse().simplify()
assert expr.op == Ops.CONST, f"loop bound must be constant, got {expr}"
return int(expr.arg)
start_val = parse_bound()
p.eat('COLON')
end_val = parse_bound()
# Collect body
i += 1
body_lines: list[str] = []
depth = 1
while i < len(lines) and depth > 0:
btoks = tokenize(lines[i])
if btoks[0].type == 'IDENT':
if btoks[0].val.lower() == 'for': depth += 1
elif btoks[0].val.lower() == 'endfor': depth -= 1
if depth > 0: body_lines.append(lines[i])
i += 1
# Execute loop with break support
has_break = any('break' in bl.lower() for bl in body_lines)
found_var = f'_found_{id(body_lines)}' if has_break else None
if found_var: env[found_var] = block_assigns[found_var] = _const(dtypes.bool, False)
for loop_i in range(start_val, end_val + 1):
subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')]
_, iter_assigns, _ = parse_block(subst_lines, 0, {**env, **block_assigns}, funcs, assigns)
if has_break:
assert found_var is not None
found = block_assigns.get(found_var, env.get(found_var))
assert isinstance(found, UOp)
not_found = found.eq(_const(dtypes.bool, False))
for var, val in iter_assigns.items():
if var != found_var and isinstance(val, UOp):
old = block_assigns.get(var, env.get(var, _u32(0)))
if isinstance(old, UOp):
block_assigns[var] = env[var] = not_found.where(
val, old.cast(val.dtype) if val.dtype != old.dtype and val.dtype.itemsize == old.dtype.itemsize else old)
for j, bl in enumerate(body_lines):
bl_l = bl.strip().lower()
if bl_l.startswith('if ') and bl_l.endswith(' then'):
if any(body_lines[k].strip().lower() == 'break' for k in range(j+1, len(body_lines))):
cond_str = _subst_loop_var(bl.strip()[3:-5].strip(), loop_var, loop_i)
cond = _to_bool(parse_expr(cond_str, env, funcs))
block_assigns[found_var] = env[found_var] = not_found.where(cond, found)
break
else:
block_assigns.update(iter_assigns)
env.update(iter_assigns)
continue
# declare
if first == 'declare':
# Initialize scalar declarations (skip arrays and env already passed as srcs)
if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT':
env.setdefault(toks[1].val, _u32(0))
i += 1
continue
# lambda definition
if first != '{' and '=' in line and 'lambda' in line and any(t.type == 'IDENT' and t.val == 'lambda' for t in toks):
name = toks[0].val
body_start = line[line.find('(', line.find('lambda')):]
params_end = _find_paren_end(body_start) + 1
params = [p.strip() for p in body_start[1:params_end-1].split(',') if p.strip()]
rest = body_start[params_end:].strip()
if rest.startswith('('):
body_end = _find_paren_end(rest)
if body_end < len(rest): # found matching paren on same line
body = rest[1:body_end].strip()
i += 1
else: # multiline body
body_lines_lst, depth = [rest[1:]], 1
i += 1
while i < len(lines) and depth > 0:
for j, ch in enumerate(lines[i]):
if ch == '(': depth += 1
elif ch == ')':
depth -= 1
if depth == 0:
body_lines_lst.append(lines[i][:j])
break
else: body_lines_lst.append(lines[i])
i += 1
body = '\n'.join(body_lines_lst).strip()
env[name] = ('lambda', params, body)
continue
# MEM assignment: MEM[addr].type (+|-)?= value
if first == 'mem' and toks[1].type == 'LBRACKET':
j, addr_toks = _match_bracket(toks, 1)
addr = parse_tokens(addr_toks, env, funcs)
if j < len(toks) and toks[j].type == 'DOT': j += 1
dt_name = toks[j].val if j < len(toks) and toks[j].type == 'IDENT' else 'u32'
dt, j = DTYPES.get(dt_name, dtypes.uint32), j + 1
compound_op = None
if j < len(toks) and toks[j].type == 'ASSIGN_OP':
compound_op = toks[j].val
j += 1
elif j < len(toks) and toks[j].type == 'EQUALS': j += 1
rhs = parse_tokens(toks[j:], env, funcs)
if compound_op:
mem = env.get('_vmem') if '_vmem' in env else env.get('_lds')
if isinstance(mem, UOp):
adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32
idx = (addr >> _const(adt, 2)).cast(dtypes.int)
old = mem.index(idx)
if dt in (dtypes.uint64, dtypes.int64, dtypes.float64):
old = old.cast(dtypes.uint64) | (mem.index(((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)).cast(dtypes.uint64) << _u64(32))
rhs = (old + rhs) if compound_op == '+=' else (old - rhs)
if assigns is not None: assigns.append((f'MEM[{_tok_str(addr_toks)}].{dt_name}', (addr, rhs)))
i += 1
continue
# VGPR assignment: VGPR[lane][reg] = value or VGPR[lane][reg][hi:lo].type = { ... }
if first == 'vgpr' and toks[1].type == 'LBRACKET':
j, lane_toks = _match_bracket(toks, 1)
if j < len(toks) and toks[j].type == 'LBRACKET':
j, reg_toks = _match_bracket(toks, j)
# Check for bit-slice: VGPR[lane][reg][hi:lo].type = value (read-modify-write)
if j < len(toks) and toks[j].type == 'LBRACKET':
j, slice_toks = _match_bracket(toks, j)
slice_str = _tok_str(slice_toks)
hi_str, lo_str = slice_str.split(':')
hi_val, lo_val = int(eval(hi_str.strip())), int(eval(lo_str.strip()))
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln = parse_tokens(lane_toks, env, funcs)
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
ws = env.get('_wave_size', 32)
vgpr_idx = _to_u32(rg) * _u32(ws) + _to_u32(ln)
if assigns is not None:
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}][{hi_val}:{lo_val}]', (vgpr_idx, val, _u32(hi_val), _u32(lo_val))))
i += 1
continue
if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
ln = parse_tokens(lane_toks, env, funcs)
rg, val = parse_tokens(reg_toks, env, funcs), parse_tokens(toks[j:], env, funcs)
if assigns is not None:
ws = env.get('_wave_size', 32)
assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(ws) + _to_u32(ln), val)))
i += 1
continue
# Compound destination: {hi.type, lo.type} = value
if first == '{':
j = 1
if j+2 < len(toks) and toks[j].type == 'IDENT' and toks[j+1].type == 'DOT':
hi_var, hi_type = toks[j].val, toks[j+2].val
j += 3
if j < len(toks) and toks[j].type == 'COMMA': j += 1
if j+2 < len(toks) and toks[j].type == 'IDENT' and toks[j+1].type == 'DOT':
lo_var, lo_type = toks[j].val, toks[j+2].val
j += 3
if j < len(toks) and toks[j].type == 'RBRACE': j += 1
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
val = parse_tokens(toks[j:], env, funcs)
lo_dt, hi_dt = DTYPES.get(lo_type, dtypes.uint64), DTYPES.get(hi_type, dtypes.uint32)
lo_bits = 64 if lo_dt in (dtypes.uint64, dtypes.int64) else 32
lo_val = val.cast(lo_dt) if val.dtype.itemsize * 8 <= lo_bits else (val & _const(val.dtype, (1 << lo_bits) - 1)).cast(lo_dt)
hi_val = (val >> _const(val.dtype, lo_bits)).cast(hi_dt)
block_assigns[lo_var] = env[lo_var] = lo_val
block_assigns[hi_var] = env[hi_var] = hi_val
if assigns is not None: assigns.extend([(f'{lo_var}.{lo_type}', lo_val), (f'{hi_var}.{hi_type}', hi_val)])
i += 1
continue
# Bit slice/index: var[hi:lo] = value, var.type[hi:lo] = value, or var[expr] = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and \
(toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')):
bracket_start = 2 if toks[1].type == 'LBRACKET' else 4
j = bracket_start
colon_pos = None
while j < len(toks) and toks[j].type != 'RBRACKET':
if toks[j].type == 'COLON': colon_pos = j
j += 1
var = toks[0].val
if colon_pos is not None: # bit slice: var[hi:lo]
hi_str = ' '.join(t.val for t in toks[bracket_start:colon_pos] if t.type != 'EOF')
lo_str = ' '.join(t.val for t in toks[colon_pos+1:j] if t.type != 'EOF')
try:
hi_val, lo_val = int(eval(hi_str)), int(eval(lo_str))
hi, lo = max(hi_val, lo_val), min(hi_val, lo_val)
j += 1
if j < len(toks) and toks[j].type == 'DOT': j += 2
if j < len(toks) and toks[j].type == 'EQUALS': j += 1
val = parse_tokens(toks[j:], env, funcs)
dt_suffix = toks[2].val if toks[1].type == 'DOT' else None
if assigns is not None: assigns.append((f'{var}[{hi}:{lo}]' + (f'.{dt_suffix}' if dt_suffix else ''), val))
if var not in env: env[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0)
old = block_assigns.get(var, env.get(var))
assert isinstance(old, UOp)
block_assigns[var] = env[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo)
i += 1
continue
except Exception: pass
elif toks[1].type == 'LBRACKET': # bit index: var[expr] (only for var[...], not var.type[...])
existing = block_assigns.get(var, env.get(var))
if existing is not None and isinstance(existing, UOp) and \
not any(f'{var}{k}' in env or f'{var}{k}' in block_assigns for k in range(8)):
bit_toks = toks[2:j]
j += 1
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
block_assigns[var] = env[var] = _set_bit(
existing, _to_u32(parse_tokens(bit_toks, env, funcs)), parse_tokens(toks[j+1:], env, funcs))
i += 1
continue
# Array element: var[idx] = value (static index) or var[expr] = value (dynamic)
if len(toks) >= 4 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET':
var = toks[0].val
j, idx_toks = _match_bracket(toks, 1)
if j < len(toks) and toks[j].type == 'EQUALS':
# Static index: var[NUM] = value
if len(idx_toks) == 1 and idx_toks[0].type == 'NUM':
idx = int(idx_toks[0].val.rstrip('UuLl'))
val = parse_tokens(toks[j+1:], env, funcs)
existing = block_assigns.get(var, env.get(var))
if existing is not None and isinstance(existing, UOp):
block_assigns[var] = env[var] = _set_bit(existing, _u32(idx), val)
else:
block_assigns[f'{var}@{idx}'] = env[f'{var}@{idx}'] = val
i += 1
continue
# Dynamic index: var[expr] = value where var has @-elements
elems = [(k.split('@')[1], v) for k, v in {**env, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)]
if elems:
idx_expr = parse_tokens(idx_toks, env, funcs)
val = parse_tokens(toks[j+1:], env, funcs)
for elem_idx_str, old_elem in elems:
elem_idx = int(elem_idx_str)
cond = _to_u32(idx_expr).eq(_u32(elem_idx))
new_val = cond.where(val.cast(old_elem.dtype) if val.dtype != old_elem.dtype else val, old_elem)
block_assigns[f'{var}@{elem_idx}'] = env[f'{var}@{elem_idx}'] = new_val
i += 1
continue
# Compound assignment: var += or var -=
assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None)
if assign_op is not None:
var = toks[0].val
old = block_assigns.get(var, env.get(var, _u32(0)))
rhs = parse_tokens(toks[assign_op+1:], env, funcs)
if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype)
block_assigns[var] = env[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs)
i += 1
continue
# Typed element: var.type[idx] = value
if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and \
toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM':
var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val)
dt = DTYPES.get(dt_name, dtypes.uint32)
j = 6
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val, old = parse_tokens(toks[j+1:], env, funcs), block_assigns.get(var, env.get(var, _u32(0)))
bw = dt.itemsize * 8
block_assigns[var] = env[var] = _set_bits(old, val, bw, idx * bw)
if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val))
i += 1
continue
# Dynamic bit: var.type[expr_with_brackets] = value
if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and \
toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET':
j, depth, has_inner = 4, 1, False
while j < len(toks) and depth > 0:
if toks[j].type == 'LBRACKET':
depth += 1
has_inner = True
elif toks[j].type == 'RBRACKET': depth -= 1
j += 1
if has_inner:
var = toks[0].val
bit_pos = _to_u32(parse_tokens(toks[4:j-1], env, funcs))
while j < len(toks) and toks[j].type != 'EQUALS': j += 1
if j < len(toks):
val = parse_tokens(toks[j+1:], env, funcs)
old = block_assigns.get(var, env.get(var, _u32(0)))
block_assigns[var] = env[var] = _set_bit(old, bit_pos, val)
i += 1
continue
# If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64)
if first == 'if':
def parse_cond(s, kw):
ll = s.lower()
return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), env, funcs))
def is_const(c, v): return c.op == Ops.CONST and c.arg is v
cond = parse_cond(line, 'if')
conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not is_const(cond, False) else []
branch_assigns: list[tuple[UOp, list]] = [] # (cond, assigns_list) for side-effect merging
else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {})
else_side_effects: list = []
env_snap = dict(env)
static_true = is_const(cond, True) # track if any condition is statically true
i += 1
if_side: list = [] if assigns is not None and not is_const(cond, False) else []
i, branch, ret = parse_block(lines, i, env, funcs, if_side if assigns is not None and not is_const(cond, False) else None)
if conditions: conditions[0] = (cond, ret if ret is not None else branch)
if assigns is not None and not is_const(cond, False): branch_assigns.append((cond, if_side))
env.clear()
env.update(env_snap)
while i < len(lines):
ltoks = tokenize(lines[i])
if ltoks[0].type != 'IDENT': break
lf = ltoks[0].val.lower()
if lf == 'elsif':
c = parse_cond(lines[i], 'elsif')
take = not static_true and not is_const(c, False)
i += 1
br_side: list = [] if assigns is not None and take else []
i, branch, ret = parse_block(lines, i, env, funcs, br_side if assigns is not None and take else None)
if take:
conditions.append((c, ret if ret is not None else branch))
if is_const(c, True): static_true = True
if assigns is not None: branch_assigns.append((c, br_side))
env.clear()
env.update(env_snap)
elif lf == 'else':
i += 1
el_side: list = [] if assigns is not None and not static_true else []
i, branch, ret = parse_block(lines, i, env, funcs, el_side if assigns is not None and not static_true else None)
if not static_true:
else_branch = (ret, branch)
if assigns is not None: else_side_effects = el_side
env.clear()
env.update(env_snap)
elif lf == 'endif':
i += 1
break
else: break
# Check if any branch returned a value (lambda-style)
if any(isinstance(br, UOp) for _, br in conditions):
result = else_branch[0]
for c, rv in reversed(conditions):
if isinstance(rv, UOp) and isinstance(result, UOp):
if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype)
result = c.where(rv, result)
return i, block_assigns, result
# If statically true, use that branch directly; otherwise merge with WHERE
if static_true:
ba = next((b for c, b in conditions if is_const(c, True) and isinstance(b, dict)), {})
block_assigns.update(ba)
env.update(ba)
# For static true, forward side effects unconditionally
if assigns is not None:
for bc, bse in branch_assigns:
if is_const(bc, True): assigns.extend(bse)
else:
else_assigns = else_branch[1]
all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys())
for var in all_vars:
res: Any = else_assigns.get(var, block_assigns.get(var, env.get(var, _u32(0))))
for cond, ba in reversed(conditions): # type: ignore[assignment]
if isinstance(ba, dict) and var in ba:
tv = ba[var]
if isinstance(tv, UOp) and isinstance(res, UOp):
res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res)
block_assigns[var] = env[var] = res
# Merge side effects from branches with conditions
if assigns is not None:
def _cond_side_effect(cnd, dest, val):
if isinstance(val, tuple) and len(val) == 4: # VGPR bit-slice: (idx, rhs, hi, lo) -> add condition
return (dest, (val[0], val[1], val[2], val[3], cnd))
if isinstance(val, tuple) and len(val) == 2: # VGPR/MEM write: (addr, rhs) -> condition rhs
return (dest, (val[0], cnd.where(val[1], val[1])))
return (dest, val)
# Build combined condition: each branch fires when its cond is true AND no earlier cond was true
remaining = UOp.const(dtypes.bool, True)
for bc, bse in branch_assigns:
effective = remaining & bc if remaining.op != Ops.CONST else bc
for dest, val in bse: assigns.append(_cond_side_effect(effective, dest, val))
remaining = remaining & bc.logical_not() if remaining.op != Ops.CONST else bc.logical_not()
for dest, val in else_side_effects: assigns.append(_cond_side_effect(remaining, dest, val))
continue
# Regular assignment: var = value
for j, t in enumerate(toks):
if t.type == 'EQUALS':
if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break
base_var = toks[0].val
block_assigns[base_var] = env[base_var] = parse_tokens(toks[j+1:], env, funcs)
i += 1
break
else: i += 1
return i, block_assigns, None
def parse_expr(expr: str, env: dict[str, VarVal], funcs: dict | None = None) -> UOp:
return parse_tokens(tokenize(expr.strip().rstrip(';')), env, funcs)