# Tokenizer-based expression parser for AMD pcode from typing import Any, Callable from tinygrad.dtype import dtypes from tinygrad.uop.ops import Ops, UOp # Type alias for vars dict: stores UOps for variables and tuples for lambda definitions VarVal = UOp | tuple[str, list[str], str] def _const(dt, v): return UOp.const(dt, v) def _u32(v): return _const(dtypes.uint32, v) def _u64(v): return _const(dtypes.uint64, v) def _to_u32(v): return v if v.dtype == dtypes.uint32 else v.bitcast(dtypes.uint32) if v.dtype.itemsize == 4 else v.cast(dtypes.uint32) def _to_bool(v): return v if v.dtype == dtypes.bool else v.ne(_const(v.dtype, 0)) def _cast_to(v, dt): if v.dtype == dt: return v if dt == dtypes.half: return v.cast(dtypes.uint16).bitcast(dtypes.half) return v.cast(dt) if dt.itemsize != v.dtype.itemsize else v.bitcast(dt) # Float bit extraction - returns (bits, exp_mask, mant_mask, quiet_bit, exp_shift) based on float type def _float_info(v: UOp) -> tuple[UOp, UOp, UOp, UOp, int]: if v.dtype in (dtypes.float64, dtypes.uint64): bits = v.bitcast(dtypes.uint64) if v.dtype == dtypes.float64 else v.cast(dtypes.uint64) return bits, _u64(0x7FF0000000000000), _u64(0x000FFFFFFFFFFFFF), _u64(0x0008000000000000), 52 if v.dtype in (dtypes.half, dtypes.uint16): bits = (v.bitcast(dtypes.uint16) if v.dtype == dtypes.half else (v & _u32(0xFFFF)).cast(dtypes.uint16)).cast(dtypes.uint32) return bits, _u32(0x7C00), _u32(0x03FF), _u32(0x0200), 10 bits = v.bitcast(dtypes.uint32) if v.dtype == dtypes.float32 else v.cast(dtypes.uint32) return bits, _u32(0x7F800000), _u32(0x007FFFFF), _u32(0x00400000), 23 def _isnan(v: UOp) -> UOp: bits, exp_m, mant_m, _, _ = _float_info(v.cast(dtypes.float32) if v.dtype == dtypes.half else v) return (bits & exp_m).eq(exp_m) & (bits & mant_m).ne(_const(bits.dtype, 0)) def _bitreverse(v: UOp, bits: int) -> UOp: dt, masks = (dtypes.uint64, [(0x5555555555555555,1),(0x3333333333333333,2),(0x0F0F0F0F0F0F0F0F,4),(0x00FF00FF00FF00FF,8),(0x0000FFFF0000FFFF,16)]) \ if bits == 64 else (dtypes.uint32, [(0x55555555,1),(0x33333333,2),(0x0F0F0F0F,4),(0x00FF00FF,8)]) v = v.cast(dt) if v.dtype != dt else v for m, s in masks: v = ((v >> _const(dt, s)) & _const(dt, m)) | ((v & _const(dt, m)) << _const(dt, s)) return (v >> _const(dt, 32 if bits == 64 else 16)) | (v << _const(dt, 32 if bits == 64 else 16)) def _extract_bits(val: UOp, hi: int, lo: int) -> UOp: dt = dtypes.uint64 if val.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32 return ((val >> _const(dt, lo)) if lo > 0 else val) & _const(val.dtype, (1 << (hi - lo + 1)) - 1) def _set_bit(old, pos, val): mask = _u32(1) << pos return (old & (mask ^ _u32(0xFFFFFFFF))) | ((val.cast(dtypes.uint32) & _u32(1)) << pos) def _val_to_bits(val): if val.dtype == dtypes.half: return val.bitcast(dtypes.uint16).cast(dtypes.uint32) if val.dtype == dtypes.float32: return val.bitcast(dtypes.uint32) if val.dtype == dtypes.float64: return val.bitcast(dtypes.uint64) return val if val.dtype == dtypes.uint32 else val.cast(dtypes.uint32) def _floor(x): t = UOp(Ops.TRUNC, x.dtype, (x,)); return ((x < _const(x.dtype, 0)) & x.ne(t)).where(t - _const(x.dtype, 1), t) def _f16_extract(v): return (v & _u32(0xFFFF)).cast(dtypes.uint16).bitcast(dtypes.half) if v.dtype == dtypes.uint32 else v def _check_nan(v: UOp, quiet: bool) -> UOp: if v.op == Ops.CAST and v.dtype == dtypes.float64: v = v.src[0] bits, exp_m, mant_m, qb, _ = _float_info(v) is_nan_exp, has_mant, is_q = (bits & exp_m).eq(exp_m), (bits & mant_m).ne(_const(bits.dtype, 0)), (bits & qb).ne(_const(bits.dtype, 0)) return (is_nan_exp & is_q) if quiet else (is_nan_exp & has_mant & is_q.logical_not()) def _minmax_reduce(is_max: bool, dt, *args: UOp) -> UOp: def cast(v: UOp) -> UOp: return v.bitcast(dt) if dt == dtypes.float32 and v.dtype == dtypes.uint32 else v.cast(dt) def minmax(a: UOp, b: UOp) -> UOp: if dt in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64): return (a > b).where(a, b) if is_max else (a < b).where(a, b) return a.maximum(b) if is_max else a.minimum(b) result = cast(args[0]) for a in args[1:]: b = cast(a) if dt == dtypes.float32: result = _isnan(result).where(b, _isnan(b).where(result, minmax(result, b))) else: result = minmax(result, b) return result def _find_two_pi_mul(x): if x.op != Ops.MUL or len(x.src) != 2: return None for i, s in enumerate(x.src): if s.op == Ops.CONST and abs(s.arg - 6.283185307179586) < 1e-5: return (x.src[1-i], 6.283185307179586) if s.op == Ops.MUL and len(s.src) == 2: vals = [ss.arg for ss in s.src if ss.op == Ops.CONST] + [ss.src[0].arg for ss in s.src if ss.op == Ops.CAST and ss.src[0].op == Ops.CONST] if len(vals) == 2 and abs(vals[0] * vals[1] - 6.283185307179586) < 1e-5: return (x.src[1-i], vals[0] * vals[1]) return None def _trig_reduce(x, phase=0.0): match = _find_two_pi_mul(x) if match is not None: turns, two_pi = match if phase: turns = turns + _const(turns.dtype, phase) n = _floor(turns + _const(turns.dtype, 0.5)) return UOp(Ops.SIN, turns.dtype, ((turns - n) * _const(turns.dtype, two_pi),)) if phase: x = x + _const(x.dtype, phase * 6.283185307179586) n = _floor(x * _const(x.dtype, 0.15915494309189535) + _const(x.dtype, 0.5)) return UOp(Ops.SIN, x.dtype, (x - n * _const(x.dtype, 6.283185307179586),)) def _signext(val: UOp) -> UOp: for bits, mask, ext in [(4, 0xF, 0xFFFFFFF0), (8, 0xFF, 0xFFFFFF00), (16, 0xFFFF, 0xFFFF0000)]: if (val.op == Ops.AND and len(val.src) == 2 and val.src[1].op == Ops.CONST and val.src[1].arg == mask) or val.dtype.itemsize == bits // 8: v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val sb = (v32 >> _u32(bits - 1)) & _u32(1) return sb.ne(_u32(0)).where(v32 | _u32(ext), v32).cast(dtypes.int) return val.cast(dtypes.int64) if val.dtype in (dtypes.int, dtypes.int32) else val def _signext_4bit(val: UOp) -> UOp: """Sign extend a 4-bit value to 32-bit signed integer.""" v32 = val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val sb = (v32 >> _u32(3)) & _u32(1) # sign bit at position 3 return sb.ne(_u32(0)).where(v32 | _u32(0xFFFFFFF0), v32).bitcast(dtypes.int) def _abs(val: UOp) -> UOp: if val.dtype not in (dtypes.float32, dtypes.float64, dtypes.half): return val _, _, _, _, shift = _float_info(val) sign_mask = {10: 0x7FFF, 23: 0x7FFFFFFF, 52: 0x7FFFFFFFFFFFFFFF}[shift] bt, ft = {10: (dtypes.uint16, dtypes.half), 23: (dtypes.uint32, dtypes.float32), 52: (dtypes.uint64, dtypes.float64)}[shift] return (val.bitcast(bt) & _const(bt, sign_mask)).bitcast(ft) def _f_to_u(f, dt): return UOp(Ops.TRUNC, f.dtype, ((f < _const(f.dtype, 0.0)).where(_const(f.dtype, 0.0), f),)).cast(dt) def _cvt_quiet(val: UOp) -> UOp: bits, _, _, qb, _ = _float_info(val) bt, ft = (dtypes.uint64, dtypes.float64) if val.dtype == dtypes.float64 else (dtypes.uint16, dtypes.half) if val.dtype == dtypes.half else (dtypes.uint32, dtypes.float32) return (val.bitcast(bt) | qb).bitcast(ft) def _is_denorm(val: UOp) -> UOp: bits, exp_m, mant_m, _, _ = _float_info(val) return (bits & exp_m).eq(_const(bits.dtype, 0)) & (bits & mant_m).ne(_const(bits.dtype, 0)) _EXP_BITS = {10: 0x1F, 23: 0xFF, 52: 0x7FF} def _get_exp(bits: UOp, shift: int) -> UOp: return ((bits >> _const(bits.dtype, shift)) & _const(bits.dtype, _EXP_BITS[shift])).cast(dtypes.int) def _exponent(val: UOp) -> UOp: bits, _, _, _, shift = _float_info(val) return _get_exp(bits, shift) def _div_would_be_denorm(a: UOp, b: UOp) -> UOp: bits_n, _, _, _, shift = _float_info(a) bits_d, _, _, _, _ = _float_info(b) min_exp = {10: -14, 23: -126, 52: -1022}[shift] return (_get_exp(bits_n, shift) - _get_exp(bits_d, shift)) < _const(dtypes.int, min_exp) def _sign(val: UOp) -> UOp: bits, _, _, _, shift = _float_info(val) sign_shift = {10: 15, 23: 31, 52: 63}[shift] return ((bits >> _const(bits.dtype, sign_shift)) & _const(bits.dtype, 1)).cast(dtypes.uint32) def _signext_from_bit(val: UOp, w: UOp) -> UOp: is_64bit = val.dtype in (dtypes.uint64, dtypes.int64) dt = dtypes.uint64 if is_64bit else dtypes.uint32 mask_all = _const(dt, 0xFFFFFFFFFFFFFFFF if is_64bit else 0xFFFFFFFF) one = _const(dt, 1) val_u = val.cast(dt) if val.dtype != dt else val w_val = w.cast(dt) if w.dtype != dt else w sign_bit = (val_u >> (w_val - one)) & one ext_mask = ((one << w_val) - one) ^ mask_all return sign_bit.ne(_const(dt, 0)).where(val_u | ext_mask, val_u) def _ldexp(val: UOp, exp: UOp) -> UOp: if val.dtype == dtypes.uint32: val = val.bitcast(dtypes.float32) elif val.dtype == dtypes.uint64: val = val.bitcast(dtypes.float64) if exp.dtype in (dtypes.uint32, dtypes.uint64): exp = exp.cast(dtypes.int if exp.dtype == dtypes.uint32 else dtypes.int64) return val * UOp(Ops.EXP2, val.dtype, (exp.cast(val.dtype),)) def _frexp_mant(val: UOp) -> UOp: val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) & _u32(0x807FFFFF)) | _u32(0x3f000000)).bitcast(dtypes.float32) return ((val.bitcast(dtypes.uint64) & _const(dtypes.uint64, 0x800FFFFFFFFFFFFF)) | _const(dtypes.uint64, 0x3fe0000000000000)).bitcast(dtypes.float64) def _frexp_exp(val: UOp) -> UOp: val = val.bitcast(dtypes.float32) if val.dtype == dtypes.uint32 else val.bitcast(dtypes.float64) if val.dtype == dtypes.uint64 else val if val.dtype == dtypes.float32: return ((val.bitcast(dtypes.uint32) >> _u32(23)) & _u32(0xFF)).cast(dtypes.int) - _const(dtypes.int, 126) return ((val.bitcast(dtypes.uint64) >> _const(dtypes.uint64, 52)) & _const(dtypes.uint64, 0x7FF)).cast(dtypes.int) - _const(dtypes.int, 1022) TWO_OVER_PI = 0x0145f306dc9c882a53f84eafa3ea69bb81b6c52b3278872083fca2c757bd778ac36e48dc74849ba5c00c925dd413a32439fc3bd63962534e7dd1046bea5d768909d338e04d68befc827323ac7306a673e93908bf177bf250763ff12fffbc0b301fde5e2316b414da3eda6cfd9e4f96136e9e8c7ecd3cbfd45aea4f758fd7cbe2f67a0e73ef14a525d4d7f6bf623f1aba10ac06608df8f6 # TWO_OVER_PI as 19 u64 words for trig_preop_result (word[0] = bits 0-63, word[18] = bits 1152-1200) _PREOP_WORDS = tuple((TWO_OVER_PI >> (64 * i)) & 0xFFFFFFFFFFFFFFFF for i in range(19)) def _trig_preop(val: UOp) -> UOp: # Extract 53 bits from position (1148 - shift) in the 1201-bit 2/PI constant # Using word-based selection: 19 conditions instead of 1149 shift = val.cast(dtypes.uint32) bit_pos = _u32(1148) - shift # starting bit position from LSB word_idx = bit_pos >> _u32(6) # // 64 bit_off = bit_pos & _u32(63) # % 64 # Select lo_word and hi_word using shared conditions lo_word, hi_word = _u64(_PREOP_WORDS[18]), _u64(0) for i in range(17, -1, -1): cond = word_idx.eq(_u32(i)) lo_word = cond.where(_u64(_PREOP_WORDS[i]), lo_word) hi_word = cond.where(_u64(_PREOP_WORDS[i + 1]), hi_word) # Combine and extract 53 bits: ((lo >> bit_off) | (hi << (64 - bit_off))) & mask bit_off_64 = bit_off.cast(dtypes.uint64) result = ((lo_word >> bit_off_64) | (hi_word << (_u64(64) - bit_off_64))) & _u64(0x1fffffffffffff) return result.cast(dtypes.float64) def _ff1(val: UOp, bits: int) -> UOp: dt = dtypes.uint64 if bits == 64 else dtypes.uint32 val = val.cast(dt) if val.dtype != dt else val result = _const(dtypes.int, -1) for i in range(bits): cond = ((val >> _const(dt, i)) & _const(dt, 1)).ne(_const(dt, 0)) & result.eq(_const(dtypes.int, -1)) result = cond.where(_const(dtypes.int, i), result) return result def _sad_u8(a: UOp, b: UOp, acc: UOp, masked: bool = False) -> UOp: """Sum of absolute differences of 4 unsigned bytes + accumulator. If masked, skips bytes where a == 0.""" a, b, acc = a.cast(dtypes.uint32), b.cast(dtypes.uint32), acc.cast(dtypes.uint32) result = acc for i in range(4): a_byte = (a >> _u32(i * 8)) & _u32(0xFF) b_byte = (b >> _u32(i * 8)) & _u32(0xFF) diff = (a_byte > b_byte).where(a_byte - b_byte, b_byte - a_byte) result = result + (a_byte.ne(_u32(0)).where(diff, _u32(0)) if masked else diff) return result _FUNCS: dict[str, Callable[..., UOp]] = { 'sqrt': lambda a: UOp(Ops.SQRT, a.dtype, (a,)), 'trunc': lambda a: UOp(Ops.TRUNC, a.dtype, (a,)), 'log2': lambda a: UOp(Ops.LOG2, a.dtype, (a,)), 'sin': lambda a: _trig_reduce(a), 'cos': lambda a: _trig_reduce(a, 0.25), 'floor': _floor, 'fract': lambda a: a - _floor(a), 'signext': _signext, 'abs': _abs, 'isEven': lambda a: (UOp(Ops.TRUNC, a.dtype, (a,)).cast(dtypes.int) & _const(dtypes.int, 1)).eq(_const(dtypes.int, 0)), 'max': lambda a, b: UOp(Ops.MAX, a.dtype, (a, b)), 'min': lambda a, b: UOp(Ops.MAX, a.dtype, (a.neg(), b.neg())).neg(), 'pow': lambda a, b: UOp(Ops.EXP2, dtypes.float32, (b.bitcast(dtypes.float32),)), 'fma': lambda a, b, c: a * b + c, 'i32_to_f32': lambda a: a.cast(dtypes.int).cast(dtypes.float32), 'u32_to_f32': lambda a: a.cast(dtypes.uint32).cast(dtypes.float32), 'f32_to_i32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int), 'f32_to_u32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint32), 'f64_to_i32': lambda a: UOp(Ops.TRUNC, dtypes.float64, (a.bitcast(dtypes.float64),)).cast(dtypes.int), 'f64_to_u32': lambda a: _f_to_u(a.bitcast(dtypes.float64), dtypes.uint32), 'f16_to_f32': lambda a: _f16_extract(a).cast(dtypes.float32), 'f32_to_f16': lambda a: a.cast(dtypes.half), 'f32_to_f64': lambda a: a.bitcast(dtypes.float32).cast(dtypes.float64), 'f64_to_f32': lambda a: a.bitcast(dtypes.float64).cast(dtypes.float32), 'i32_to_f64': lambda a: a.cast(dtypes.int).cast(dtypes.float64), 'u32_to_f64': lambda a: a.cast(dtypes.uint32).cast(dtypes.float64), 'f16_to_i16': lambda a: UOp(Ops.TRUNC, dtypes.half, (_f16_extract(a),)).cast(dtypes.int16), 'f16_to_u16': lambda a: UOp(Ops.TRUNC, dtypes.half, (_f16_extract(a),)).cast(dtypes.uint16), 'i16_to_f16': lambda a: a.cast(dtypes.int16).cast(dtypes.half), 'u16_to_f16': lambda a: a.cast(dtypes.uint16).cast(dtypes.half), 'bf16_to_f32': lambda a: (((a.cast(dtypes.uint32) if a.dtype != dtypes.uint32 else a) & _u32(0xFFFF)) << _u32(16)).bitcast(dtypes.float32), 'isNAN': _isnan, 'isSignalNAN': lambda a: _check_nan(a, False), 'isQuietNAN': lambda a: _check_nan(a, True), 'cvtToQuietNAN': _cvt_quiet, 'isDENORM': _is_denorm, 'exponent': _exponent, 'divWouldBeDenorm': _div_would_be_denorm, 'sign': _sign, 'signext_from_bit': _signext_from_bit, 'ldexp': _ldexp, 'frexp_mant': _frexp_mant, 'mantissa': _frexp_mant, 'frexp_exp': _frexp_exp, 'trig_preop_result': _trig_preop, 's_ff1_i32_b32': lambda a: _ff1(a, 32), 's_ff1_i32_b64': lambda a: _ff1(a, 64), # Normalization conversions: map [-1,1] or [0,1] to integer range # Use floor(x + 0.5) for round-to-nearest # SNORM: round(value * 32767), range is [-32767, 32767] (hardware behavior) 'f16_to_snorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16), 'f16_to_unorm': lambda a: _floor(_f16_extract(a).cast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16), 'f32_to_snorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 32767) + _const(dtypes.float32, 0.5)).cast(dtypes.int).cast(dtypes.int16), 'f32_to_unorm': lambda a: _floor(a.bitcast(dtypes.float32) * _const(dtypes.float32, 65535) + _const(dtypes.float32, 0.5)).cast(dtypes.uint16), 'f32_to_u8': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint8), # Integer truncation conversions 'i32_to_i16': lambda a: a.cast(dtypes.int).cast(dtypes.int16), 'u32_to_u16': lambda a: a.cast(dtypes.uint32).cast(dtypes.uint16), 'u16_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFFFF)), 'u8_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xFF)), 'u4_to_u32': lambda a: (a.cast(dtypes.uint32) & _u32(0xF)), # Signed extraction with sign extension for dot products 'i16_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFFFF)), 'i8_to_i32': lambda a: _signext(a.cast(dtypes.uint32) & _u32(0xFF)), 'i4_to_i32': lambda a: _signext_4bit(a.cast(dtypes.uint32) & _u32(0xF)), # Float to int16 conversions 'v_cvt_i16_f32': lambda a: UOp(Ops.TRUNC, dtypes.float32, (a.bitcast(dtypes.float32),)).cast(dtypes.int16), 'v_cvt_u16_f32': lambda a: _f_to_u(a.bitcast(dtypes.float32), dtypes.uint16), # SAD (Sum of Absolute Differences) - sum |a_i - b_i| for 4 bytes + accumulator 'v_sad_u8': lambda a, b, c: _sad_u8(a, b, c), 'v_msad_u8': lambda a, b, c: _sad_u8(a, b, c, masked=True), # System NOPs - these are scheduling hints, no effect on emulation 'MIN': lambda a, b: (a < b).where(a, b), 's_nop': lambda a: _u32(0), # Address calculation for memory operations 'CalcDsAddr': lambda a, o, *r: a.cast(dtypes.uint32) + o.cast(dtypes.uint32), 'CalcGlobalAddr': lambda v, s, *r: v.cast(dtypes.uint64) + s.cast(dtypes.uint64), } for is_max, name in [(False, 'min'), (True, 'max')]: for dt, sfx in [(dtypes.float32, 'f32'), (dtypes.int, 'i32'), (dtypes.uint32, 'u32'), (dtypes.int16, 'i16'), (dtypes.uint16, 'u16')]: _FUNCS[f'v_{name}_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a) _FUNCS[f'v_{name}3_{sfx}'] = lambda *a, im=is_max, d=dt: _minmax_reduce(im, d, *a) # f16 min/max/min3/max3/med3 for is_max, name in [(False, 'min'), (True, 'max')]: _FUNCS[f'v_{name}_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) _FUNCS[f'v_{name}3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) _FUNCS[f'v_{name}_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) _FUNCS[f'v_{name}_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) _FUNCS[f'v_{name}3_num_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) _FUNCS[f'v_{name}3_num_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) _FUNCS[f'v_{name}imum_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) _FUNCS[f'v_{name}imum_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) _FUNCS[f'v_{name}imum3_f16'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.half, *[_f16_extract(x) for x in a]) _FUNCS[f'v_{name}imum3_f32'] = lambda *a, im=is_max: _minmax_reduce(im, dtypes.float32, *a) # ═══════════════════════════════════════════════════════════════════════════════ # TOKENIZER/PARSER # ═══════════════════════════════════════════════════════════════════════════════ DTYPES = {'u32': dtypes.uint32, 'i32': dtypes.int, 'f32': dtypes.float32, 'b32': dtypes.uint32, 'u64': dtypes.uint64, 'i64': dtypes.int64, 'f64': dtypes.float64, 'b64': dtypes.uint64, 'u16': dtypes.uint16, 'i16': dtypes.short, 'f16': dtypes.half, 'b16': dtypes.uint16, 'u8': dtypes.uint8, 'i8': dtypes.int8, 'b8': dtypes.uint8, 'u4': dtypes.uint8, 'i4': dtypes.int8, 'u1': dtypes.uint32} _BITS_DT = {8: dtypes.uint8, 16: dtypes.uint16, 32: dtypes.uint32, 64: dtypes.uint64} _NUM_SUFFIXES = ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f') def _strip_suffix(num: str) -> tuple[str, str]: for sfx in _NUM_SUFFIXES: if num.endswith(sfx): return sfx, num[:-len(sfx)] return '', num _SINGLE_CHAR = {'(': 'LPAREN', ')': 'RPAREN', '[': 'LBRACKET', ']': 'RBRACKET', '{': 'LBRACE', '}': 'RBRACE', ':': 'COLON', ',': 'COMMA', '?': 'QUESTION', '.': 'DOT', '=': 'EQUALS', "'": 'QUOTE'} class Token: __slots__ = ('type', 'val') def __init__(self, type: str, val: str): self.type, self.val = type, val def __repr__(self): return f'{self.type}:{self.val}' def tokenize(s: str) -> list[Token]: tokens, i, n = [], 0, len(s) while i < n: c = s[i] if c.isspace(): i += 1; continue if i + 1 < n and s[i:i+2] in ('+=', '-='): tokens.append(Token('ASSIGN_OP', s[i:i+2])); i += 2; continue if i + 1 < n and s[i:i+2] in ('||', '&&', '>=', '<=', '==', '!=', '<>', '>>', '<<', '**', '+:', '-:'): tokens.append(Token('OP', s[i:i+2])); i += 2; continue if c in '|^&><+-*/~!%': tokens.append(Token('OP', c)); i += 1; continue if (t := _SINGLE_CHAR.get(c)): tokens.append(Token(t, c)); i += 1; continue if c == ';': i += 1; continue if c.isdigit() or (c == '-' and i + 1 < n and s[i+1].isdigit()): start = i if c == '-': i += 1 if i + 1 < n and s[i] == '0' and s[i+1] in 'xX': i += 2 while i < n and s[i] in '0123456789abcdefABCDEF': i += 1 else: while i < n and s[i].isdigit(): i += 1 if i < n and s[i] == '.' and i + 1 < n and s[i+1].isdigit(): i += 1 while i < n and s[i].isdigit(): i += 1 for sfx in ('ULL', 'LL', 'UL', 'U', 'L', 'F', 'f'): if s[i:i+len(sfx)] == sfx: i += len(sfx); break tokens.append(Token('NUM', s[start:i])); continue if c.isalpha() or c == '_': start = i while i < n and (s[i].isalnum() or s[i] == '_'): i += 1 tokens.append(Token('IDENT', s[start:i])); continue raise RuntimeError(f"unexpected char '{c}' at pos {i} in: {s}") tokens.append(Token('EOF', '')) return tokens class Parser: def __init__(self, tokens: list[Token], vars: dict, funcs: dict | None = None): self.tokens, self.vars, self.funcs, self.pos = tokens, vars, funcs if funcs is not None else _FUNCS, 0 def peek(self, offset=0) -> Token: return self.tokens[min(self.pos + offset, len(self.tokens) - 1)] def at(self, *types) -> bool: return self.peek().type in types def _advance(self) -> Token: tok = self.tokens[self.pos]; self.pos += 1; return tok def eat(self, type: str) -> Token: if self.peek().type != type: raise RuntimeError(f"expected {type}, got {self.peek()}") return self._advance() def try_eat(self, type: str) -> Token | None: return self._advance() if self.peek().type == type else None def try_eat_val(self, val: str, type: str) -> Token | None: return self._advance() if self.peek().type == type and self.peek().val == val else None def eat_val(self, val: str, type: str) -> Token: if self.peek().type != type or self.peek().val != val: raise RuntimeError(f"expected {type}:{val}, got {self.peek()}") return self._advance() def parse(self) -> UOp: cond = self.binop(0) if self.try_eat('QUESTION'): then_val = self.parse() self.eat('COLON') return _to_bool(cond).where(then_val, self.parse()) return cond def _apply_binop(self, left, right, op): if op in ('||', '&&', '|', '^', '&'): left, right = self._coerce_bitwise(left, right) elif op in ('>=', '<=', '>', '<', '==', '!=', '<>', '>>', '<<'): left, right = self._coerce_cmp(left, right) elif left.dtype != right.dtype: right = right.cast(left.dtype) match op: case '||' | '|': return left | right case '&&' | '&': return left & right case '^': return left ^ right case '==' | '<>': return left.eq(right) if op == '==' else left.ne(right) case '!=' : return left.ne(right) case '>=' | '<=' | '>' | '<': return self._cmp_nan(left, right, {'>=':(lambda a,b:a>=b),'<=':(lambda a,b:a<=b),'>':(lambda a,b:a>b),'<':(lambda a,b:a>' | '<<': return (left >> right) if op == '>>' else (left << right) case '+' | '-': if op == '-' and left.op == Ops.CONST and right.op == Ops.CONST: return _const(left.dtype, left.arg - right.arg) return (left + right) if op == '+' else (left - right) case '*' | '/': return (left * right) if op == '*' else (left / right) case '**': return UOp(Ops.EXP2, left.dtype, (right.cast(left.dtype),)) if left.op == Ops.CONST and left.arg == 2.0 else left _PREC = [('||',), ('&&',), ('|',), ('^',), ('&',), ('==', '!=', '<>'), ('>=', '<=', '>', '<'), ('>>', '<<'), ('+', '-'), ('*', '/'), ('**',)] def binop(self, prec: int) -> UOp: if prec >= len(self._PREC): return self.unary() left = self.binop(prec + 1) ops = self._PREC[prec] while self.at('OP') and self.peek().val in ops: op = self.eat('OP').val left = self._apply_binop(left, self.binop(prec + 1), op) return left def unary(self) -> UOp: if self.try_eat_val('~', 'OP'): inner = self.unary() return inner ^ _const(inner.dtype, (1 << (inner.dtype.itemsize * 8)) - 1) if self.try_eat_val('!', 'OP'): inner = self.unary() return inner.eq(_const(inner.dtype, 0)) if self.try_eat_val('-', 'OP'): inner = self.unary() if inner.op == Ops.CONST: return _const(dtypes.int if inner.dtype == dtypes.uint32 else inner.dtype, -inner.arg) return inner.neg() if self.try_eat_val('+', 'OP'): return self.unary() return self.postfix() def postfix(self) -> UOp: base = self.primary() while True: if self.try_eat('DOT'): field = self.eat('IDENT').val base = self._handle_dot(base, field) elif self.at('LBRACKET'): base = self._handle_bracket(base) elif self.at('LBRACE'): base = self._handle_brace_index(base) else: break return base def primary(self) -> UOp: if self.try_eat('LPAREN'): e = self.parse() self.eat('RPAREN') return e if self.try_eat('LBRACE'): hi = self.parse() self.eat('COMMA') lo = self.parse() self.eat('RBRACE') return (hi.cast(dt:=_BITS_DT.get((s:=lo.dtype.bitsize) * 2, dtypes.uint64)) << _const(dt, s)) | lo.cast(dt) if self.at('NUM'): num = self.eat('NUM').val if self.try_eat('QUOTE'): return self._sized_literal(int(num.rstrip('ULlf'))) return self._parse_number(num) if self.at('IDENT'): name = self.eat('IDENT').val if name == 'MEM': self.eat('LBRACKET') addr = self.parse() self.eat('RBRACKET') self.eat('DOT') dt_name = self.eat('IDENT').val return self._handle_mem_load(addr, DTYPES.get(dt_name, dtypes.uint32)) if name == 'VGPR' and self.at('LBRACKET'): self.eat('LBRACKET') lane = self.parse() self.eat('RBRACKET') self.eat('LBRACKET') reg = self.parse() self.eat('RBRACKET') vgpr = self.vars.get('_vgpr') if vgpr is None: return _u32(0) return vgpr.index(_to_u32(reg) * _u32(32) + _to_u32(lane), ptr=True).load() if self.try_eat('LPAREN'): args = self._parse_args() self.eat('RPAREN') return self._call_func(name, args) if name == 'PI': return _const(dtypes.float32, 3.141592653589793) if name == 'INF': return _const(dtypes.float64, float('inf')) if name == 'NAN': return _const(dtypes.uint32, 0x7FC00000).bitcast(dtypes.float32) if name == 'UNDERFLOW_F32': return _const(dtypes.uint32, 1).bitcast(dtypes.float32) if name == 'OVERFLOW_F32': return _const(dtypes.uint32, 0x7F7FFFFF).bitcast(dtypes.float32) if name == 'UNDERFLOW_F64': return _const(dtypes.uint64, 1).bitcast(dtypes.float64) if name == 'OVERFLOW_F64': return _const(dtypes.uint64, 0x7FEFFFFFFFFFFFFF).bitcast(dtypes.float64) if name == 'WAVE32': return _const(dtypes.bool, True) if name == 'WAVE64': return _const(dtypes.bool, False) if name == 'WAVE_MODE' and self.try_eat('DOT') and self.try_eat_val('IEEE', 'IDENT'): return _u32(1) if self.try_eat('LBRACE'): idx = self.eat('NUM').val self.eat('RBRACE') # Handle VGPR{lane}[reg] - 2D array access after loop unrolling if name == 'VGPR' and self.at('LBRACKET'): self.eat('LBRACKET') reg = self.parse() self.eat('RBRACKET') vgpr = self.vars.get('_vgpr') if vgpr is None: return _u32(0) return vgpr.index(_to_u32(reg) * _u32(32) + _u32(int(idx)), ptr=True).load() elem = self.vars.get(f'{name}@{idx}', self.vars.get(f'{name}{idx}')) if elem is None: # Extract bit idx from base variable (like var[idx]) base = self.vars.get(name) assert isinstance(base, UOp), f"unknown variable: {name}{idx}" dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32 elem = (base.cast(dt) >> _const(dt, int(idx))) & _const(dt, 1) if self.try_eat('DOT'): dt_name = self.eat('IDENT').val return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32)) if self.at('LBRACKET'): return self._handle_bracket(elem, name + idx) return elem if self.at('LBRACKET') and name not in self.vars: self.eat('LBRACKET') first = self.parse() return self._handle_bracket_rest(first, _u32(0), name) if name in self.vars: v = self.vars[name] assert isinstance(v, UOp), f"expected UOp for {name}, got {type(v)}" return v raise RuntimeError(f"unknown variable: {name}") raise RuntimeError(f"unexpected token in primary: {self.peek()}") def _handle_dot(self, base, field: str) -> UOp: assert isinstance(base, UOp), f"expected UOp for dot access, got {type(base)}" if field == 'u64' and self.at('LBRACKET') and self.peek(1).type == 'IDENT' and self.peek(1).val == 'laneId': self.eat('LBRACKET') self.eat_val('laneId', 'IDENT') self.eat('RBRACKET') result = (base >> _to_u32(self.vars['laneId'])) & _u32(1) if self.try_eat('DOT'): dt_name = self.eat('IDENT').val return result.cast(DTYPES.get(dt_name, dtypes.uint32)) return result dt = DTYPES.get(field) if dt is None: return base if dt == base.dtype: return base if dt.itemsize == 2 and base.dtype.itemsize == 4: return (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16) if dt == dtypes.uint16 else (base & _const(base.dtype, 0xFFFF)).cast(dtypes.uint16).bitcast(dt) if field == 'i4': return _signext_4bit(base) return _cast_to(base, dt) def _handle_bracket(self, base, var_name: str | None = None) -> UOp: self.eat('LBRACKET') return self._handle_bracket_rest(self.parse(), base, var_name) def _handle_bracket_rest(self, first: UOp, base: UOp, var_name: str | None = None) -> UOp: if self.at('OP') and self.peek().val in ('+:', '-:'): op = self.eat('OP').val width = self.parse() self.eat('RBRACKET') if width.op == Ops.CONST: w = int(width.arg) return (base >> _to_u32(first)) & _const(base.dtype, (1 << w) - 1) return base if self.try_eat('COLON'): second = self.parse() self.eat('RBRACKET') if first.op == Ops.CONST and second.op == Ops.CONST: a, b = int(first.arg), int(second.arg) if a < b: return _bitreverse(base, b - a + 1) hi, lo = a, b if lo >= base.dtype.itemsize * 8: vn = var_name or self._find_var_name(base) if vn and f'{vn}{lo // 32}' in self.vars: base = self.vars[f'{vn}{lo // 32}'] lo, hi = lo % 32, (hi % 32) + (lo % 32) return _extract_bits(base, hi, lo) # Dynamic bit slice: (base >> lo) & ((1 << (hi - lo + 1)) - 1) dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32 hi, lo = first.cast(dt), second.cast(dt) width = hi - lo + _const(dt, 1) mask = (_const(dt, 1) << width) - _const(dt, 1) return (base.cast(dt) >> lo) & mask self.eat('RBRACKET') dt_suffix = None if self.try_eat('DOT'): dt_suffix = DTYPES.get(self.eat('IDENT').val, dtypes.uint32) if var_name is None: var_name = self._find_var_name(base) if first.op == Ops.CONST: idx = int(first.arg) # Check for array element (var@idx) if var_name and f'{var_name}@{idx}' in self.vars: v = self.vars[f'{var_name}@{idx}'] return _cast_to(v, dt_suffix) if dt_suffix else v # Bit extraction dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32 base_cast = base.cast(dt) if base.dtype != dt else base result = ((base_cast >> _const(dt, idx)) & _const(dt, 1)) return _cast_to(result, dt_suffix) if dt_suffix else result if var_name: idx_u32 = _to_u32(first) elems = [(i, self.vars[f'{var_name}@{i}']) for i in range(256) if f'{var_name}@{i}' in self.vars] if elems: result = elems[-1][1] for ei, ev in reversed(elems[:-1]): if ev.dtype != result.dtype and ev.dtype.itemsize == result.dtype.itemsize: result = result.cast(ev.dtype) elif ev.dtype != result.dtype: ev = ev.cast(result.dtype) result = idx_u32.eq(_u32(ei)).where(ev, result) return result dt = dtypes.uint64 if base.dtype in (dtypes.uint64, dtypes.int64) else dtypes.uint32 base_cast = base.cast(dt) if base.dtype != dt else base result = (base_cast >> first.cast(dt)) & _const(dt, 1) return _cast_to(result, dt_suffix) if dt_suffix else result def _handle_brace_index(self, base) -> UOp: self.eat('LBRACE') idx = self.eat('NUM').val self.eat('RBRACE') var_name = self._find_var_name(base) if var_name: elem = self.vars.get(f'{var_name}@{idx}', _u32(0)) # use @ to avoid collision with temps like A4 if self.try_eat('DOT'): dt_name = self.eat('IDENT').val return _cast_to(elem, DTYPES.get(dt_name, dtypes.uint32)) if self.at('LBRACKET'): return self._handle_bracket(elem) return elem return _u32(0) def _find_var_name(self, base: UOp) -> str | None: if base.op == Ops.DEFINE_VAR and base.arg: return base.arg[0] for name, v in self.vars.items(): if isinstance(v, UOp) and v is base: return name return None def _sized_literal(self, bits: int) -> UOp: if self.at('IDENT') and self.peek().val in ('U', 'I', 'F', 'B'): type_char = self.eat('IDENT').val self.eat('LPAREN') inner = self.parse() self.eat('RPAREN') dt = {('U',32): dtypes.uint32, ('U',64): dtypes.uint64, ('I',32): dtypes.int, ('I',64): dtypes.int64, ('F',16): dtypes.half, ('F',32): dtypes.float32, ('F',64): dtypes.float64, ('B',32): dtypes.uint32, ('B',64): dtypes.uint64}.get((type_char, bits), dtypes.uint64 if bits > 32 else dtypes.uint32) if type_char == 'F' and inner.dtype in (dtypes.uint32, dtypes.uint64, dtypes.ulong, dtypes.int, dtypes.int64): if inner.dtype.itemsize != dt.itemsize: inner = inner.cast(dtypes.uint32 if dt.itemsize == 4 else dtypes.uint64) return inner.bitcast(dt) return inner.cast(dt) if self.at('IDENT'): ident = self.peek().val fmt = ident[0].lower() if fmt in ('h', 'b', 'd'): self.eat('IDENT') if len(ident) > 1: num = ident[1:] elif self.at('NUM'): num = self.eat('NUM').val elif self.at('IDENT'): num = self.eat('IDENT').val else: raise RuntimeError(f"expected number after {bits}'{fmt}") if fmt == 'h': val = int(num, 16) elif fmt == 'b': val = int(num, 2) else: val = int(num) return _const(_BITS_DT.get(bits, dtypes.uint32), val) if self.at('NUM') and self.peek().val.startswith('0x'): num = self.eat('NUM').val return _const(_BITS_DT.get(bits, dtypes.uint32), int(num, 16)) if self.at('NUM') or (self.at('OP') and self.peek().val == '-'): neg = self.try_eat_val('-', 'OP') is not None suffix, num = _strip_suffix(self.eat('NUM').val) if num.startswith('0x'): val = int(num, 16) if neg: val = -val elif '.' in num: fval = float(num) if neg: fval = -fval return _const({16: dtypes.half, 32: dtypes.float32, 64: dtypes.float64}.get(bits, dtypes.float32), fval) else: val = int(num) if neg: val = -val dt = {1: dtypes.uint32, 8: dtypes.uint8, 16: dtypes.int16 if 'U' not in suffix else dtypes.uint16, 32: dtypes.int if 'U' not in suffix else dtypes.uint32, 64: dtypes.int64 if 'U' not in suffix else dtypes.uint64}.get(bits, dtypes.uint32) return _const(dt, val) raise RuntimeError(f"unexpected token after {bits}': {self.peek()}") def _parse_number(self, num: str) -> UOp: if num.startswith('0x') or num.startswith('0X'): is_u64 = num.upper().endswith('ULL') or num.upper().endswith('LL') or num.upper().endswith('UL') return _const(dtypes.uint64 if is_u64 else dtypes.uint32, int(num.rstrip('ULul'), 16)) suffix, num_str = _strip_suffix(num) if '.' in num_str or suffix in ('F', 'f'): return _const(dtypes.float32 if suffix in ('F', 'f') else dtypes.float64, float(num_str)) val = int(num_str) if 'ULL' in suffix or 'LL' in suffix or 'L' in suffix: return _const(dtypes.uint64, val) if 'U' in suffix: return _const(dtypes.uint32, val) return _const(dtypes.int if val < 0 else dtypes.uint32, val) def _parse_args(self) -> list[UOp]: if self.at('RPAREN'): return [] args = [self.parse()] while self.try_eat('COMMA'): args.append(self.parse()) return args def _call_func(self, name: str, args: list[UOp]) -> UOp: if name in self.vars and isinstance(self.vars[name], tuple) and self.vars[name][0] == 'lambda': _, params, body = self.vars[name] lv = {**self.vars, **{p: a for p, a in zip(params, args)}} if ';' in body or '\n' in body or 'return' in body.lower(): lines = [l.strip() for l in body.replace(';', '\n').split('\n') if l.strip() and not l.strip().startswith('//')] _, _, result = parse_block(lines, 0, lv, self.funcs) assert result is not None, f"lambda {name} must return a value" return result return parse_expr(body, lv, self.funcs) if name in self.funcs: return self.funcs[name](*args) raise RuntimeError(f"unknown function: {name}") def _handle_mem_load(self, addr: UOp, dt) -> UOp: mem = self.vars.get('_vmem') if '_vmem' in self.vars else self.vars.get('_lds') assert mem is not None, "memory load requires _vmem or _lds" adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32 active = self.vars.get('_active') gate = (active,) if active is not None else () byte_mem = mem.dtype.base == dtypes.uint8 if byte_mem: idx = addr.cast(dtypes.int) if dt in (dtypes.uint64, dtypes.int64, dtypes.float64): val = _u32(0).cast(dtypes.uint64) for i in range(8): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint64) << _u64(i * 8)) elif dt in (dtypes.uint8, dtypes.int8): val = mem.index(idx, *gate, ptr=True).load().cast(dt) elif dt in (dtypes.uint16, dtypes.int16, dtypes.short): val = (mem.index(idx, *gate, ptr=True).load().cast(dtypes.uint32) | (mem.index(idx + _const(dtypes.int, 1), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(8))).cast(dt) else: val = _u32(0) for i in range(4): val = val | (mem.index(idx + _const(dtypes.int, i), *gate, ptr=True).load().cast(dtypes.uint32) << _u32(i * 8)) else: idx = (addr >> _const(addr.dtype, 2)).cast(dtypes.int) val = mem.index(idx, *gate) if dt in (dtypes.uint64, dtypes.int64, dtypes.float64): idx2 = ((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int) val = val.cast(dtypes.uint64) | (mem.index(idx2, *gate).cast(dtypes.uint64) << _u64(32)) elif dt in (dtypes.uint8, dtypes.int8): val = (val >> ((addr & _const(adt, 3)).cast(dtypes.uint32) * _u32(8))) & _u32(0xFF) elif dt in (dtypes.uint16, dtypes.int16): val = (val >> (((addr >> _const(adt, 1)) & _const(adt, 1)).cast(dtypes.uint32) * _u32(16))) & _u32(0xFFFF) return val def _coerce_cmp(self, l: UOp, r: UOp) -> tuple[UOp, UOp]: if l.dtype != r.dtype: if r.dtype == dtypes.int and r.op == Ops.CONST and r.arg < 0: l = l.cast(dtypes.int) else: r = r.cast(l.dtype) return l, r def _coerce_bitwise(self, l: UOp, r: UOp) -> tuple[UOp, UOp]: if l.dtype != r.dtype: if l.dtype.itemsize == r.dtype.itemsize: t = dtypes.uint32 if l.dtype.itemsize == 4 else dtypes.uint64 if l.dtype.itemsize == 8 else l.dtype l, r = l.bitcast(t), r.bitcast(t) else: r = r.cast(l.dtype) return l, r def _cmp_nan(self, l: UOp, r: UOp, fn) -> UOp: result = fn(l, r) if l.dtype in (dtypes.float32, dtypes.float64, dtypes.half): return result & _isnan(l).logical_not() & _isnan(r).logical_not() return result def _match_bracket(toks: list[Token], start: int) -> tuple[int, list[Token]]: """Match brackets from start, return (end_idx, inner_tokens).""" j, depth = start + 1, 1 while j < len(toks) and depth > 0: if toks[j].type == 'LBRACKET': depth += 1 elif toks[j].type == 'RBRACKET': depth -= 1 j += 1 return j, [t for t in toks[start+1:j-1] if t.type != 'EOF'] def _tok_str(toks: list[Token]) -> str: return ' '.join(t.val for t in toks if t.type != 'EOF') def parse_tokens(toks: list[Token], vars: dict[str, VarVal], funcs: dict | None = None) -> UOp: return Parser(toks, vars, funcs).parse() # Unified block parser for pcode def _subst_loop_var(line: str, loop_var: str, val: int) -> str: """Substitute loop variable with its value.""" toks = tokenize(line) return ' '.join(str(val) if t.type == 'IDENT' and t.val == loop_var else t.val for t in toks if t.type != 'EOF') def _set_bits(old: UOp, val: UOp, width: int, offset: int) -> UOp: """Set bits [offset:offset+width) in old to val, masking and shifting appropriately.""" mask = _u32(((1 << width) - 1) << offset) v = (val.cast(dtypes.uint32) if val.dtype != dtypes.uint32 else val) & _u32((1 << width) - 1) return (old & (mask ^ _u32(0xFFFFFFFF))) | (v << _u32(offset)) def _find_paren_end(s: str, start: int = 0, open_ch: str = '(', close_ch: str = ')') -> int: """Find index of matching close paren, starting after the open paren at start.""" depth = 0 for j, ch in enumerate(s[start:], start): if ch == open_ch: depth += 1 elif ch == close_ch: depth -= 1 if depth == 0: return j return len(s) def parse_block(lines: list[str], start: int, vars: dict[str, VarVal], funcs: dict | None = None, assigns: list | None = None) -> tuple[int, dict[str, VarVal], UOp | None]: """Parse a block of pcode. Returns (next_line, block_assigns, return_value). If assigns list is provided, side effects (MEM/VGPR writes) are appended to it.""" if funcs is None: funcs = _FUNCS block_assigns: dict[str, VarVal] = {} i = start while i < len(lines): line = lines[i] toks = tokenize(line) if toks[0].type != 'IDENT' and toks[0].type != 'LBRACE': i += 1; continue first = toks[0].val.lower() if toks[0].type == 'IDENT' else '{' # Block terminators if first in ('elsif', 'else', 'endif', 'endfor'): break # return expr (lambda bodies) if first == 'return': rest = line[line.lower().find('return') + 6:].strip() return i + 1, block_assigns, parse_expr(rest, vars, funcs) # for loop if first == 'for': # Parse: for VAR in [SIZE']START : [SIZE']END do p = Parser(toks, vars, funcs) p.eat_val('for', 'IDENT') loop_var = p.eat('IDENT').val p.eat_val('in', 'IDENT') def parse_bound(): if p.at('NUM') and p.peek(1).type == 'QUOTE': p.eat('NUM'); p.eat('QUOTE') if p.at('NUM'): return int(p.eat('NUM').val.rstrip('UuLl')) expr = p.parse().simplify() assert expr.op == Ops.CONST, f"loop bound must be constant, got {expr}" return int(expr.arg) start_val = parse_bound() p.eat('COLON') end_val = parse_bound() # Collect body i += 1 body_lines: list[str] = [] depth = 1 while i < len(lines) and depth > 0: btoks = tokenize(lines[i]) if btoks[0].type == 'IDENT': if btoks[0].val.lower() == 'for': depth += 1 elif btoks[0].val.lower() == 'endfor': depth -= 1 if depth > 0: body_lines.append(lines[i]) i += 1 # Execute loop with break support has_break = any('break' in bl.lower() for bl in body_lines) found_var = f'_found_{id(body_lines)}' if has_break else None if found_var: vars[found_var] = block_assigns[found_var] = _const(dtypes.bool, False) for loop_i in range(start_val, end_val + 1): subst_lines = [_subst_loop_var(bl, loop_var, loop_i) for bl in body_lines if not (has_break and bl.strip().lower() == 'break')] _, iter_assigns, _ = parse_block(subst_lines, 0, {**vars, **block_assigns}, funcs, assigns) if has_break: assert found_var is not None found = block_assigns.get(found_var, vars.get(found_var)) assert isinstance(found, UOp) not_found = found.eq(_const(dtypes.bool, False)) for var, val in iter_assigns.items(): if var != found_var and isinstance(val, UOp): old = block_assigns.get(var, vars.get(var, _u32(0))) if isinstance(old, UOp): block_assigns[var] = vars[var] = not_found.where(val, old.cast(val.dtype) if val.dtype != old.dtype and val.dtype.itemsize == old.dtype.itemsize else old) for j, bl in enumerate(body_lines): bl_l = bl.strip().lower() if bl_l.startswith('if ') and bl_l.endswith(' then'): if any(body_lines[k].strip().lower() == 'break' for k in range(j+1, len(body_lines))): cond_str = _subst_loop_var(bl.strip()[3:-5].strip(), loop_var, loop_i) cond = _to_bool(parse_expr(cond_str, vars, funcs)) block_assigns[found_var] = vars[found_var] = not_found.where(cond, found) break else: block_assigns.update(iter_assigns); vars.update(iter_assigns) continue # declare if first == 'declare': # Initialize scalar declarations (skip arrays and vars already passed as srcs) if '[' not in line and len(toks) >= 2 and toks[1].type == 'IDENT': vars.setdefault(toks[1].val, _u32(0)) i += 1; continue # lambda definition if first != '{' and '=' in line and 'lambda' in line and any(t.type == 'IDENT' and t.val == 'lambda' for t in toks): name = toks[0].val body_start = line[line.find('(', line.find('lambda')):] params_end = _find_paren_end(body_start) + 1 params = [p.strip() for p in body_start[1:params_end-1].split(',') if p.strip()] rest = body_start[params_end:].strip() if rest.startswith('('): body_end = _find_paren_end(rest) if body_end < len(rest): # found matching paren on same line body = rest[1:body_end].strip() i += 1 else: # multiline body body_lines_lst, depth = [rest[1:]], 1 i += 1 while i < len(lines) and depth > 0: for j, ch in enumerate(lines[i]): if ch == '(': depth += 1 elif ch == ')': depth -= 1 if depth == 0: body_lines_lst.append(lines[i][:j]); break else: body_lines_lst.append(lines[i]) i += 1 body = '\n'.join(body_lines_lst).strip() vars[name] = ('lambda', params, body) continue # MEM assignment: MEM[addr].type (+|-)?= value if first == 'mem' and toks[1].type == 'LBRACKET': j, addr_toks = _match_bracket(toks, 1) addr = parse_tokens(addr_toks, vars, funcs) if j < len(toks) and toks[j].type == 'DOT': j += 1 dt_name = toks[j].val if j < len(toks) and toks[j].type == 'IDENT' else 'u32' dt, j = DTYPES.get(dt_name, dtypes.uint32), j + 1 compound_op = None if j < len(toks) and toks[j].type == 'ASSIGN_OP': compound_op = toks[j].val; j += 1 elif j < len(toks) and toks[j].type == 'EQUALS': j += 1 rhs = parse_tokens(toks[j:], vars, funcs) if compound_op: mem = vars.get('_vmem') if '_vmem' in vars else vars.get('_lds') if isinstance(mem, UOp): adt = dtypes.uint64 if addr.dtype == dtypes.uint64 else dtypes.uint32 idx = (addr >> _const(adt, 2)).cast(dtypes.int) old = mem.index(idx) if dt in (dtypes.uint64, dtypes.int64, dtypes.float64): old = old.cast(dtypes.uint64) | (mem.index(((addr + _const(adt, 4)) >> _const(adt, 2)).cast(dtypes.int)).cast(dtypes.uint64) << _u64(32)) rhs = (old + rhs) if compound_op == '+=' else (old - rhs) if assigns is not None: assigns.append((f'MEM[{_tok_str(addr_toks)}].{dt_name}', (addr, rhs))) i += 1; continue # VGPR assignment: VGPR[lane][reg] = value if first == 'vgpr' and toks[1].type == 'LBRACKET': j, lane_toks = _match_bracket(toks, 1) if j < len(toks) and toks[j].type == 'LBRACKET': j, reg_toks = _match_bracket(toks, j) if j < len(toks) and toks[j].type == 'DOT': j += 2 # skip .type suffix if j < len(toks) and toks[j].type == 'EQUALS': j += 1 ln, rg, val = parse_tokens(lane_toks, vars, funcs), parse_tokens(reg_toks, vars, funcs), parse_tokens(toks[j:], vars, funcs) if assigns is not None: assigns.append((f'VGPR[{_tok_str(lane_toks)}][{_tok_str(reg_toks)}]', (_to_u32(rg) * _u32(32) + _to_u32(ln), val))) i += 1; continue # Compound destination: {hi.type, lo.type} = value if first == '{': j = 1 if j+2 < len(toks) and toks[j].type == 'IDENT' and toks[j+1].type == 'DOT': hi_var, hi_type = toks[j].val, toks[j+2].val j += 3 if j < len(toks) and toks[j].type == 'COMMA': j += 1 if j+2 < len(toks) and toks[j].type == 'IDENT' and toks[j+1].type == 'DOT': lo_var, lo_type = toks[j].val, toks[j+2].val j += 3 if j < len(toks) and toks[j].type == 'RBRACE': j += 1 if j < len(toks) and toks[j].type == 'EQUALS': j += 1 val = parse_tokens(toks[j:], vars, funcs) lo_dt, hi_dt = DTYPES.get(lo_type, dtypes.uint64), DTYPES.get(hi_type, dtypes.uint32) lo_bits = 64 if lo_dt in (dtypes.uint64, dtypes.int64) else 32 lo_val = val.cast(lo_dt) if val.dtype.itemsize * 8 <= lo_bits else (val & _const(val.dtype, (1 << lo_bits) - 1)).cast(lo_dt) hi_val = (val >> _const(val.dtype, lo_bits)).cast(hi_dt) block_assigns[lo_var] = vars[lo_var] = lo_val block_assigns[hi_var] = vars[hi_var] = hi_val if assigns is not None: assigns.extend([(f'{lo_var}.{lo_type}', lo_val), (f'{hi_var}.{hi_type}', hi_val)]) i += 1; continue # Bit slice/index: var[hi:lo] = value, var.type[hi:lo] = value, or var[expr] = value if len(toks) >= 5 and toks[0].type == 'IDENT' and (toks[1].type == 'LBRACKET' or (toks[1].type == 'DOT' and toks[3].type == 'LBRACKET')): bracket_start = 2 if toks[1].type == 'LBRACKET' else 4 j = bracket_start colon_pos = None while j < len(toks) and toks[j].type != 'RBRACKET': if toks[j].type == 'COLON': colon_pos = j j += 1 var = toks[0].val if colon_pos is not None: # bit slice: var[hi:lo] hi_str = ' '.join(t.val for t in toks[bracket_start:colon_pos] if t.type != 'EOF') lo_str = ' '.join(t.val for t in toks[colon_pos+1:j] if t.type != 'EOF') try: hi_val, lo_val = int(eval(hi_str)), int(eval(lo_str)) hi, lo = max(hi_val, lo_val), min(hi_val, lo_val) j += 1 if j < len(toks) and toks[j].type == 'DOT': j += 2 if j < len(toks) and toks[j].type == 'EQUALS': j += 1 val = parse_tokens(toks[j:], vars, funcs) dt_suffix = toks[2].val if toks[1].type == 'DOT' else None if assigns is not None: assigns.append((f'{var}[{hi}:{lo}]' + (f'.{dt_suffix}' if dt_suffix else ''), val)) if var not in vars: vars[var] = _const(dtypes.uint64 if hi >= 32 else dtypes.uint32, 0) old = block_assigns.get(var, vars.get(var)) block_assigns[var] = vars[var] = _set_bits(old, _val_to_bits(val), hi - lo + 1, lo) i += 1; continue except: pass elif toks[1].type == 'LBRACKET': # bit index: var[expr] (only for var[...], not var.type[...]) existing = block_assigns.get(var, vars.get(var)) if existing is not None and isinstance(existing, UOp) and not any(f'{var}{k}' in vars or f'{var}{k}' in block_assigns for k in range(8)): bit_toks = toks[2:j] j += 1 while j < len(toks) and toks[j].type != 'EQUALS': j += 1 if j < len(toks): block_assigns[var] = vars[var] = _set_bit(existing, _to_u32(parse_tokens(bit_toks, vars, funcs)), parse_tokens(toks[j+1:], vars, funcs)) i += 1; continue # Array element: var[idx] = value (static index) or var[expr] = value (dynamic) if len(toks) >= 4 and toks[0].type == 'IDENT' and toks[1].type == 'LBRACKET': var = toks[0].val j, idx_toks = _match_bracket(toks, 1) if j < len(toks) and toks[j].type == 'EQUALS': # Static index: var[NUM] = value if len(idx_toks) == 1 and idx_toks[0].type == 'NUM': idx = int(idx_toks[0].val.rstrip('UuLl')) val = parse_tokens(toks[j+1:], vars, funcs) existing = block_assigns.get(var, vars.get(var)) if existing is not None and isinstance(existing, UOp): block_assigns[var] = vars[var] = _set_bit(existing, _u32(idx), val) else: block_assigns[f'{var}@{idx}'] = vars[f'{var}@{idx}'] = val i += 1; continue # Dynamic index: var[expr] = value where var has @-elements elems = [(k.split('@')[1], v) for k, v in {**vars, **block_assigns}.items() if k.startswith(f'{var}@') and isinstance(v, UOp)] if elems: idx_expr = parse_tokens(idx_toks, vars, funcs) val = parse_tokens(toks[j+1:], vars, funcs) for elem_idx_str, old_elem in elems: elem_idx = int(elem_idx_str) cond = _to_u32(idx_expr).eq(_u32(elem_idx)) new_val = cond.where(val.cast(old_elem.dtype) if val.dtype != old_elem.dtype else val, old_elem) block_assigns[f'{var}@{elem_idx}'] = vars[f'{var}@{elem_idx}'] = new_val i += 1; continue # Compound assignment: var += or var -= assign_op = next((j for j, t in enumerate(toks) if t.type == 'ASSIGN_OP'), None) if assign_op is not None: var = toks[0].val old = block_assigns.get(var, vars.get(var, _u32(0))) rhs = parse_tokens(toks[assign_op+1:], vars, funcs) if rhs.dtype != old.dtype: rhs = rhs.cast(old.dtype) block_assigns[var] = vars[var] = (old + rhs) if toks[assign_op].val == '+=' else (old - rhs) i += 1; continue # Typed element: var.type[idx] = value if len(toks) >= 7 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET' and toks[4].type == 'NUM': var, dt_name, idx = toks[0].val, toks[2].val, int(toks[4].val) dt = DTYPES.get(dt_name, dtypes.uint32) j = 6 while j < len(toks) and toks[j].type != 'EQUALS': j += 1 if j < len(toks): val, old = parse_tokens(toks[j+1:], vars, funcs), block_assigns.get(var, vars.get(var, _u32(0))) bw = dt.itemsize * 8 block_assigns[var] = vars[var] = _set_bits(old, val, bw, idx * bw) if assigns is not None: assigns.append((f'{var}.{dt_name}[{idx}]', val)) i += 1; continue # Dynamic bit: var.type[expr_with_brackets] = value if len(toks) >= 5 and toks[0].type == 'IDENT' and toks[1].type == 'DOT' and toks[2].type == 'IDENT' and toks[3].type == 'LBRACKET': j, depth, has_inner = 4, 1, False while j < len(toks) and depth > 0: if toks[j].type == 'LBRACKET': depth += 1; has_inner = True elif toks[j].type == 'RBRACKET': depth -= 1 j += 1 if has_inner: var = toks[0].val bit_pos = _to_u32(parse_tokens(toks[4:j-1], vars, funcs)) while j < len(toks) and toks[j].type != 'EQUALS': j += 1 if j < len(toks): val = parse_tokens(toks[j+1:], vars, funcs) old = block_assigns.get(var, vars.get(var, _u32(0))) block_assigns[var] = vars[var] = _set_bit(old, bit_pos, val) i += 1; continue # If/elsif/else - skip branches with statically false conditions (WAVE32/WAVE64) if first == 'if': def parse_cond(s, kw): ll = s.lower() return _to_bool(parse_expr(s[ll.find(kw) + len(kw):ll.rfind('then')].strip(), vars, funcs)) def is_const(c, v): return c.op == Ops.CONST and c.arg is v cond = parse_cond(line, 'if') conditions: list[tuple[UOp, UOp | dict[str, VarVal] | None]] = [(cond, None)] if not is_const(cond, False) else [] else_branch: tuple[UOp | None, dict[str, VarVal]] = (None, {}) vars_snap = dict(vars) static_true = is_const(cond, True) # track if any condition is statically true i += 1 i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not is_const(cond, False) else None) if conditions: conditions[0] = (cond, ret if ret is not None else branch) vars.clear(); vars.update(vars_snap) while i < len(lines): ltoks = tokenize(lines[i]) if ltoks[0].type != 'IDENT': break lf = ltoks[0].val.lower() if lf == 'elsif': c = parse_cond(lines[i], 'elsif') take = not static_true and not is_const(c, False) i += 1; i, branch, ret = parse_block(lines, i, vars, funcs, assigns if take else None) if take: conditions.append((c, ret if ret is not None else branch)) if is_const(c, True): static_true = True vars.clear(); vars.update(vars_snap) elif lf == 'else': i += 1 i, branch, ret = parse_block(lines, i, vars, funcs, assigns if not static_true else None) if not static_true: else_branch = (ret, branch) vars.clear(); vars.update(vars_snap) elif lf == 'endif': i += 1; break else: break # Check if any branch returned a value (lambda-style) if any(isinstance(br, UOp) for _, br in conditions): result = else_branch[0] for c, rv in reversed(conditions): if isinstance(rv, UOp) and isinstance(result, UOp): if rv.dtype != result.dtype and rv.dtype.itemsize == result.dtype.itemsize: result = result.cast(rv.dtype) result = c.where(rv, result) return i, block_assigns, result # If statically true, use that branch directly; otherwise merge with WHERE if static_true: ba = next((b for c, b in conditions if is_const(c, True) and isinstance(b, dict)), {}) block_assigns.update(ba); vars.update(ba) else: else_assigns = else_branch[1] all_vars = set().union(*[ba.keys() for _, ba in conditions if isinstance(ba, dict)], else_assigns.keys()) for var in all_vars: res: Any = else_assigns.get(var, block_assigns.get(var, vars.get(var, _u32(0)))) for cond, ba in reversed(conditions): if isinstance(ba, dict) and var in ba: tv = ba[var] if isinstance(tv, UOp) and isinstance(res, UOp): res = cond.where(tv, res.cast(tv.dtype) if tv.dtype != res.dtype and tv.dtype.itemsize == res.dtype.itemsize else res) block_assigns[var] = vars[var] = res continue # Regular assignment: var = value for j, t in enumerate(toks): if t.type == 'EQUALS': if any(toks[k].type == 'OP' and toks[k].val in ('<', '>', '!', '=') for k in range(j)): break base_var = toks[0].val block_assigns[base_var] = vars[base_var] = parse_tokens(toks[j+1:], vars, funcs) i += 1; break else: i += 1 return i, block_assigns, None def parse_expr(expr: str, vars: dict[str, VarVal], funcs: dict | None = None) -> UOp: return parse_tokens(tokenize(expr.strip().rstrip(';')), vars, funcs)