From 8e8ad423a744abe8be2874370ba19c8c568ccd7a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 4 Jan 2026 09:16:00 -0800 Subject: [PATCH] post parser --- extra/assembly/amd/qcode.py | 4 +- extra/assembly/amd/ucode.py | 1782 +++++++++++++---------------------- 2 files changed, 667 insertions(+), 1119 deletions(-) diff --git a/extra/assembly/amd/qcode.py b/extra/assembly/amd/qcode.py index 49cee8a859..f3221aa0c1 100644 --- a/extra/assembly/amd/qcode.py +++ b/extra/assembly/amd/qcode.py @@ -84,6 +84,7 @@ def expr(s: str) -> Expr: if s.endswith('.') and not (len(s) > 1 and s[-2].isdigit()): s = s[:-1] s = s.strip() if not s: raise ValueError("Empty expression") + if s == '+INF': s = 'INF' if s[0] == '(' and (e := _match(s, 0, '(', ')')) == len(s)-1: return expr(s[1:e]) if s[0] == '{' and s[-1] == '}': return Pack(tuple(expr(a) for a in _split(s[1:-1]))) if m := re.match(r"^(\d+)'([IUFB])\(", s): @@ -132,9 +133,6 @@ def expr(s: str) -> Expr: if '.' in s: for i in range(len(s)-1, 0, -1): if s[i] == '.' and s[i+1:] in DTYPES: return Typed(expr(s[:i]), DTYPES[s[i+1:]]) - if s in ('INF', '+INF'): return Const(float('inf'), DType.F64) - if s == '-INF': return Const(float('-inf'), DType.F64) - if s == 'PI': return Const(3.141592653589793, DType.F64) if s[:5] == 'eval ': return Var(s) if ':' in s and '?' not in s and '[' not in s: p = s.split(':') diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index 940b913218..c351f377fd 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -1,43 +1,40 @@ # UOp-based pseudocode compiler for AMD GPU instruction emulation -# Transforms pseudocode -> UOps -> execution via simplify -# Designed for reversible transformation (UOps -> instruction selection) +# Transforms pseudocode -> qcode AST -> UOps -> execution via simplify -import re, functools, struct +import functools, struct, math from tinygrad.uop.ops import UOp, Ops from tinygrad.dtype import dtypes, DType +from extra.assembly.amd.qcode import parse, Const, Var, Typed, Slice, Index, Cast, Unary, Binary, Ternary, Call, Pack, Assign, Declare, If, For +from extra.assembly.amd.qcode import DType as QDType # ═══════════════════════════════════════════════════════════════════════════════ # TYPE MAPPING # ═══════════════════════════════════════════════════════════════════════════════ -DTYPE_MAP = { - 'f32': dtypes.float32, 'f16': dtypes.float16, 'f64': dtypes.float64, - 'u32': dtypes.uint32, 'u16': dtypes.uint16, 'u64': dtypes.uint64, - 'i32': dtypes.int32, 'i16': dtypes.int16, 'i64': dtypes.int64, - 'u24': dtypes.uint24, 'i24': dtypes.int24, - 'b32': dtypes.uint32, 'b16': dtypes.uint16, 'b64': dtypes.uint64, - 'u8': dtypes.uint8, 'i8': dtypes.int8, - 'u': dtypes.uint32, 'i': dtypes.int32, 'f': dtypes.float32, # shorthand types - 'u1': dtypes.uint32, 'i1': dtypes.int32, # 1-bit as 32-bit +QDTYPE_MAP = { + QDType.F64: dtypes.float64, QDType.F32: dtypes.float32, QDType.F16: dtypes.float16, + QDType.U64: dtypes.uint64, QDType.U32: dtypes.uint32, QDType.U24: dtypes.uint24, QDType.U16: dtypes.uint16, QDType.U8: dtypes.uint8, + QDType.I64: dtypes.int64, QDType.I32: dtypes.int32, QDType.I24: dtypes.int24, QDType.I16: dtypes.int16, QDType.I8: dtypes.int8, + QDType.B128: dtypes.uint64, QDType.B64: dtypes.uint64, QDType.B32: dtypes.uint32, QDType.B16: dtypes.uint16, QDType.B8: dtypes.uint8, + QDType.U1: dtypes.uint32, QDType.I1: dtypes.int32, QDType.U3: dtypes.uint32, QDType.U4: dtypes.uint32, QDType.I4: dtypes.int32, } def _is_float(dtype: DType) -> bool: return dtype in (dtypes.float16, dtypes.float32, dtypes.float64) +def _qdt(qd: QDType) -> DType: return QDTYPE_MAP.get(qd, dtypes.uint32) # ═══════════════════════════════════════════════════════════════════════════════ -# UOP GRAPH BUILDER (compile time) +# UOP GRAPH BUILDER # ═══════════════════════════════════════════════════════════════════════════════ class UOpBuilder: - """Builds a UOp graph from pseudocode expressions at compile time.""" - + """Builds a UOp graph from qcode AST at compile time.""" + def __init__(self): - # Create DEFINE_VAR placeholders for inputs self.input_vars = { 'S0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('S0', 0, 0xffffffff)), 'S1': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('S1', 0, 0xffffffff)), 'S2': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('S2', 0, 0xffffffff)), 'D0': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('D0', 0, 0xffffffff)), - # 64-bit variants for ops that need them 'S0_64': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('S0_64', 0, 0xffffffffffffffff)), 'S1_64': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('S1_64', 0, 0xffffffffffffffff)), 'S2_64': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('S2_64', 0, 0xffffffffffffffff)), @@ -46,888 +43,663 @@ class UOpBuilder: 'VCC': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('VCC', 0, 0xffffffffffffffff)), 'EXEC': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('EXEC', 0, 0xffffffffffffffff)), 'laneId': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('laneId', 0, 31)), - # Immediate constants for SOPK/literal instructions 'SIMM16': UOp(Ops.DEFINE_VAR, dtypes.int32, (), ('SIMM16', -32768, 32767)), 'SIMM32': UOp(Ops.DEFINE_VAR, dtypes.uint32, (), ('SIMM32', 0, 0xffffffff)), - # Program counter for branch instructions 'PC': UOp(Ops.DEFINE_VAR, dtypes.uint64, (), ('PC', 0, 0xffffffffffffffff)), } self.vars: dict[str, UOp] = dict(self.input_vars) - self.outputs: list[tuple[str, UOp, DType]] = [] # (name, uop, dtype) - - def const(self, val, dtype: DType) -> UOp: - return UOp(Ops.CONST, dtype, (), val) - + self.outputs: list[tuple[str, UOp, DType]] = [] + + def const(self, val, dtype: DType) -> UOp: return UOp(Ops.CONST, dtype, (), val) + def cast(self, x: UOp, dtype: DType) -> UOp: if x.dtype == dtype: return x - # BITCAST only works for same-size types, use CAST otherwise - if dtype.itemsize == x.dtype.itemsize: - return UOp(Ops.BITCAST, dtype, (x,)) - return UOp(Ops.CAST, dtype, (x,)) - - def parse_type(self, s: str) -> tuple[str, DType]: - if '.' in s: - var, typ = s.rsplit('.', 1) - return var.strip(), DTYPE_MAP.get(typ, dtypes.uint32) - return s.strip(), dtypes.uint32 - - def parse_expr(self, expr: str, dtype_hint: DType = None) -> tuple[UOp, DType]: - expr = expr.strip() - # Strip trailing punctuation (period used as sentence end in pseudocode) - if expr.endswith('.') and not expr[-2:-1].isdigit(): - expr = expr[:-1] + return UOp(Ops.BITCAST if dtype.itemsize == x.dtype.itemsize else Ops.CAST, dtype, (x,)) - # Handle parentheses - if expr.startswith('(') and expr.endswith(')'): - depth = 0 - for i, c in enumerate(expr): - if c == '(': depth += 1 - elif c == ')': depth -= 1 - if depth == 0 and i < len(expr) - 1: break - else: - return self.parse_expr(expr[1:-1], dtype_hint) - - # Handle type cast: 32'I(expr), 64'U(expr), 64'F(expr), 16'F(expr), 1'1U, 1'0U - # Only match if the cast spans the whole expression - if m := re.match(r"^(\d+)'([IUFB])\(", expr): - bits, typ = int(m.group(1)), m.group(2) - # Find matching closing paren - start = m.end() - 1 # position of opening paren - depth, end = 0, start - for i, c in enumerate(expr[start:], start): - if c == '(': depth += 1 - elif c == ')': depth -= 1 - if depth == 0: end = i; break - # Only use this if cast spans entire expression - if end == len(expr) - 1: - inner = expr[start+1:end] - dtype_map = { - (16, 'I'): dtypes.int16, (16, 'U'): dtypes.uint16, (16, 'F'): dtypes.float16, - (32, 'I'): dtypes.int32, (32, 'U'): dtypes.uint32, (32, 'F'): dtypes.float32, (32, 'B'): dtypes.uint32, - (64, 'I'): dtypes.int64, (64, 'U'): dtypes.uint64, (64, 'F'): dtypes.float64, (64, 'B'): dtypes.uint64, - } - dtype = dtype_map.get((bits, typ), dtypes.uint32) - inner_uop, inner_dt = self.parse_expr(inner, dtype) - # For float casts, use CAST for value conversion - if typ == 'F': - return UOp(Ops.CAST, dtype, (inner_uop,)), dtype - # If inner is already the right size integer, just use it (masking already done by .u24/.i24) - if inner_dt in (dtypes.uint32, dtypes.int32) and bits == 32: return inner_uop, dtype - if inner_dt in (dtypes.uint64, dtypes.int64) and bits == 64: return inner_uop, dtype - return self.cast(inner_uop, dtype), dtype - if m := re.match(r"^(\d+)'(\d+)([IU])$", expr): - # Constant like 1'1U or 1'0U - val = int(m.group(2)) - return self.const(val, dtypes.uint32), dtypes.uint32 - - # Handle signext(expr) - sign extension - # Only match if signext spans the whole expression - if m := re.match(r"^signext\(", expr): - # Find matching closing paren - start = m.end() - 1 - depth, end = 0, start - for i, c in enumerate(expr[start:], start): - if c == '(': depth += 1 - elif c == ')': depth -= 1 - if depth == 0: end = i; break - # Only use this if signext spans entire expression - if end == len(expr) - 1: - inner = expr[start+1:end] - inner_uop, inner_dt = self.parse_expr(inner) - # Sign extend to 64-bit for arithmetic - return self.cast(inner_uop, dtypes.int64), dtypes.int64 - - # Handle function calls: fma, trunc, floor, sqrt, isNAN, abs, etc. - if m := re.match(r"^(\w+)\(", expr): - fn_name = m.group(1) - start = m.end() - 1 - depth, end = 0, start - for i, c in enumerate(expr[start:], start): - if c == '(': depth += 1 - elif c == ')': depth -= 1 - if depth == 0: end = i; break - if end == len(expr) - 1: - inner = expr[start+1:end] - # Parse comma-separated arguments - args = [] - depth = 0 - last = 0 - for i, c in enumerate(inner): - if c in '([': depth += 1 - elif c in ')]': depth -= 1 - elif c == ',' and depth == 0: - args.append(inner[last:i].strip()) - last = i + 1 - args.append(inner[last:].strip()) - - if fn_name == 'fma' and len(args) == 3: - a, _ = self.parse_expr(args[0], dtype_hint) - b, _ = self.parse_expr(args[1], dtype_hint) - c, dt = self.parse_expr(args[2], dtype_hint) - return UOp(Ops.MULACC, dt, (a, b, c)), dt - elif fn_name == 'trunc' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return UOp(Ops.TRUNC, dt, (inner_uop,)), dt - elif fn_name == 'floor' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - # floor(x) = trunc(x) - (x < 0 and x != trunc(x) ? 1 : 0) - # For now, use a simpler approach - just trunc for positive, trunc-1 for negative non-integer - # Actually, Python's math.floor works correctly, so we can use it via constant folding - # But we need a FLOOR op - for now just use trunc (may be slightly wrong for negative) - return UOp(Ops.TRUNC, dt, (inner_uop,)), dt # TODO: proper floor - elif fn_name == 'sqrt' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return UOp(Ops.SQRT, dt, (inner_uop,)), dt - elif fn_name == 'abs' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - # abs(x) = x < 0 ? -x : x - zero = self.const(0, dt) - neg = UOp(Ops.NEG, dt, (inner_uop,)) - cond = UOp(Ops.CMPLT, dtypes.bool, (inner_uop, zero)) - return UOp(Ops.WHERE, dt, (cond, neg, inner_uop)), dt - elif fn_name == 'isNAN' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - # isNAN(x) = x != x - return UOp(Ops.CMPNE, dtypes.bool, (inner_uop, inner_uop)), dtypes.bool - elif fn_name == 'isQuietNAN' and len(args) == 1: - # For now, treat same as isNAN - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return UOp(Ops.CMPNE, dtypes.bool, (inner_uop, inner_uop)), dtypes.bool - elif fn_name == 'isSignalNAN' and len(args) == 1: - # For now, return false (signal NaN is rare) - return self.const(0, dtypes.bool), dtypes.bool - elif fn_name == 'cvtToQuietNAN' and len(args) == 1: - # Convert signaling NaN to quiet NaN - for our purposes, just return input - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return inner_uop, dt - elif fn_name == 'isINF' and len(args) == 1: - # isINF(x) - check if infinite - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - # x == inf or x == -inf, but we need the constants - # For now, return a placeholder that will work for constant folding - inf = self.const(float('inf'), dt) - neg_inf = self.const(float('-inf'), dt) - is_pos_inf = UOp(Ops.CMPEQ, dtypes.bool, (inner_uop, inf)) - is_neg_inf = UOp(Ops.CMPEQ, dtypes.bool, (inner_uop, neg_inf)) - return UOp(Ops.OR, dtypes.bool, (is_pos_inf, is_neg_inf)), dtypes.bool - elif fn_name in ('u32_to_f32', 'i32_to_f32', 'f32_to_u32', 'f32_to_i32', 'f16_to_f32', 'f32_to_f16', - 'f32_to_u8', 'f32_to_i8', 'f32_to_u16', 'f32_to_i16', 'v_cvt_i16_f32', 'v_cvt_u16_f32', - 'f64_to_i32', 'f64_to_u32', 'i32_to_f64', 'u32_to_f64', 'f64_to_f32', 'f32_to_f64', - 'f16_to_snorm', 'f16_to_unorm', 'u16_to_f16', 'i16_to_f16', 'f16_to_u16', 'f16_to_i16'): - # These are VALUE conversions, not bit reinterpretations - always use CAST - inner_uop, inner_dt = self.parse_expr(args[0]) - if fn_name == 'u32_to_f32': - return UOp(Ops.CAST, dtypes.float32, (inner_uop,)), dtypes.float32 - elif fn_name == 'i32_to_f32': - return UOp(Ops.CAST, dtypes.float32, (inner_uop,)), dtypes.float32 - elif fn_name == 'f32_to_u32': - # Clamp negative to 0, then convert (AMD semantics) - zero = self.const(0.0, dtypes.float32) - clamped = UOp(Ops.WHERE, dtypes.float32, (UOp(Ops.CMPLT, dtypes.bool, (inner_uop, zero)), zero, inner_uop)) - return UOp(Ops.CAST, dtypes.uint32, (clamped,)), dtypes.uint32 - elif fn_name == 'f32_to_i32': - return UOp(Ops.CAST, dtypes.int32, (inner_uop,)), dtypes.int32 - elif fn_name == 'f16_to_f32': - # f16 -> f32 value conversion - return UOp(Ops.CAST, dtypes.float32, (inner_uop,)), dtypes.float32 - elif fn_name == 'f32_to_f16': - # f32 -> f16 value conversion - return UOp(Ops.CAST, dtypes.float16, (inner_uop,)), dtypes.float16 - elif fn_name == 'f32_to_u8': - return UOp(Ops.CAST, dtypes.uint8, (inner_uop,)), dtypes.uint8 - elif fn_name == 'f32_to_i8': - return UOp(Ops.CAST, dtypes.int8, (inner_uop,)), dtypes.int8 - elif fn_name in ('f32_to_u16', 'v_cvt_u16_f32'): - return UOp(Ops.CAST, dtypes.uint16, (inner_uop,)), dtypes.uint16 - elif fn_name in ('f32_to_i16', 'v_cvt_i16_f32'): - return UOp(Ops.CAST, dtypes.int16, (inner_uop,)), dtypes.int16 - elif fn_name == 'f64_to_i32': - return UOp(Ops.CAST, dtypes.int32, (inner_uop,)), dtypes.int32 - elif fn_name == 'f64_to_u32': - # Clamp negative to 0 - zero = self.const(0.0, dtypes.float64) - clamped = UOp(Ops.WHERE, dtypes.float64, (UOp(Ops.CMPLT, dtypes.bool, (inner_uop, zero)), zero, inner_uop)) - return UOp(Ops.CAST, dtypes.uint32, (clamped,)), dtypes.uint32 - elif fn_name == 'i32_to_f64': - return UOp(Ops.CAST, dtypes.float64, (inner_uop,)), dtypes.float64 - elif fn_name == 'u32_to_f64': - return UOp(Ops.CAST, dtypes.float64, (inner_uop,)), dtypes.float64 - elif fn_name == 'f64_to_f32': - return UOp(Ops.CAST, dtypes.float32, (inner_uop,)), dtypes.float32 - elif fn_name == 'f32_to_f64': - return UOp(Ops.CAST, dtypes.float64, (inner_uop,)), dtypes.float64 - elif fn_name == 'f16_to_snorm': - # Convert f16 to signed normalized i16 (-1.0 to 1.0 -> -32768 to 32767) - clamped = UOp(Ops.WHERE, inner_uop.dtype, (UOp(Ops.CMPLT, dtypes.bool, (inner_uop, self.const(-1.0, inner_uop.dtype))), - self.const(-1.0, inner_uop.dtype), inner_uop)) - clamped = UOp(Ops.WHERE, inner_uop.dtype, (UOp(Ops.CMPLT, dtypes.bool, (self.const(1.0, inner_uop.dtype), clamped)), - self.const(1.0, inner_uop.dtype), clamped)) - scaled = UOp(Ops.MUL, inner_uop.dtype, (clamped, self.const(32767.0, inner_uop.dtype))) - return UOp(Ops.CAST, dtypes.int16, (scaled,)), dtypes.int16 - elif fn_name == 'f16_to_unorm': - # Convert f16 to unsigned normalized u16 (0.0 to 1.0 -> 0 to 65535) - clamped = UOp(Ops.WHERE, inner_uop.dtype, (UOp(Ops.CMPLT, dtypes.bool, (inner_uop, self.const(0.0, inner_uop.dtype))), - self.const(0.0, inner_uop.dtype), inner_uop)) - clamped = UOp(Ops.WHERE, inner_uop.dtype, (UOp(Ops.CMPLT, dtypes.bool, (self.const(1.0, inner_uop.dtype), clamped)), - self.const(1.0, inner_uop.dtype), clamped)) - scaled = UOp(Ops.MUL, inner_uop.dtype, (clamped, self.const(65535.0, inner_uop.dtype))) - return UOp(Ops.CAST, dtypes.uint16, (scaled,)), dtypes.uint16 - elif fn_name == 'u16_to_f16': - # Convert u16 integer to f16 float - return UOp(Ops.CAST, dtypes.float16, (inner_uop,)), dtypes.float16 - elif fn_name == 'i16_to_f16': - # Convert i16 integer to f16 float - return UOp(Ops.CAST, dtypes.float16, (inner_uop,)), dtypes.float16 - elif fn_name == 'f16_to_u16': - # Convert f16 float to u16 integer - return UOp(Ops.CAST, dtypes.uint16, (inner_uop,)), dtypes.uint16 - elif fn_name == 'f16_to_i16': - # Convert f16 float to i16 integer - return UOp(Ops.CAST, dtypes.int16, (inner_uop,)), dtypes.int16 - elif fn_name == 'exp2' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return UOp(Ops.EXP2, dt, (inner_uop,)), dt - elif fn_name == 'log2' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return UOp(Ops.LOG2, dt, (inner_uop,)), dt - elif fn_name == 'sin' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return UOp(Ops.SIN, dt, (inner_uop,)), dt - elif fn_name == 'cos' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - # cos(x) = sin(x + pi/2) - pi_2 = self.const(1.5707963267948966, dt) - return UOp(Ops.SIN, dt, (UOp(Ops.ADD, dt, (inner_uop, pi_2)),)), dt - elif fn_name == 'rcp' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - return UOp(Ops.RECIPROCAL, dt, (inner_uop,)), dt - elif fn_name == 'rsqrt' and len(args) == 1: - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - sqrt_val = UOp(Ops.SQRT, dt, (inner_uop,)) - return UOp(Ops.RECIPROCAL, dt, (sqrt_val,)), dt - elif fn_name == 'min' and len(args) == 2: - a, dt = self.parse_expr(args[0], dtype_hint) - b, _ = self.parse_expr(args[1], dtype_hint) - # min(a,b) = a < b ? a : b - cond = UOp(Ops.CMPLT, dtypes.bool, (a, b)) - return UOp(Ops.WHERE, dt, (cond, a, b)), dt - elif fn_name == 'max' and len(args) == 2: - a, dt = self.parse_expr(args[0], dtype_hint) - b, _ = self.parse_expr(args[1], dtype_hint) - # max(a,b) = a > b ? a : b - cond = UOp(Ops.CMPLT, dtypes.bool, (b, a)) - return UOp(Ops.WHERE, dt, (cond, a, b)), dt - elif fn_name == 'clamp' and len(args) == 3: - x, dt = self.parse_expr(args[0], dtype_hint) - lo, _ = self.parse_expr(args[1], dtype_hint) - hi, _ = self.parse_expr(args[2], dtype_hint) - # clamp(x, lo, hi) = min(max(x, lo), hi) - cond_lo = UOp(Ops.CMPLT, dtypes.bool, (x, lo)) - max_val = UOp(Ops.WHERE, dt, (cond_lo, lo, x)) - cond_hi = UOp(Ops.CMPLT, dtypes.bool, (hi, max_val)) - return UOp(Ops.WHERE, dt, (cond_hi, hi, max_val)), dt - elif fn_name == 'signext_from_bit' and len(args) == 2: - # signext_from_bit(val, width) - sign extend val from width bits to full type - val_uop, dt = self.parse_expr(args[0], dtype_hint) - width_uop, _ = self.parse_expr(args[1]) - # Sign bit is at position (width - 1) - # Sign extend: ((val ^ (1 << (width-1))) - (1 << (width-1))) - # But we need to handle width=0 case: return 0 - one = self.const(1, dt) - width_minus_1 = UOp(Ops.SUB, dt, (self.cast(width_uop, dt), one)) - sign_bit = UOp(Ops.SHL, dt, (one, width_minus_1)) - xored = UOp(Ops.XOR, dt, (val_uop, sign_bit)) - result = UOp(Ops.SUB, dt, (xored, sign_bit)) - # If width is 0, return 0 - width_is_zero = UOp(Ops.CMPEQ, dtypes.bool, (width_uop, self.const(0, width_uop.dtype))) - return UOp(Ops.WHERE, dt, (width_is_zero, self.const(0, dt), result)), dt - elif fn_name == 'ABSDIFF' and len(args) == 2: - # ABSDIFF(a, b) = |a - b| for unsigned values - a, _ = self.parse_expr(args[0]) - b, _ = self.parse_expr(args[1]) - # max(a,b) - min(a,b) - a_gt_b = UOp(Ops.CMPLT, dtypes.bool, (b, a)) - max_val = UOp(Ops.WHERE, dtypes.uint32, (a_gt_b, self.cast(a, dtypes.uint32), self.cast(b, dtypes.uint32))) - min_val = UOp(Ops.WHERE, dtypes.uint32, (a_gt_b, self.cast(b, dtypes.uint32), self.cast(a, dtypes.uint32))) - return UOp(Ops.SUB, dtypes.uint32, (max_val, min_val)), dtypes.uint32 - elif fn_name == 'exponent' and len(args) == 1: - # exponent(x) - extract IEEE exponent bits from float - # f16: bits[14:10] (5 bits), f32: bits[30:23] (8 bits), f64: bits[62:52] (11 bits) - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - if dt == dtypes.float64: - bits = UOp(Ops.BITCAST, dtypes.uint64, (inner_uop,)) - exp = UOp(Ops.SHR, dtypes.uint64, (bits, self.const(52, dtypes.uint64))) - exp = UOp(Ops.AND, dtypes.uint32, (self.cast(exp, dtypes.uint32), self.const(0x7ff, dtypes.uint32))) - elif dt == dtypes.float16: - bits = UOp(Ops.BITCAST, dtypes.uint16, (inner_uop,)) - exp = UOp(Ops.SHR, dtypes.uint16, (bits, self.const(10, dtypes.uint16))) - exp = UOp(Ops.AND, dtypes.uint32, (self.cast(exp, dtypes.uint32), self.const(0x1f, dtypes.uint32))) - else: # f32 - bits = UOp(Ops.BITCAST, dtypes.uint32, (inner_uop,)) - exp = UOp(Ops.SHR, dtypes.uint32, (bits, self.const(23, dtypes.uint32))) - exp = UOp(Ops.AND, dtypes.uint32, (exp, self.const(0xff, dtypes.uint32))) - return exp, dtypes.uint32 - elif fn_name == 'isEven' and len(args) == 1: - # isEven(x) - check if integer part is even - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - # Cast to int64, check bit 0 - int_val = UOp(Ops.CAST, dtypes.int64, (inner_uop,)) - bit0 = UOp(Ops.AND, dtypes.int64, (int_val, self.const(1, dtypes.int64))) - return UOp(Ops.CMPEQ, dtypes.bool, (bit0, self.const(0, dtypes.int64))), dtypes.bool - elif fn_name == 'sign' and len(args) == 1: - # sign(x) - return 1 if x is negative (sign bit set), 0 otherwise - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - if dt == dtypes.float64: - bits = UOp(Ops.BITCAST, dtypes.uint64, (inner_uop,)) - sign_bit = UOp(Ops.SHR, dtypes.uint64, (bits, self.const(63, dtypes.uint64))) - return UOp(Ops.AND, dtypes.uint32, (self.cast(sign_bit, dtypes.uint32), self.const(1, dtypes.uint32))), dtypes.uint32 - elif dt == dtypes.float16: - bits = UOp(Ops.BITCAST, dtypes.uint16, (inner_uop,)) - sign_bit = UOp(Ops.SHR, dtypes.uint16, (bits, self.const(15, dtypes.uint16))) - return UOp(Ops.AND, dtypes.uint32, (self.cast(sign_bit, dtypes.uint32), self.const(1, dtypes.uint32))), dtypes.uint32 - else: # f32 - bits = UOp(Ops.BITCAST, dtypes.uint32, (inner_uop,)) - sign_bit = UOp(Ops.SHR, dtypes.uint32, (bits, self.const(31, dtypes.uint32))) - return UOp(Ops.AND, dtypes.uint32, (sign_bit, self.const(1, dtypes.uint32))), dtypes.uint32 - elif fn_name == 'fract' and len(args) == 1: - # fract(x) = x - floor(x) = x - trunc(x) for positive, need proper floor for negative - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - truncated = UOp(Ops.TRUNC, dt, (inner_uop,)) - return UOp(Ops.SUB, dt, (inner_uop, truncated)), dt - elif fn_name == 'mantissa' and len(args) == 1: - # mantissa(x) - extract IEEE mantissa bits from float - # f16: bits[9:0] (10 bits), f32: bits[22:0] (23 bits), f64: bits[51:0] (52 bits) - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - if dt == dtypes.float64: - bits = UOp(Ops.BITCAST, dtypes.uint64, (inner_uop,)) - mant = UOp(Ops.AND, dtypes.uint64, (bits, self.const(0xfffffffffffff, dtypes.uint64))) - return mant, dtypes.uint64 - elif dt == dtypes.float16: - bits = UOp(Ops.BITCAST, dtypes.uint16, (inner_uop,)) - mant = UOp(Ops.AND, dtypes.uint32, (self.cast(bits, dtypes.uint32), self.const(0x3ff, dtypes.uint32))) - return mant, dtypes.uint32 - else: # f32 - bits = UOp(Ops.BITCAST, dtypes.uint32, (inner_uop,)) - mant = UOp(Ops.AND, dtypes.uint32, (bits, self.const(0x7fffff, dtypes.uint32))) - return mant, dtypes.uint32 - elif fn_name == 'pow' and len(args) == 2: - # pow(base, exp) - when base is 2.0, use exp2 - base, base_dt = self.parse_expr(args[0], dtype_hint) - exp, _ = self.parse_expr(args[1], dtype_hint) - result_dt = base_dt if _is_float(base_dt) else dtype_hint or dtypes.float32 - # Check if base is 2.0 - if base.op == Ops.CONST and base.arg == 2.0: - # For exponent, use CAST (value conversion), not BITCAST - exp_uop = UOp(Ops.CAST, result_dt, (exp,)) if exp.dtype != result_dt else exp - return UOp(Ops.EXP2, result_dt, (exp_uop,)), result_dt - # General case: pow(a, b) = exp2(b * log2(a)) - base_cast = UOp(Ops.CAST, result_dt, (base,)) if base.dtype != result_dt else base - exp_cast = UOp(Ops.CAST, result_dt, (exp,)) if exp.dtype != result_dt else exp - log_a = UOp(Ops.LOG2, result_dt, (base_cast,)) - product = UOp(Ops.MUL, result_dt, (exp_cast, log_a)) - return UOp(Ops.EXP2, result_dt, (product,)), result_dt - elif fn_name == 'LT_NEG_ZERO' and len(args) == 2: - # LT_NEG_ZERO(a, b) - less than comparison where -0 < +0 - # This differs from IEEE where -0 == +0 - a, dt = self.parse_expr(args[0], dtype_hint) - b, _ = self.parse_expr(args[1], dtype_hint) - # Compare as signed integers to make -0 < +0 - if dt == dtypes.float64: - a_bits = UOp(Ops.BITCAST, dtypes.int64, (a,)) - b_bits = UOp(Ops.BITCAST, dtypes.int64, (b,)) - elif dt == dtypes.float16: - a_bits = UOp(Ops.BITCAST, dtypes.int16, (a,)) - b_bits = UOp(Ops.BITCAST, dtypes.int16, (b,)) - else: # f32 - a_bits = UOp(Ops.BITCAST, dtypes.int32, (a,)) - b_bits = UOp(Ops.BITCAST, dtypes.int32, (b,)) - return UOp(Ops.CMPLT, dtypes.bool, (a_bits, b_bits)), dtypes.bool - elif fn_name == 'GT_NEG_ZERO' and len(args) == 2: - # GT_NEG_ZERO(a, b) - greater than comparison where -0 < +0 - a, dt = self.parse_expr(args[0], dtype_hint) - b, _ = self.parse_expr(args[1], dtype_hint) - # Compare as signed integers - if dt == dtypes.float64: - a_bits = UOp(Ops.BITCAST, dtypes.int64, (a,)) - b_bits = UOp(Ops.BITCAST, dtypes.int64, (b,)) - elif dt == dtypes.float16: - a_bits = UOp(Ops.BITCAST, dtypes.int16, (a,)) - b_bits = UOp(Ops.BITCAST, dtypes.int16, (b,)) - else: # f32 - a_bits = UOp(Ops.BITCAST, dtypes.int32, (a,)) - b_bits = UOp(Ops.BITCAST, dtypes.int32, (b,)) - return UOp(Ops.CMPLT, dtypes.bool, (b_bits, a_bits)), dtypes.bool - elif fn_name == 'SAT8' and len(args) == 1: - # SAT8(x) - saturate to signed 8-bit range [-128, 127] - inner_uop, dt = self.parse_expr(args[0], dtype_hint) - # Clamp to [-128, 127] - lo = self.const(-128, dt) - hi = self.const(127, dt) - clamped_lo = UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (inner_uop, lo)), lo, inner_uop)) - clamped = UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (hi, clamped_lo)), hi, clamped_lo)) - return clamped, dt - # v_min/v_max functions - just forward to min/max - elif fn_name in ('v_min_f32', 'v_min_f16', 'v_min_f64', 'v_min_i32', 'v_min_i16', 'v_min_u32', 'v_min_u16') and len(args) == 2: - a, dt = self.parse_expr(args[0], dtype_hint) - b, _ = self.parse_expr(args[1], dtype_hint) - cond = UOp(Ops.CMPLT, dtypes.bool, (a, b)) - return UOp(Ops.WHERE, dt, (cond, a, b)), dt - elif fn_name in ('v_max_f32', 'v_max_f16', 'v_max_f64', 'v_max_i32', 'v_max_i16', 'v_max_u32', 'v_max_u16') and len(args) == 2: - a, dt = self.parse_expr(args[0], dtype_hint) - b, _ = self.parse_expr(args[1], dtype_hint) - cond = UOp(Ops.CMPLT, dtypes.bool, (b, a)) - return UOp(Ops.WHERE, dt, (cond, a, b)), dt - - # Handle ternary: cond ? true_val : false_val - depth = bracket = 0 - q_pos = c_pos = -1 - for i, c in enumerate(expr): - if c == '(': depth += 1 - elif c == ')': depth -= 1 - elif c == '[': bracket += 1 - elif c == ']': bracket -= 1 - elif c == '?' and depth == 0 and bracket == 0 and q_pos < 0: q_pos = i - elif c == ':' and depth == 0 and bracket == 0 and q_pos >= 0: c_pos = i; break - if q_pos > 0 and c_pos > q_pos: - cond_uop, _ = self.parse_expr(expr[:q_pos].strip()) - true_uop, true_dt = self.parse_expr(expr[q_pos+1:c_pos].strip(), dtype_hint) - false_uop, false_dt = self.parse_expr(expr[c_pos+1:].strip(), dtype_hint) - return UOp(Ops.WHERE, true_dt, (cond_uop, true_uop, false_uop)), true_dt - - binop_map = { - '+': Ops.ADD, '-': Ops.SUB, '*': Ops.MUL, '/': Ops.FDIV, - '&': Ops.AND, '|': Ops.OR, '^': Ops.XOR, - '<<': Ops.SHL, '>>': Ops.SHR, - '<': Ops.CMPLT, '==': Ops.CMPEQ, '!=': Ops.CMPNE, - } - - # Handle binary operators (lowest precedence first) - # Order matters: check longer ops before shorter ones to avoid << matching < - # ** (exponentiation) is highest precedence among binary ops - for ops in [('||',), ('&&',), ('==', '!=', '<>', '<=', '>=', '<', '>'), ('|',), ('^',), ('&',), ('<<', '>>'), ('+', '-'), ('*', '/'), ('**',)]: - depth = bracket = 0 - for i in range(len(expr) - 1, -1, -1): - c = expr[i] - if c == ')': depth += 1 - elif c == '(': depth -= 1 - elif c == ']': bracket += 1 - elif c == '[': bracket -= 1 - elif depth == 0 and bracket == 0: - for op in sorted(ops, key=len, reverse=True): # longest first - if expr[i:i+len(op)] == op: - # Check we're not matching < when we should match << or <= - if op in ('<', '>') and i + 1 < len(expr) and expr[i+1] in '<>=': - continue - if op in ('<', '>') and i > 0 and expr[i-1] in '<>=': - continue - # Check we're not matching * when it's part of ** - if op == '*' and i + 1 < len(expr) and expr[i+1] == '*': - continue - if op == '*' and i > 0 and expr[i-1] == '*': - continue - left_expr = expr[:i].strip() - right_expr = expr[i+len(op):].strip() - if not left_expr: continue - # Skip if this looks like unary - after another operator - if op == '-' and left_expr and left_expr[-1] in '+-*/(<>=&|^': continue - left_uop, left_dt = self.parse_expr(left_expr) - right_uop, right_dt = self.parse_expr(right_expr, left_dt) - result_dt = left_dt if _is_float(left_dt) else right_dt if _is_float(right_dt) else left_dt - - if op == '||': - one, zero = self.const(1, dtypes.uint32), self.const(0, dtypes.uint32) - inner = UOp(Ops.WHERE, dtypes.uint32, (right_uop, one, zero)) - return UOp(Ops.WHERE, dtypes.uint32, (left_uop, one, inner)), dtypes.uint32 - if op == '&&': - one, zero = self.const(1, dtypes.uint32), self.const(0, dtypes.uint32) - inner = UOp(Ops.WHERE, dtypes.uint32, (right_uop, one, zero)) - return UOp(Ops.WHERE, dtypes.uint32, (left_uop, inner, zero)), dtypes.uint32 - if op == '<>': op = '!=' - - # Handle comparison ops that don't have direct UOp equivalents - if op == '>': - return UOp(Ops.CMPLT, dtypes.bool, (right_uop, left_uop)), dtypes.bool - if op == '>=': - lt = UOp(Ops.CMPLT, dtypes.bool, (left_uop, right_uop)) - return UOp(Ops.XOR, dtypes.bool, (lt, self.const(True, dtypes.bool))), dtypes.bool - if op == '<=': - lt = UOp(Ops.CMPLT, dtypes.bool, (right_uop, left_uop)) - return UOp(Ops.XOR, dtypes.bool, (lt, self.const(True, dtypes.bool))), dtypes.bool - - # Handle ** (exponentiation) - when base is 2.0, use exp2 - if op == '**': - # 2.0 ** x = exp2(x), 2.0F ** x = exp2(x) - # Check if left side is 2.0 constant - if left_uop.op == Ops.CONST and left_uop.arg == 2.0: - # For exponent, always use CAST (value conversion), not BITCAST - exp_uop = UOp(Ops.CAST, result_dt, (right_uop,)) if right_uop.dtype != result_dt else right_uop - return UOp(Ops.EXP2, result_dt, (exp_uop,)), result_dt - # General case: a ** b = exp2(b * log2(a)) - log_a = UOp(Ops.LOG2, result_dt, (left_uop,)) - exp_uop = UOp(Ops.CAST, result_dt, (right_uop,)) if right_uop.dtype != result_dt else right_uop - product = UOp(Ops.MUL, result_dt, (exp_uop, log_a)) - return UOp(Ops.EXP2, result_dt, (product,)), result_dt - - uop_op = binop_map.get(op) - if uop_op is None: raise ValueError(f"Unknown operator: {op}") - out_dt = dtypes.bool if uop_op in (Ops.CMPLT, Ops.CMPEQ, Ops.CMPNE) else result_dt - return UOp(uop_op, out_dt, (left_uop, right_uop)), out_dt - - # Unary operators - if expr.startswith('-'): - val_uop, dt = self.parse_expr(expr[1:]) - return UOp(Ops.NEG, dt, (val_uop,)), dt - if expr.startswith('~'): - val_uop, dt = self.parse_expr(expr[1:]) - return UOp(Ops.XOR, dt, (val_uop, self.const(-1, dt))), dt - if expr.startswith('!'): - val_uop, dt = self.parse_expr(expr[1:]) - return UOp(Ops.CMPEQ, dtypes.bool, (val_uop, self.const(0, dt))), dtypes.bool - - # Bit slice with type suffix: S0[4:0].u32, S0[15:0].f16 - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\s*\[\s*(\d+)\s*:\s*(\d+)\s*\]\.([a-z]\d+)$', expr): - var, high, low, typ = m.group(1), int(m.group(2)), int(m.group(3)), m.group(4) - dtype = DTYPE_MAP.get(typ, dtypes.uint32) - if high < low: high, low = low, high - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - mask = (1 << (high - low + 1)) - 1 - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.const(low, base_uop.dtype))) if low > 0 else base_uop - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(mask, dtypes.uint32))) - # For float types, bitcast from integer bits - if _is_float(dtype): - if dtype == dtypes.float16: - return UOp(Ops.BITCAST, dtypes.float16, (self.cast(masked, dtypes.uint16),)), dtype - return UOp(Ops.BITCAST, dtype, (masked,)), dtype - return self.cast(masked, dtype), dtype - - # Bit slice with type prefix: S0.u32[31:24], S0.u[5:0] - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\.([a-z]\d*)\s*\[\s*(\d+)\s*:\s*(\d+)\s*\]$', expr): - var, typ, high, low = m.group(1), m.group(2), int(m.group(3)), int(m.group(4)) - if high < low: high, low = low, high - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - mask = (1 << (high - low + 1)) - 1 - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.const(low, base_uop.dtype))) if low > 0 else base_uop - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(mask, dtypes.uint32))) - return masked, dtypes.uint32 - - # Bit slice with both type prefix and suffix: S0.u32[31:24].u32 - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\.([a-z]\d+)\s*\[\s*(\d+)\s*:\s*(\d+)\s*\]\.([a-z]\d+)$', expr): - var, var_typ, high, low, result_typ = m.group(1), m.group(2), int(m.group(3)), int(m.group(4)), m.group(5) - if high < low: high, low = low, high - dtype = DTYPE_MAP.get(result_typ, dtypes.uint32) - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - mask = (1 << (high - low + 1)) - 1 - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.const(low, base_uop.dtype))) if low > 0 else base_uop - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(mask, dtypes.uint32))) - return self.cast(masked, dtype), dtype - - # Bit slice without type: S0[4:0] - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\s*\[\s*(\d+)\s*:\s*(\d+)\s*\]$', expr): - var, high, low = m.group(1), int(m.group(2)), int(m.group(3)) - if high < low: high, low = low, high - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - mask = (1 << (high - low + 1)) - 1 - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.const(low, base_uop.dtype))) if low > 0 else base_uop - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(mask, dtypes.uint32))) - return masked, dtype_hint or dtypes.uint32 - - # Bit index with expression: S1.u32[expr] - extract single bit at computed index - # Handle complex expressions like S1.u32[sign(S0.f32) ? 5 : 6] - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\.([a-z]\d+)\[(.+)\]$', expr): - var, typ, idx_expr = m.group(1), m.group(2), m.group(3) - # Check it's not a bit range (digit:digit pattern without ?) - # Allow expressions containing : if they also have ? (ternary) - is_bit_range = ':' in idx_expr and '?' not in idx_expr and re.match(r'^\s*\d+\s*:\s*\d+\s*$', idx_expr) - if not is_bit_range: - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - # Parse idx_expr as an expression - if idx_expr in self.vars: - idx_uop = self.vars[idx_expr] - elif idx_expr.isdigit(): - idx_uop = self.const(int(idx_expr), dtypes.uint32) - else: - idx_uop, _ = self.parse_expr(idx_expr) - # (base >> idx) & 1 - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.cast(idx_uop, base_uop.dtype))) - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(1, dtypes.uint32))) - return masked, dtypes.uint32 - - # Bit index with variable index AND result type: VCC.u64[laneId].u32 - extract single bit with result type - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\.([a-z]\d+)\[(\w+)\]\.([a-z]\d+)$', expr): - var, var_typ, idx_expr, result_typ = m.group(1), m.group(2), m.group(3), m.group(4) - dtype = DTYPE_MAP.get(result_typ, dtypes.uint32) - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - if idx_expr in self.vars: - idx_uop = self.vars[idx_expr] - else: - idx_uop = self.const(int(idx_expr), dtypes.uint32) - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.cast(idx_uop, base_uop.dtype))) - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(1, dtypes.uint32))) - return self.cast(masked, dtype), dtype - - # Bit index with result type: S2.u32[24].u8 - extract single bit with type - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\.([a-z]\d+)\[(\d+)\]\.([a-z]\d+)$', expr): - var, var_typ, bit_idx, result_typ = m.group(1), m.group(2), int(m.group(3)), m.group(4) - dtype = DTYPE_MAP.get(result_typ, dtypes.uint32) - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - # (base >> bit_idx) & 1 - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.const(bit_idx, base_uop.dtype))) - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(1, dtypes.uint32))) - return self.cast(masked, dtype), dtype - - # Bit index without type: tmp[31], SIMM16.i16[15] - extract single bit - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)(?:\.([a-z]\d+))?\[(\d+)\]$', expr): - var, typ, bit_idx = m.group(1), m.group(2), int(m.group(3)) - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - # (base >> bit_idx) & 1 - shifted = UOp(Ops.SHR, base_uop.dtype, (base_uop, self.const(bit_idx, base_uop.dtype))) - masked = UOp(Ops.AND, dtypes.uint32, (self.cast(shifted, dtypes.uint32), self.const(1, dtypes.uint32))) - return masked, dtypes.uint32 - - # Typed variable: S0.f32, S0.u24, S0.i24, S0.f64, S0.f16, EXEC.u64, SIMM16.i16, tmp.u32, etc. - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\.([a-z]\d+)$', expr): - var, typ = m.group(1), m.group(2) - dtype = DTYPE_MAP.get(typ, dtypes.uint32) - # Handle VCCZ and EXECZ specially - they're computed from VCC/EXEC - if var == 'VCCZ': - vcc = self.vars.get('VCC') - return UOp(Ops.CMPEQ, dtypes.bool, (vcc, self.const(0, dtypes.uint64))), dtypes.bool - if var == 'EXECZ': - exec_mask = self.vars.get('EXEC') - return UOp(Ops.CMPEQ, dtypes.bool, (exec_mask, self.const(0, dtypes.uint64))), dtypes.bool - # For 64-bit types, use the _64 variant of the variable (for input vars only) - if typ in ('f64', 'u64', 'i64', 'b64') and var.isupper(): - base_uop = self.vars.get(var + '_64') - if base_uop is None: base_uop = self.vars.get(var) # fallback - else: - base_uop = self.vars.get(var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - # For 24-bit types, mask to 24 bits - if typ == 'u24': - masked = UOp(Ops.AND, dtypes.uint32, (base_uop, self.const(0xffffff, dtypes.uint32))) - return masked, dtypes.uint32 - if typ == 'i24': - # Sign-extend from 24 bits: ((x & 0xffffff) ^ 0x800000) - 0x800000 - masked = UOp(Ops.AND, dtypes.uint32, (base_uop, self.const(0xffffff, dtypes.uint32))) - xored = UOp(Ops.XOR, dtypes.int32, (masked, self.const(0x800000, dtypes.int32))) - sext = UOp(Ops.SUB, dtypes.int32, (xored, self.const(0x800000, dtypes.int32))) - return sext, dtypes.int32 - # For float types, bitcast from the integer representation - if typ in ('f32', 'f64'): - return UOp(Ops.BITCAST, dtype, (base_uop,)), dtype - if typ == 'f16': - # Mask to 16 bits and bitcast to f16 - masked = UOp(Ops.AND, dtypes.uint16, (self.cast(base_uop, dtypes.uint16), self.const(0xffff, dtypes.uint16))) - return UOp(Ops.BITCAST, dtypes.float16, (masked,)), dtypes.float16 - return self.cast(base_uop, dtype), dtype - - # Plain variable - if expr in self.vars: - uop = self.vars[expr] - dtype = dtype_hint or uop.dtype - return self.cast(uop, dtype), dtype - - # Special constants - import math - if expr == 'PI': return self.const(math.pi, dtype_hint or dtypes.float64), dtype_hint or dtypes.float64 - if expr in ('INF', '+INF'): return self.const(float('inf'), dtype_hint or dtypes.float64), dtype_hint or dtypes.float64 - if expr == '-INF': return self.const(float('-inf'), dtype_hint or dtypes.float64), dtype_hint or dtypes.float64 - # Mode constants - fixed at compile time for RDNA3 - if expr == 'WAVE_MODE.IEEE': return self.const(1, dtypes.uint32), dtypes.uint32 # IEEE mode enabled - if expr == 'WAVE32': return self.const(1, dtypes.uint32), dtypes.uint32 # 32-lane wavefront - if expr == 'WAVE64': return self.const(0, dtypes.uint32), dtypes.uint32 # not 64-lane wavefront - if expr == 'ROUND_MODE': return self.const(0, dtypes.uint32), dtypes.uint32 # round to nearest even - # PC is passed as input - if expr == 'PC': return self.vars.get('PC', self.const(0, dtypes.uint64)), dtypes.uint64 - # VCCZ and EXECZ - zero flags (VCC==0 and EXEC==0) - if expr == 'VCCZ': - vcc = self.vars.get('VCC') - return UOp(Ops.CMPEQ, dtypes.bool, (vcc, self.const(0, dtypes.uint64))), dtypes.bool - if expr == 'EXECZ': - exec_mask = self.vars.get('EXEC') - return UOp(Ops.CMPEQ, dtypes.bool, (exec_mask, self.const(0, dtypes.uint64))), dtypes.bool - - # Width-prefixed constants: 16'4 means 4 as 16-bit, 64'0 means 0 as 64-bit - if m := re.match(r"^(\d+)'(-?\d+)$", expr): - bits, val = int(m.group(1)), int(m.group(2)) - dtype_map = {8: dtypes.uint8, 16: dtypes.int16, 32: dtypes.int32, 64: dtypes.int64} - dtype = dtype_map.get(bits, dtypes.int32) - return self.const(val, dtype), dtype - - # Numeric literals - expr_clean = re.sub(r"(\d+)'([0-9a-fA-Fx]+)[UuLlFf]*", r'\2', expr) - expr_clean = re.sub(r'([0-9a-fA-Fx]+)[UuLlFf]+$', r'\1', expr_clean) - try: - if expr_clean.startswith('0x') or expr_clean.startswith('0X'): - return self.const(int(expr_clean, 16), dtype_hint or dtypes.uint32), dtype_hint or dtypes.uint32 - if '.' in expr_clean or 'e' in expr_clean.lower(): - return self.const(float(expr_clean), dtype_hint or dtypes.float32), dtype_hint or dtypes.float32 - return self.const(int(expr_clean), dtype_hint or dtypes.uint32), dtype_hint or dtypes.uint32 - except ValueError: - pass - - # Handle pack syntax: { hi, lo } -> (hi << N) | lo - # For .u32/.u16 concatenation: { S0.u32, S1.u32 } -> 64-bit, { S0.u16, S1.u16 } -> 32-bit - if expr.startswith('{') and expr.endswith('}'): - inner = expr[1:-1].strip() - # Find the comma that separates hi and lo - depth = 0 - comma_pos = -1 - for i, c in enumerate(inner): - if c in '([{': depth += 1 - elif c in ')]}': depth -= 1 - elif c == ',' and depth == 0: - comma_pos = i - break - if comma_pos > 0: - hi_expr = inner[:comma_pos].strip() - lo_expr = inner[comma_pos+1:].strip() - hi_uop, hi_dt = self.parse_expr(hi_expr) - lo_uop, lo_dt = self.parse_expr(lo_expr) - # Determine shift amount based on lo size - if lo_dt.itemsize >= 4: - # 32-bit elements -> 64-bit result - hi_ext = self.cast(hi_uop, dtypes.uint64) - lo_ext = self.cast(lo_uop, dtypes.uint64) - hi_shifted = UOp(Ops.SHL, dtypes.uint64, (hi_ext, self.const(32, dtypes.uint64))) - packed = UOp(Ops.OR, dtypes.uint64, (hi_shifted, lo_ext)) - return packed, dtypes.uint64 - else: - # 16-bit elements -> 32-bit result - hi_shifted = UOp(Ops.SHL, dtypes.uint32, (self.cast(hi_uop, dtypes.uint32), self.const(16, dtypes.uint32))) - lo_masked = UOp(Ops.AND, dtypes.uint32, (self.cast(lo_uop, dtypes.uint32), self.const(0xffff, dtypes.uint32))) - packed = UOp(Ops.OR, dtypes.uint32, (hi_shifted, lo_masked)) - return packed, dtypes.uint32 - - raise ValueError(f"Cannot parse expression: {expr}") - - def parse_stmt(self, line: str): - if '=' not in line or any(line.startswith(k) for k in ('if ', 'elsif ', 'for ', 'Set ')): - return - - # Handle += and -= operators - if '+=' in line or '-=' in line: - is_sub = '-=' in line - lhs, rhs = line.split('-=' if is_sub else '+=', 1) - lhs, rhs = lhs.strip(), rhs.strip() - var, dtype = self.parse_type(lhs) - curr_val = self.vars.get(var) - if curr_val is None: - curr_val = self.const(0, dtype) - inc_uop, _ = self.parse_expr(rhs, dtype) - # For float types, need to bitcast curr_val before arithmetic - if _is_float(dtype) and curr_val.dtype != dtype: - curr_val = UOp(Ops.BITCAST, dtype, (curr_val,)) - op = Ops.SUB if is_sub else Ops.ADD - result = UOp(op, dtype, (self.cast(curr_val, dtype), self.cast(inc_uop, dtype))) - # Store back as bits for floats - if _is_float(dtype): - result_bits = UOp(Ops.BITCAST, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64, (result,)) - self.vars[var] = result_bits - if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): - self.outputs = [(n, u, d) for n, u, d in self.outputs if n != var] - self.outputs.append((var, result_bits, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64)) - else: - self.vars[var] = result - if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): - self.outputs = [(n, u, d) for n, u, d in self.outputs if n != var] - self.outputs.append((var, result, dtype)) - return - - lhs, rhs = line.split('=', 1) - lhs = lhs.strip() - - # Handle bit set: D0.u64[laneId] = expr or EXEC.u64[laneId] = expr (sets single bit based on condition) - if m := re.match(r'^([A-Z][A-Z0-9]*)\.([a-z]\d+)\[(\w+)\]$', lhs): - var, typ, idx_var = m.group(1), m.group(2), m.group(3) - dtype = DTYPE_MAP.get(typ, dtypes.uint64) - base_uop = self.vars.get(var) - idx_uop = self.vars.get(idx_var) - if base_uop is None: raise ValueError(f"Unknown variable: {var}") - if idx_uop is None: raise ValueError(f"Unknown index variable: {idx_var}") - # Parse RHS as condition - cond_uop, _ = self.parse_expr(rhs.strip()) - # Set bit: (base & ~(1 << idx)) | (cond << idx) - one = self.const(1, dtype) - bit_mask = UOp(Ops.SHL, dtype, (one, self.cast(idx_uop, dtype))) - inv_mask = UOp(Ops.XOR, dtype, (bit_mask, self.const(-1, dtype))) - cleared = UOp(Ops.AND, dtype, (base_uop, inv_mask)) - cond_ext = self.cast(cond_uop, dtype) - cond_bit = UOp(Ops.SHL, dtype, (UOp(Ops.AND, dtype, (cond_ext, one)), self.cast(idx_uop, dtype))) - result = UOp(Ops.OR, dtype, (cleared, cond_bit)) - self.vars[var] = result - if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): - self.outputs.append((var, result, dtype)) - return - - # Handle bit range assignment: D0[31:16].f16 = expr, D0[15:0].f16 = expr, tmp[31:16].i16 = expr - if m := re.match(r'^([A-Za-z][A-Za-z0-9]*)\s*\[\s*(\d+)\s*:\s*(\d+)\s*\]\.([a-z]\d+)$', lhs): - var, high, low, typ = m.group(1), int(m.group(2)), int(m.group(3)), m.group(4) - if high < low: high, low = low, high - dtype = DTYPE_MAP.get(typ, dtypes.uint32) - base_uop = self.vars.get(var) - if base_uop is None: base_uop = self.const(0, dtypes.uint32) - rhs_uop, _ = self.parse_expr(rhs.strip(), dtype) - # For float types, convert to bits - if _is_float(dtype): - if dtype == dtypes.float16: - rhs_bits = UOp(Ops.BITCAST, dtypes.uint16, (rhs_uop,)) - rhs_bits = self.cast(rhs_bits, dtypes.uint32) - else: - rhs_bits = UOp(Ops.BITCAST, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64, (rhs_uop,)) - else: - rhs_bits = self.cast(rhs_uop, dtypes.uint32) - # Create mask and insert bits - width = high - low + 1 - mask = (1 << width) - 1 - shifted_val = UOp(Ops.SHL, dtypes.uint32, (UOp(Ops.AND, dtypes.uint32, (rhs_bits, self.const(mask, dtypes.uint32))), self.const(low, dtypes.uint32))) - inv_mask = ~(mask << low) & 0xffffffff - cleared = UOp(Ops.AND, dtypes.uint32, (self.cast(base_uop, dtypes.uint32), self.const(inv_mask, dtypes.uint32))) - result = UOp(Ops.OR, dtypes.uint32, (cleared, shifted_val)) - self.vars[var] = result - if var in ('D0', 'D1', 'SCC', 'VCC'): - # Only add to outputs if not already there, or replace existing - self.outputs = [(n, u, d) for n, u, d in self.outputs if n != var] - self.outputs.append((var, result, dtypes.uint32)) - return - - var, dtype = self.parse_type(lhs) - rhs_uop, _ = self.parse_expr(rhs.strip(), dtype) - self.vars[var] = rhs_uop - # For 64-bit outputs, also update the _64 variant so subsequent reads find the computed value - if dtype.itemsize == 8 and var in ('D0', 'D1', 'S0', 'S1'): - self.vars[var + '_64'] = rhs_uop - if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): - self.outputs.append((var, rhs_uop, dtype)) - def build_sink(self) -> UOp: - """Build a SINK UOp containing all outputs.""" - if not self.outputs: - return UOp(Ops.SINK, dtypes.void, ()) + if not self.outputs: return UOp(Ops.SINK, dtypes.void, ()) return UOp(Ops.SINK, dtypes.void, tuple(uop for _, uop, _ in self.outputs)) +# ═══════════════════════════════════════════════════════════════════════════════ +# AST -> UOP TRANSFORMER +# ═══════════════════════════════════════════════════════════════════════════════ + +def _get_var_dtype(name: str, qdt: QDType|None = None) -> tuple[str, DType]: + """Get variable name and dtype, handling 64-bit variants.""" + dt = _qdt(qdt) if qdt else dtypes.uint32 + if qdt in (QDType.F64, QDType.U64, QDType.I64, QDType.B64) and name.isupper(): + return name + '_64', dt + return name, dt + +def transform_expr(node, b: UOpBuilder, hint: DType = None) -> tuple[UOp, DType]: + """Transform qcode AST expression to UOp.""" + match node: + case Const(val, qdt): + dt = _qdt(qdt) if qdt != QDType.I32 or hint is None else hint + if isinstance(val, float) and not _is_float(dt): dt = dtypes.float32 + return b.const(val, dt), dt + + case Var(name): + # Special constants + if name == 'PI': return b.const(math.pi, hint or dtypes.float64), hint or dtypes.float64 + if name in ('INF', '+INF'): return b.const(float('inf'), hint or dtypes.float64), hint or dtypes.float64 + if name == '-INF': return b.const(float('-inf'), hint or dtypes.float64), hint or dtypes.float64 + if name == 'WAVE_MODE.IEEE': return b.const(1, dtypes.uint32), dtypes.uint32 + if name == 'WAVE32': return b.const(1, dtypes.uint32), dtypes.uint32 + if name == 'WAVE64': return b.const(0, dtypes.uint32), dtypes.uint32 + if name == 'ROUND_MODE': return b.const(0, dtypes.uint32), dtypes.uint32 + if name == 'VCCZ': + vcc = b.vars.get('VCC') + cmp = UOp(Ops.CMPEQ, dtypes.bool, (vcc, b.const(0, dtypes.uint64))) + return b.cast(cmp, dtypes.uint32), dtypes.uint32 + if name == 'EXECZ': + ex = b.vars.get('EXEC') + cmp = UOp(Ops.CMPEQ, dtypes.bool, (ex, b.const(0, dtypes.uint64))) + return b.cast(cmp, dtypes.uint32), dtypes.uint32 + if name.startswith('eval '): return b.vars.get('_eval', b.const(0, dtypes.uint32)), dtypes.uint32 + # Regular variable + if name not in b.vars: raise ValueError(f"Unknown variable: {name}") + uop = b.vars[name] + dt = hint or uop.dtype + return b.cast(uop, dt), dt + + case Typed(expr, qdt): + dt = _qdt(qdt) + var_name = expr.name if isinstance(expr, Var) else None + # Handle typed variable access + if var_name: + if var_name == 'VCCZ': + vcc = b.vars.get('VCC') + cmp = UOp(Ops.CMPEQ, dtypes.bool, (vcc, b.const(0, dtypes.uint64))) + # Cast to uint32 for integer comparisons + return b.cast(cmp, dtypes.uint32), dt + if var_name == 'EXECZ': + ex = b.vars.get('EXEC') + cmp = UOp(Ops.CMPEQ, dtypes.bool, (ex, b.const(0, dtypes.uint64))) + return b.cast(cmp, dtypes.uint32), dt + # For 64-bit types, use _64 variant + vn, vdt = _get_var_dtype(var_name, qdt) + base = b.vars.get(vn) if vn in b.vars else b.vars.get(var_name) + if base is None: raise ValueError(f"Unknown variable: {var_name}") + # Handle 24-bit types + if qdt == QDType.U24: + masked = UOp(Ops.AND, dtypes.uint32, (base, b.const(0xffffff, dtypes.uint32))) + return masked, dtypes.uint32 + if qdt == QDType.I24: + masked = UOp(Ops.AND, dtypes.uint32, (base, b.const(0xffffff, dtypes.uint32))) + xored = UOp(Ops.XOR, dtypes.int32, (masked, b.const(0x800000, dtypes.int32))) + return UOp(Ops.SUB, dtypes.int32, (xored, b.const(0x800000, dtypes.int32))), dtypes.int32 + # Float types need bitcast + if _is_float(dt): + if dt == dtypes.float16: + # Mask to 16 bits and bitcast to f16 + masked = UOp(Ops.AND, dtypes.uint16, (b.cast(base, dtypes.uint16), b.const(0xffff, dtypes.uint16))) + return UOp(Ops.BITCAST, dtypes.float16, (masked,)), dtypes.float16 + return UOp(Ops.BITCAST, dt, (base,)), dt + # For signed integer types, keep as unsigned to avoid overflow issues during simplify + # Return the unsigned base but report the signed dtype for semantic purposes + if dt == dtypes.int32: return base, dtypes.int32 + if dt == dtypes.int64: + base64 = b.vars.get(var_name + '_64') if (var_name + '_64') in b.vars else base + return base64, dtypes.int64 + if dt == dtypes.int16: return base, dtypes.int16 + if dt == dtypes.int8: return base, dtypes.int8 + return b.cast(base, dt), dt + # Non-variable typed expression + inner, _ = transform_expr(expr, b, dt) + if _is_float(dt): + if dt == dtypes.float16: + # For f16, need to cast to u16 first then bitcast + inner_u16 = b.cast(inner, dtypes.uint16) + return UOp(Ops.BITCAST, dt, (inner_u16,)), dt + return UOp(Ops.BITCAST, dt, (inner,)), dt + return b.cast(inner, dt), dt + + case Slice(expr, hi, lo): + base, base_dt = transform_expr(expr, b) + hi_uop, _ = transform_expr(hi, b) + lo_uop, _ = transform_expr(lo, b) + # For constant hi/lo, compute mask directly + if hi_uop.op == Ops.CONST and lo_uop.op == Ops.CONST: + hi_val, lo_val = int(hi_uop.arg), int(lo_uop.arg) + if hi_val < lo_val: hi_val, lo_val = lo_val, hi_val + mask = (1 << (hi_val - lo_val + 1)) - 1 + shifted = UOp(Ops.SHR, base_dt, (base, b.const(lo_val, base_dt))) if lo_val > 0 else base + masked = UOp(Ops.AND, dtypes.uint32, (b.cast(shifted, dtypes.uint32), b.const(mask, dtypes.uint32))) + return masked, hint or dtypes.uint32 + raise ValueError(f"Non-constant slice bounds not supported: {node}") + + case Index(expr, idx): + base, base_dt = transform_expr(expr, b) + idx_uop, _ = transform_expr(idx, b) + # Single bit extraction: (base >> idx) & 1 + shifted = UOp(Ops.SHR, base_dt, (base, b.cast(idx_uop, base_dt))) + masked = UOp(Ops.AND, dtypes.uint32, (b.cast(shifted, dtypes.uint32), b.const(1, dtypes.uint32))) + return masked, dtypes.uint32 + + case Cast(bits, typ, expr): + dtype_map = { + (16, 'I'): dtypes.int16, (16, 'U'): dtypes.uint16, (16, 'F'): dtypes.float16, + (32, 'I'): dtypes.int32, (32, 'U'): dtypes.uint32, (32, 'F'): dtypes.float32, (32, 'B'): dtypes.uint32, + (64, 'I'): dtypes.int64, (64, 'U'): dtypes.uint64, (64, 'F'): dtypes.float64, (64, 'B'): dtypes.uint64, + } + dt = dtype_map.get((bits, typ), dtypes.uint32) + inner, inner_dt = transform_expr(expr, b, dt) + if typ == 'F': return UOp(Ops.CAST, dt, (inner,)), dt + if inner_dt in (dtypes.uint32, dtypes.int32) and bits == 32: return inner, dt + if inner_dt in (dtypes.uint64, dtypes.int64) and bits == 64: return inner, dt + # For signed widening cast, first cast to signed type to get sign extension + if typ == 'I' and inner_dt in (dtypes.int32, dtypes.int16, dtypes.int8): + signed_inner = b.cast(inner, inner_dt) # BITCAST to signed + return UOp(Ops.CAST, dt, (signed_inner,)), dt + return b.cast(inner, dt), dt + + case Unary(op, expr): + val, dt = transform_expr(expr, b, hint) + if op == '-': return UOp(Ops.NEG, dt, (val,)), dt + if op == '~': return UOp(Ops.XOR, dt, (val, b.const(-1, dt))), dt + if op == '!': return UOp(Ops.CMPEQ, dtypes.bool, (val, b.const(0, dt))), dtypes.bool + raise ValueError(f"Unknown unary op: {op}") + + case Binary(op, left, right): + l, l_dt = transform_expr(left, b, hint) + r, r_dt = transform_expr(right, b, l_dt if _is_float(l_dt) else hint) + # Use actual UOp dtype for arithmetic to avoid type mismatches + # The semantic dtype (l_dt/r_dt) may be signed but UOp is unsigned + result_dt = l.dtype if _is_float(l.dtype) else r.dtype if _is_float(r.dtype) else l.dtype + + binop_map = {'+': Ops.ADD, '-': Ops.SUB, '*': Ops.MUL, '/': Ops.FDIV, '&': Ops.AND, '|': Ops.OR, '^': Ops.XOR, + '<<': Ops.SHL, '==': Ops.CMPEQ, '!=': Ops.CMPNE, '<>': Ops.CMPNE, '<': Ops.CMPLT} + # >> is logical shift for unsigned, arithmetic shift for signed + if op == '>>': + if l_dt in (dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8): + signed_l = b.cast(l, l_dt) + shifted = UOp(Ops.SHR, l_dt, (signed_l, r)) + return b.cast(shifted, l.dtype), l_dt + return UOp(Ops.SHR, result_dt, (l, r)), result_dt + if op == '||': + one, zero = b.const(1, dtypes.uint32), b.const(0, dtypes.uint32) + inner = UOp(Ops.WHERE, dtypes.uint32, (r, one, zero)) + return UOp(Ops.WHERE, dtypes.uint32, (l, one, inner)), dtypes.uint32 + if op == '&&': + one, zero = b.const(1, dtypes.uint32), b.const(0, dtypes.uint32) + inner = UOp(Ops.WHERE, dtypes.uint32, (r, one, zero)) + return UOp(Ops.WHERE, dtypes.uint32, (l, inner, zero)), dtypes.uint32 + # For signed comparisons, use the semantic dtype (l_dt) for comparison + def _cmp_operands(): + if l_dt in (dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8): + return b.cast(l, l_dt), b.cast(r, l_dt) + return l, r + if op == '>': + cmp_l, cmp_r = _cmp_operands() + return UOp(Ops.CMPLT, dtypes.bool, (cmp_r, cmp_l)), dtypes.bool + if op == '>=': + cmp_l, cmp_r = _cmp_operands() + lt = UOp(Ops.CMPLT, dtypes.bool, (cmp_l, cmp_r)) + return UOp(Ops.XOR, dtypes.bool, (lt, b.const(True, dtypes.bool))), dtypes.bool + if op == '<=': + cmp_l, cmp_r = _cmp_operands() + lt = UOp(Ops.CMPLT, dtypes.bool, (cmp_r, cmp_l)) + return UOp(Ops.XOR, dtypes.bool, (lt, b.const(True, dtypes.bool))), dtypes.bool + if op == '**': + if l.op == Ops.CONST and l.arg == 2.0: + # For signed exponents, cast to signed first to get correct sign extension + if r_dt in (dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8): + r_signed = b.cast(r, r_dt) + exp = UOp(Ops.CAST, result_dt, (r_signed,)) + else: + exp = UOp(Ops.CAST, result_dt, (r,)) if r.dtype != result_dt else r + return UOp(Ops.EXP2, result_dt, (exp,)), result_dt + log_a = UOp(Ops.LOG2, result_dt, (l,)) + exp = UOp(Ops.CAST, result_dt, (r,)) if r.dtype != result_dt else r + return UOp(Ops.EXP2, result_dt, (UOp(Ops.MUL, result_dt, (exp, log_a)),)), result_dt + if op == '%': + # a % b = a - (a / b) * b (integer modulo) + div = UOp(Ops.IDIV, result_dt, (l, r)) + return UOp(Ops.SUB, result_dt, (l, UOp(Ops.MUL, result_dt, (div, r)))), result_dt + uop_op = binop_map.get(op) + if uop_op is None: raise ValueError(f"Unknown binary op: {op}") + out_dt = dtypes.bool if uop_op in (Ops.CMPLT, Ops.CMPEQ, Ops.CMPNE) else result_dt + # For signed < comparison, cast operands to signed type + if op == '<' and l_dt in (dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8): + cmp_l, cmp_r = b.cast(l, l_dt), b.cast(r, l_dt) + return UOp(Ops.CMPLT, dtypes.bool, (cmp_l, cmp_r)), dtypes.bool + return UOp(uop_op, out_dt, (l, r)), out_dt + + case Ternary(cond, t, f): + c, _ = transform_expr(cond, b) + tv, t_dt = transform_expr(t, b, hint) + fv, f_dt = transform_expr(f, b, t_dt) + return UOp(Ops.WHERE, t_dt, (c, tv, fv)), t_dt + + case Call(name, args): + return _transform_call(name, args, b, hint) + + case Pack(exprs): + if len(exprs) == 2: + hi, hi_dt = transform_expr(exprs[0], b) + lo, lo_dt = transform_expr(exprs[1], b) + if lo_dt.itemsize >= 4: + hi_ext = b.cast(hi, dtypes.uint64) + lo_ext = b.cast(lo, dtypes.uint64) + hi_shifted = UOp(Ops.SHL, dtypes.uint64, (hi_ext, b.const(32, dtypes.uint64))) + return UOp(Ops.OR, dtypes.uint64, (hi_shifted, lo_ext)), dtypes.uint64 + else: + hi_shifted = UOp(Ops.SHL, dtypes.uint32, (b.cast(hi, dtypes.uint32), b.const(16, dtypes.uint32))) + lo_masked = UOp(Ops.AND, dtypes.uint32, (b.cast(lo, dtypes.uint32), b.const(0xffff, dtypes.uint32))) + return UOp(Ops.OR, dtypes.uint32, (hi_shifted, lo_masked)), dtypes.uint32 + raise ValueError(f"Pack with {len(exprs)} elements not supported") + + raise ValueError(f"Cannot transform expression: {node}") + +def _transform_call(name: str, args: tuple, b: UOpBuilder, hint: DType) -> tuple[UOp, DType]: + """Transform function call to UOp.""" + def arg(i, h=None): return transform_expr(args[i], b, h) + + # Memory access + if name == 'MEM': + addr, _ = arg(0) + return addr, hint or dtypes.uint32 + + # Math functions + if name == 'fma' and len(args) == 3: + a, _ = arg(0, hint); bv, _ = arg(1, hint); c, dt = arg(2, hint) + return UOp(Ops.MULACC, dt, (a, bv, c)), dt + if name == 'trunc' and len(args) == 1: + v, dt = arg(0, hint); return UOp(Ops.TRUNC, dt, (v,)), dt + if name == 'floor' and len(args) == 1: + v, dt = arg(0, hint); return UOp(Ops.TRUNC, dt, (v,)), dt # TODO: proper floor + if name == 'sqrt' and len(args) == 1: + v, dt = arg(0, hint); return UOp(Ops.SQRT, dt, (v,)), dt + if name == 'abs' and len(args) == 1: + v, dt = arg(0, hint) + neg = UOp(Ops.NEG, dt, (v,)) + cond = UOp(Ops.CMPLT, dtypes.bool, (v, b.const(0, dt))) + return UOp(Ops.WHERE, dt, (cond, neg, v)), dt + if name == 'exp2' and len(args) == 1: + v, dt = arg(0, hint); return UOp(Ops.EXP2, dt, (v,)), dt + if name == 'log2' and len(args) == 1: + v, dt = arg(0, hint); return UOp(Ops.LOG2, dt, (v,)), dt + if name == 'sin' and len(args) == 1: + v, dt = arg(0, hint); return UOp(Ops.SIN, dt, (v,)), dt + if name == 'cos' and len(args) == 1: + v, dt = arg(0, hint) + pi_2 = b.const(1.5707963267948966, dt) + return UOp(Ops.SIN, dt, (UOp(Ops.ADD, dt, (v, pi_2)),)), dt + if name == 'rcp' and len(args) == 1: + v, dt = arg(0, hint); return UOp(Ops.RECIPROCAL, dt, (v,)), dt + if name == 'rsqrt' and len(args) == 1: + v, dt = arg(0, hint) + return UOp(Ops.RECIPROCAL, dt, (UOp(Ops.SQRT, dt, (v,)),)), dt + if name == 'min' and len(args) == 2: + a, dt = arg(0, hint); bv, _ = arg(1, hint) + # For signed types, compare with signed dtype + cmp_a, cmp_b = (b.cast(a, dt), b.cast(bv, dt)) if dt in (dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8) else (a, bv) + return UOp(Ops.WHERE, a.dtype, (UOp(Ops.CMPLT, dtypes.bool, (cmp_a, cmp_b)), a, bv)), dt + if name == 'max' and len(args) == 2: + a, dt = arg(0, hint); bv, _ = arg(1, hint) + cmp_a, cmp_b = (b.cast(a, dt), b.cast(bv, dt)) if dt in (dtypes.int32, dtypes.int64, dtypes.int16, dtypes.int8) else (a, bv) + return UOp(Ops.WHERE, a.dtype, (UOp(Ops.CMPLT, dtypes.bool, (cmp_b, cmp_a)), a, bv)), dt + if name == 'clamp' and len(args) == 3: + x, dt = arg(0, hint); lo, _ = arg(1, hint); hi, _ = arg(2, hint) + cond_lo = UOp(Ops.CMPLT, dtypes.bool, (x, lo)) + max_val = UOp(Ops.WHERE, dt, (cond_lo, lo, x)) + cond_hi = UOp(Ops.CMPLT, dtypes.bool, (hi, max_val)) + return UOp(Ops.WHERE, dt, (cond_hi, hi, max_val)), dt + if name == 'fract' and len(args) == 1: + v, dt = arg(0, hint) + truncated = UOp(Ops.TRUNC, dt, (v,)) + return UOp(Ops.SUB, dt, (v, truncated)), dt + + # NaN/Inf checks + if name == 'isNAN' and len(args) == 1: + v, dt = arg(0, hint) + return UOp(Ops.CMPNE, dtypes.bool, (v, v)), dtypes.bool + if name == 'isQuietNAN' and len(args) == 1: + v, dt = arg(0, hint) + return UOp(Ops.CMPNE, dtypes.bool, (v, v)), dtypes.bool + if name == 'isSignalNAN' and len(args) == 1: + return b.const(0, dtypes.bool), dtypes.bool + if name == 'cvtToQuietNAN' and len(args) == 1: + v, dt = arg(0, hint); return v, dt + if name == 'isINF' and len(args) == 1: + v, dt = arg(0, hint) + inf = b.const(float('inf'), dt) + neg_inf = b.const(float('-inf'), dt) + is_pos = UOp(Ops.CMPEQ, dtypes.bool, (v, inf)) + is_neg = UOp(Ops.CMPEQ, dtypes.bool, (v, neg_inf)) + return UOp(Ops.OR, dtypes.bool, (is_pos, is_neg)), dtypes.bool + + # Type conversions + cvt_map = { + 'u32_to_f32': (dtypes.float32, False), 'i32_to_f32': (dtypes.float32, False), + 'f32_to_u32': (dtypes.uint32, True), 'f32_to_i32': (dtypes.int32, False), + 'f16_to_f32': (dtypes.float32, False), 'f32_to_f16': (dtypes.float16, False), + 'f32_to_u8': (dtypes.uint8, False), 'f32_to_i8': (dtypes.int8, False), + 'f32_to_u16': (dtypes.uint16, False), 'f32_to_i16': (dtypes.int16, False), + 'v_cvt_u16_f32': (dtypes.uint16, False), 'v_cvt_i16_f32': (dtypes.int16, False), + 'f64_to_i32': (dtypes.int32, False), 'f64_to_u32': (dtypes.uint32, True), + 'i32_to_f64': (dtypes.float64, False), 'u32_to_f64': (dtypes.float64, False), + 'f64_to_f32': (dtypes.float32, False), 'f32_to_f64': (dtypes.float64, False), + 'u16_to_f16': (dtypes.float16, False), 'i16_to_f16': (dtypes.float16, False), + 'f16_to_u16': (dtypes.uint16, False), 'f16_to_i16': (dtypes.int16, False), + } + if name in cvt_map and len(args) == 1: + v, v_dt = arg(0) + dt, clamp_neg = cvt_map[name] + if clamp_neg: + zero = b.const(0.0, v_dt) + v = UOp(Ops.WHERE, v_dt, (UOp(Ops.CMPLT, dtypes.bool, (v, zero)), zero, v)) + return UOp(Ops.CAST, dt, (v,)), dt + + if name == 'f16_to_snorm' and len(args) == 1: + v, dt = arg(0) + clamped = UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (v, b.const(-1.0, dt))), b.const(-1.0, dt), v)) + clamped = UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (b.const(1.0, dt), clamped)), b.const(1.0, dt), clamped)) + scaled = UOp(Ops.MUL, dt, (clamped, b.const(32767.0, dt))) + return UOp(Ops.CAST, dtypes.int16, (scaled,)), dtypes.int16 + if name == 'f16_to_unorm' and len(args) == 1: + v, dt = arg(0) + clamped = UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (v, b.const(0.0, dt))), b.const(0.0, dt), v)) + clamped = UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (b.const(1.0, dt), clamped)), b.const(1.0, dt), clamped)) + scaled = UOp(Ops.MUL, dt, (clamped, b.const(65535.0, dt))) + return UOp(Ops.CAST, dtypes.uint16, (scaled,)), dtypes.uint16 + + # Sign/exponent/mantissa extraction + if name == 'sign' and len(args) == 1: + v, dt = arg(0, hint) + if dt == dtypes.float64: + bits = UOp(Ops.BITCAST, dtypes.uint64, (v,)) + sign = UOp(Ops.SHR, dtypes.uint64, (bits, b.const(63, dtypes.uint64))) + return UOp(Ops.AND, dtypes.uint32, (b.cast(sign, dtypes.uint32), b.const(1, dtypes.uint32))), dtypes.uint32 + elif dt == dtypes.float16: + bits = UOp(Ops.BITCAST, dtypes.uint16, (v,)) + sign = UOp(Ops.SHR, dtypes.uint16, (bits, b.const(15, dtypes.uint16))) + return UOp(Ops.AND, dtypes.uint32, (b.cast(sign, dtypes.uint32), b.const(1, dtypes.uint32))), dtypes.uint32 + else: + bits = UOp(Ops.BITCAST, dtypes.uint32, (v,)) + sign = UOp(Ops.SHR, dtypes.uint32, (bits, b.const(31, dtypes.uint32))) + return UOp(Ops.AND, dtypes.uint32, (sign, b.const(1, dtypes.uint32))), dtypes.uint32 + + if name == 'exponent' and len(args) == 1: + v, dt = arg(0, hint) + if dt == dtypes.float64: + bits = UOp(Ops.BITCAST, dtypes.uint64, (v,)) + exp = UOp(Ops.SHR, dtypes.uint64, (bits, b.const(52, dtypes.uint64))) + return UOp(Ops.AND, dtypes.uint32, (b.cast(exp, dtypes.uint32), b.const(0x7ff, dtypes.uint32))), dtypes.uint32 + elif dt == dtypes.float16: + bits = UOp(Ops.BITCAST, dtypes.uint16, (v,)) + exp = UOp(Ops.SHR, dtypes.uint16, (bits, b.const(10, dtypes.uint16))) + return UOp(Ops.AND, dtypes.uint32, (b.cast(exp, dtypes.uint32), b.const(0x1f, dtypes.uint32))), dtypes.uint32 + else: + bits = UOp(Ops.BITCAST, dtypes.uint32, (v,)) + exp = UOp(Ops.SHR, dtypes.uint32, (bits, b.const(23, dtypes.uint32))) + return UOp(Ops.AND, dtypes.uint32, (exp, b.const(0xff, dtypes.uint32))), dtypes.uint32 + + if name == 'mantissa' and len(args) == 1: + v, dt = arg(0, hint) + if dt == dtypes.float64: + bits = UOp(Ops.BITCAST, dtypes.uint64, (v,)) + return UOp(Ops.AND, dtypes.uint64, (bits, b.const(0xfffffffffffff, dtypes.uint64))), dtypes.uint64 + elif dt == dtypes.float16: + bits = UOp(Ops.BITCAST, dtypes.uint16, (v,)) + return UOp(Ops.AND, dtypes.uint32, (b.cast(bits, dtypes.uint32), b.const(0x3ff, dtypes.uint32))), dtypes.uint32 + else: + bits = UOp(Ops.BITCAST, dtypes.uint32, (v,)) + return UOp(Ops.AND, dtypes.uint32, (bits, b.const(0x7fffff, dtypes.uint32))), dtypes.uint32 + + if name == 'isEven' and len(args) == 1: + v, dt = arg(0, hint) + int_val = UOp(Ops.CAST, dtypes.int64, (v,)) + bit0 = UOp(Ops.AND, dtypes.int64, (int_val, b.const(1, dtypes.int64))) + return UOp(Ops.CMPEQ, dtypes.bool, (bit0, b.const(0, dtypes.int64))), dtypes.bool + + if name == 'signext' and len(args) == 1: + v, dt = arg(0) + return b.cast(v, dtypes.int64), dtypes.int64 + + if name == 'signext_from_bit' and len(args) == 2: + val, dt = arg(0, hint) + width, _ = arg(1) + one = b.const(1, dt) + width_minus_1 = UOp(Ops.SUB, dt, (b.cast(width, dt), one)) + sign_bit = UOp(Ops.SHL, dt, (one, width_minus_1)) + xored = UOp(Ops.XOR, dt, (val, sign_bit)) + result = UOp(Ops.SUB, dt, (xored, sign_bit)) + width_is_zero = UOp(Ops.CMPEQ, dtypes.bool, (width, b.const(0, width.dtype))) + return UOp(Ops.WHERE, dt, (width_is_zero, b.const(0, dt), result)), dt + + if name == 'ABSDIFF' and len(args) == 2: + a, _ = arg(0); bv, _ = arg(1) + a_gt_b = UOp(Ops.CMPLT, dtypes.bool, (bv, a)) + max_v = UOp(Ops.WHERE, dtypes.uint32, (a_gt_b, b.cast(a, dtypes.uint32), b.cast(bv, dtypes.uint32))) + min_v = UOp(Ops.WHERE, dtypes.uint32, (a_gt_b, b.cast(bv, dtypes.uint32), b.cast(a, dtypes.uint32))) + return UOp(Ops.SUB, dtypes.uint32, (max_v, min_v)), dtypes.uint32 + + if name == 'pow' and len(args) == 2: + base, base_dt = arg(0, hint); exp, _ = arg(1, hint) + result_dt = base_dt if _is_float(base_dt) else hint or dtypes.float32 + if base.op == Ops.CONST and base.arg == 2.0: + exp_uop = UOp(Ops.CAST, result_dt, (exp,)) if exp.dtype != result_dt else exp + return UOp(Ops.EXP2, result_dt, (exp_uop,)), result_dt + base_cast = UOp(Ops.CAST, result_dt, (base,)) if base.dtype != result_dt else base + exp_cast = UOp(Ops.CAST, result_dt, (exp,)) if exp.dtype != result_dt else exp + log_a = UOp(Ops.LOG2, result_dt, (base_cast,)) + return UOp(Ops.EXP2, result_dt, (UOp(Ops.MUL, result_dt, (exp_cast, log_a)),)), result_dt + + if name == 'LT_NEG_ZERO' and len(args) == 2: + a, dt = arg(0, hint); bv, _ = arg(1, hint) + if dt == dtypes.float64: + a_bits = UOp(Ops.BITCAST, dtypes.int64, (a,)) + b_bits = UOp(Ops.BITCAST, dtypes.int64, (bv,)) + elif dt == dtypes.float16: + a_bits = UOp(Ops.BITCAST, dtypes.int16, (a,)) + b_bits = UOp(Ops.BITCAST, dtypes.int16, (bv,)) + else: + a_bits = UOp(Ops.BITCAST, dtypes.int32, (a,)) + b_bits = UOp(Ops.BITCAST, dtypes.int32, (bv,)) + return UOp(Ops.CMPLT, dtypes.bool, (a_bits, b_bits)), dtypes.bool + + if name == 'GT_NEG_ZERO' and len(args) == 2: + a, dt = arg(0, hint); bv, _ = arg(1, hint) + if dt == dtypes.float64: + a_bits = UOp(Ops.BITCAST, dtypes.int64, (a,)) + b_bits = UOp(Ops.BITCAST, dtypes.int64, (bv,)) + elif dt == dtypes.float16: + a_bits = UOp(Ops.BITCAST, dtypes.int16, (a,)) + b_bits = UOp(Ops.BITCAST, dtypes.int16, (bv,)) + else: + a_bits = UOp(Ops.BITCAST, dtypes.int32, (a,)) + b_bits = UOp(Ops.BITCAST, dtypes.int32, (bv,)) + return UOp(Ops.CMPLT, dtypes.bool, (b_bits, a_bits)), dtypes.bool + + if name == 'SAT8' and len(args) == 1: + v, dt = arg(0, hint) + lo, hi = b.const(-128, dt), b.const(127, dt) + clamped_lo = UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (v, lo)), lo, v)) + return UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (hi, clamped_lo)), hi, clamped_lo)), dt + + # v_min/v_max functions + if name.startswith('v_min_') and len(args) == 2: + a, dt = arg(0, hint); bv, _ = arg(1, hint) + return UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (a, bv)), a, bv)), dt + if name.startswith('v_max_') and len(args) == 2: + a, dt = arg(0, hint); bv, _ = arg(1, hint) + return UOp(Ops.WHERE, dt, (UOp(Ops.CMPLT, dtypes.bool, (bv, a)), a, bv)), dt + + raise ValueError(f"Unknown function: {name}") + +# ═══════════════════════════════════════════════════════════════════════════════ +# STATEMENT TRANSFORMER +# ═══════════════════════════════════════════════════════════════════════════════ + +def _get_lhs_info(lhs, b: UOpBuilder) -> tuple[str, DType, int|None, int|None, str|None]: + """Extract assignment target info: (var_name, dtype, hi_bit, lo_bit, idx_var)""" + match lhs: + case Typed(Var(name), qdt): return name, _qdt(qdt), None, None, None + case Typed(Slice(Var(name), Const(hi, _), Const(lo, _)), qdt): return name, _qdt(qdt), int(hi), int(lo), None + case Typed(Index(Typed(Var(name), _), Var(idx)), _): return name, dtypes.uint64, None, None, idx + case Typed(Index(Var(name), Var(idx)), qdt): return name, _qdt(qdt), None, None, idx + case Slice(Typed(Var(name), _), Const(hi, _), Const(lo, _)): return name, dtypes.uint32, int(hi), int(lo), None + case Slice(Var(name), Const(hi, _), Const(lo, _)): return name, dtypes.uint32, int(hi), int(lo), None + case Index(Typed(Var(name), qdt), Var(idx)): return name, _qdt(qdt), None, None, idx + case Var(name): return name, dtypes.uint32, None, None, None + raise ValueError(f"Cannot parse LHS: {lhs}") + +def transform_stmt(stmt, b: UOpBuilder): + """Transform statement and update builder state.""" + match stmt: + case Declare(_, _): pass # Skip declarations + + case Assign(lhs, rhs): + var, dtype, hi, lo, idx_var = _get_lhs_info(lhs, b) + + # Bit index assignment: D0.u64[laneId] = expr + if idx_var is not None: + base = b.vars.get(var) + idx = b.vars.get(idx_var) + if base is None: raise ValueError(f"Unknown variable: {var}") + if idx is None: raise ValueError(f"Unknown index variable: {idx_var}") + cond, _ = transform_expr(rhs, b) + one = b.const(1, dtype) + bit_mask = UOp(Ops.SHL, dtype, (one, b.cast(idx, dtype))) + inv_mask = UOp(Ops.XOR, dtype, (bit_mask, b.const(-1, dtype))) + cleared = UOp(Ops.AND, dtype, (base, inv_mask)) + cond_ext = b.cast(cond, dtype) + cond_bit = UOp(Ops.SHL, dtype, (UOp(Ops.AND, dtype, (cond_ext, one)), b.cast(idx, dtype))) + result = UOp(Ops.OR, dtype, (cleared, cond_bit)) + b.vars[var] = result + if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): + b.outputs.append((var, result, dtype)) + return + + # Bit range assignment: D0[31:16].f16 = expr + if hi is not None and lo is not None: + if hi < lo: hi, lo = lo, hi + base = b.vars[var] if var in b.vars else b.const(0, dtypes.uint32) + rhs_uop, _ = transform_expr(rhs, b, dtype) + if _is_float(dtype): + if dtype == dtypes.float16: + rhs_bits = UOp(Ops.BITCAST, dtypes.uint16, (rhs_uop,)) + rhs_bits = b.cast(rhs_bits, dtypes.uint32) + else: + rhs_bits = UOp(Ops.BITCAST, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64, (rhs_uop,)) + else: + rhs_bits = b.cast(rhs_uop, dtypes.uint32) + width = hi - lo + 1 + mask = (1 << width) - 1 + shifted_val = UOp(Ops.SHL, dtypes.uint32, (UOp(Ops.AND, dtypes.uint32, (rhs_bits, b.const(mask, dtypes.uint32))), b.const(lo, dtypes.uint32))) + inv_mask = ~(mask << lo) & 0xffffffff + cleared = UOp(Ops.AND, dtypes.uint32, (b.cast(base, dtypes.uint32), b.const(inv_mask, dtypes.uint32))) + result = UOp(Ops.OR, dtypes.uint32, (cleared, shifted_val)) + b.vars[var] = result + if var in ('D0', 'D1', 'SCC', 'VCC'): + b.outputs = [(n, u, d) for n, u, d in b.outputs if n != var] + b.outputs.append((var, result, dtypes.uint32)) + return + + # Simple assignment + rhs_uop, _ = transform_expr(rhs, b, dtype) + b.vars[var] = rhs_uop + if dtype.itemsize == 8 and var in ('D0', 'D1', 'S0', 'S1'): + b.vars[var + '_64'] = rhs_uop + if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): + b.outputs.append((var, rhs_uop, dtype)) + + case If(branches): + _transform_if(branches, b) + + case For(var, start, end, body): + _transform_for(var, start, end, body, b) + +def _transform_if(branches: tuple, b: UOpBuilder): + """Transform if/elsif/else to nested WHERE expressions.""" + # Parse all conditions + parsed_branches = [] + for cond, body in branches: + cond_uop = transform_expr(cond, b)[0] if cond else None + parsed_branches.append((cond_uop, body)) + + # Collect all assigned variables + assigned_vars = set() + for _, body in parsed_branches: + for stmt in body: + if isinstance(stmt, Assign): + var, _, _, _, _ = _get_lhs_info(stmt.lhs, b) + assigned_vars.add(var) + + # Build nested WHERE for each variable + for var in assigned_vars: + # Determine dtype from first assignment + dtype = dtypes.uint32 + for _, body in parsed_branches: + for stmt in body: + if isinstance(stmt, Assign): + v, dt, _, _, _ = _get_lhs_info(stmt.lhs, b) + if v == var: dtype = dt; break + + curr_val = b.vars[var] if var in b.vars else b.const(0, dtype) + result = curr_val + + # Process branches in reverse order + for cond_uop, body in reversed(parsed_branches): + branch_val = None + for stmt in body: + if isinstance(stmt, Assign): + v, dt, _, _, _ = _get_lhs_info(stmt.lhs, b) + if v == var: + branch_val, _ = transform_expr(stmt.rhs, b, dt) + dtype = dt + break + + if branch_val is not None: + if cond_uop is None: + result = branch_val + else: + if result.dtype != branch_val.dtype: + result = b.cast(result, branch_val.dtype) + result = UOp(Ops.WHERE, branch_val.dtype, (cond_uop, branch_val, result)) + + # Store result + if _is_float(dtype): + result_bits = UOp(Ops.BITCAST, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64, (result,)) + b.vars[var] = result_bits + if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): + b.outputs = [(n, u, d) for n, u, d in b.outputs if n != var] + b.outputs.append((var, result_bits, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64)) + else: + b.vars[var] = result + if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): + b.outputs = [(n, u, d) for n, u, d in b.outputs if n != var] + b.outputs.append((var, result, dtype)) + +def _transform_for(var: str, start, end, body: tuple, b: UOpBuilder): + """Unroll for loop and transform body.""" + start_val = start.value if isinstance(start, Const) else int(transform_expr(start, b)[0].arg) + end_val = end.value if isinstance(end, Const) else int(transform_expr(end, b)[0].arg) + + for loop_val in range(int(start_val), int(end_val) + 1): + # Set loop variable + b.vars[var] = b.const(loop_val, dtypes.uint32) + + for stmt in body: + if isinstance(stmt, If): + _transform_if(stmt.branches, b) + elif isinstance(stmt, Assign): + transform_stmt(stmt, b) + # ═══════════════════════════════════════════════════════════════════════════════ # HELPERS # ═══════════════════════════════════════════════════════════════════════════════ def _float_to_bits(val: float, dtype: DType) -> int: - import math if dtype == dtypes.float32: return struct.unpack(' 0 else 0xfc00 # f16 +/-inf - if abs(val) > 65504.0: return 0x7c00 if val > 0 else 0xfc00 # overflow to inf - if abs(val) < 6.103515625e-05 and val != 0: return 0x0000 if val > 0 else 0x8000 # underflow to zero + if math.isnan(val): return 0x7e00 + if math.isinf(val): return 0x7c00 if val > 0 else 0xfc00 + if abs(val) > 65504.0: return 0x7c00 if val > 0 else 0xfc00 + if abs(val) < 6.103515625e-05 and val != 0: return 0x0000 if val > 0 else 0x8000 return struct.unpack(' int: # ═══════════════════════════════════════════════════════════════════════════════ def _compile_pseudocode(pseudocode: str) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp]]: - """Compile pseudocode to UOp graph. Returns (sink, output_info, input_vars).""" - builder = UOpBuilder() - lines = [line.split('//')[0].strip().rstrip(';') for line in pseudocode.strip().split('\n')] - lines = [l for l in lines if l] + """Compile pseudocode to UOp graph using qcode parser.""" + ast = parse(pseudocode) + b = UOpBuilder() - i = 0 - while i < len(lines): - line = lines[i] + for stmt in ast: + transform_stmt(stmt, b) - # Skip declare statements - if line.startswith('declare '): - i += 1 - continue - - # Handle for loops: for i in START : END do ... endfor - if line.startswith('for ') and ' do' in line: - # Parse: for VAR in START : END do - m = re.match(r'for\s+(\w+)\s+in\s+(.+?)\s*:\s*(.+?)\s+do', line) - if m: - loop_var, start_expr, end_expr = m.group(1), m.group(2).strip(), m.group(3).strip() - # Parse start and end as constants - start_val = int(start_expr.replace("'", "").rstrip('U')) - end_val = int(end_expr.replace("'", "").rstrip('U')) - - # Collect loop body until endfor - i += 1 - loop_body = [] - depth = 1 - while i < len(lines) and depth > 0: - if lines[i].startswith('for ') and ' do' in lines[i]: - depth += 1 - elif lines[i] == 'endfor': - depth -= 1 - if depth > 0: - loop_body.append(lines[i]) - i += 1 - - # Unroll the loop - process loop body for each iteration - def expand_line(line, var, val): - """Substitute loop variable and evaluate bracket expressions.""" - expanded = re.sub(rf'\b{var}\b', str(val), line) - def eval_bracket(m): - try: return '[' + str(eval(m.group(1))) + ']' - except: return m.group(0) - return re.sub(r'\[([^\]]+)\]', eval_bracket, expanded) - - for loop_val in range(start_val, end_val + 1): - # Process loop body with index tracking - body_idx = 0 - while body_idx < len(loop_body): - body_line = loop_body[body_idx] - expanded = expand_line(body_line, loop_var, loop_val) - - if expanded.startswith('if ') and ' then' in expanded: - # Handle if inside loop - cond_str = expanded[3:expanded.index(' then')].strip() - cond_uop, _ = builder.parse_expr(cond_str) - # Collect if body until endif - body_idx += 1 - if_stmts = [] - while body_idx < len(loop_body) and loop_body[body_idx] != 'endif': - stmt = expand_line(loop_body[body_idx], loop_var, loop_val) - if_stmts.append(stmt) - body_idx += 1 - body_idx += 1 # Skip endif - # Process if body with condition - for stmt in if_stmts: - if '=' not in stmt: - continue - lhs, rhs = stmt.split('=', 1) - lhs, rhs = lhs.strip(), rhs.strip() - var, dtype = builder.parse_type(lhs) - curr_val = builder.vars.get(var) - if curr_val is None: - curr_val = builder.const(0, dtype) - new_val, _ = builder.parse_expr(rhs, dtype) - if new_val.dtype != curr_val.dtype: - curr_val = builder.cast(curr_val, new_val.dtype) - result = UOp(Ops.WHERE, new_val.dtype, (cond_uop, new_val, curr_val)) - builder.vars[var] = result - if var in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): - builder.outputs = [(n, u, d) for n, u, d in builder.outputs if n != var] - builder.outputs.append((var, result, dtype)) - elif '=' in expanded and not expanded.startswith('if '): - builder.parse_stmt(expanded) - body_idx += 1 - else: - body_idx += 1 - continue - - i += 1 - continue - - # Handle if/elsif/else/endif blocks - if line.startswith('if ') and ' then' in line: - # Collect all branches: [(condition_str, body_stmts), ...] - # Last entry may have condition_str=None for 'else' branch - branches = [] - cond_str = line[3:line.index(' then')].strip() - i += 1 - current_body = [] - - while i < len(lines) and lines[i] != 'endif': - if lines[i].startswith('elsif ') and ' then' in lines[i]: - # Save current branch - branches.append((cond_str, current_body)) - # Start new elsif branch - cond_str = lines[i][6:lines[i].index(' then')].strip() - current_body = [] - i += 1 - elif lines[i] == 'else': - # Save current branch - branches.append((cond_str, current_body)) - cond_str = None # else has no condition - current_body = [] - i += 1 - else: - current_body.append(lines[i]) - i += 1 - - # Save final branch - branches.append((cond_str, current_body)) - - # Parse all conditions - parsed_branches = [] - for cond, body in branches: - cond_uop = builder.parse_expr(cond)[0] if cond else None - parsed_branches.append((cond_uop, body)) - - # Helper to extract assignment info from a statement - def parse_assignment(stmt): - if '=' not in stmt or stmt.startswith('if ') or stmt.startswith('for ') or stmt.startswith('elsif '): - return None - if '+=' in stmt or '-=' in stmt: - is_sub = '-=' in stmt - lhs, rhs = stmt.split('-=' if is_sub else '+=', 1) - return ('compound', lhs.strip(), rhs.strip(), is_sub) - lhs, rhs = stmt.split('=', 1) - return ('simple', lhs.strip(), rhs.strip(), None) - - # Collect all variables that are assigned in any branch - assigned_vars = set() - for _, body in parsed_branches: - for stmt in body: - info = parse_assignment(stmt) - if info: - var, _ = builder.parse_type(info[1]) - assigned_vars.add(var) - - # For each assigned variable, build a nested WHERE chain - for var in assigned_vars: - # Get current value as the default (used if no branch assigns) - var_name, dtype = builder.parse_type(var) - curr_val = builder.vars.get(var_name) - if curr_val is None: - curr_val = builder.const(0, dtype) - - # Build nested WHERE from last branch to first - # result = cond1 ? val1 : (cond2 ? val2 : (cond3 ? val3 : else_val)) - result = curr_val # default if no else branch - - # Process branches in reverse order - for cond_uop, body in reversed(parsed_branches): - # Find assignment to this var in this branch - branch_val = None - for stmt in body: - info = parse_assignment(stmt) - if info: - stmt_var, stmt_dtype = builder.parse_type(info[1]) - if stmt_var == var_name: - if info[0] == 'compound': - # += or -= - is_sub = info[3] - inc_uop, _ = builder.parse_expr(info[2], stmt_dtype) - base = builder.vars.get(var_name, builder.const(0, stmt_dtype)) - if _is_float(stmt_dtype) and base.dtype != stmt_dtype: - base = UOp(Ops.BITCAST, stmt_dtype, (base,)) - op = Ops.SUB if is_sub else Ops.ADD - branch_val = UOp(op, stmt_dtype, (base, inc_uop)) - else: - # simple assignment - branch_val, _ = builder.parse_expr(info[2], stmt_dtype) - dtype = stmt_dtype - break - - if branch_val is not None: - if cond_uop is None: - # This is the else branch - it becomes the new default - result = branch_val - else: - # Conditional branch - wrap in WHERE - if result.dtype != branch_val.dtype: - result = builder.cast(result, branch_val.dtype) - result = UOp(Ops.WHERE, branch_val.dtype, (cond_uop, branch_val, result)) - - # Store result - if _is_float(dtype): - result_bits = UOp(Ops.BITCAST, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64, (result,)) - builder.vars[var_name] = result_bits - if var_name in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): - builder.outputs = [(n, u, d) for n, u, d in builder.outputs if n != var_name] - builder.outputs.append((var_name, result_bits, dtypes.uint32 if dtype == dtypes.float32 else dtypes.uint64)) - else: - builder.vars[var_name] = result - if var_name in ('D0', 'D1', 'SCC', 'VCC', 'EXEC', 'PC'): - builder.outputs = [(n, u, d) for n, u, d in builder.outputs if n != var_name] - builder.outputs.append((var_name, result, dtype)) - - i += 1 # Skip endif - continue - - # Regular statement - builder.parse_stmt(line) - i += 1 - - sink = builder.build_sink() - output_info = [(name, dtype) for name, _, dtype in builder.outputs] - return sink, output_info, builder.input_vars + sink = b.build_sink() + output_info = [(name, dtype) for name, _, dtype in b.outputs] + return sink, output_info, b.input_vars def _make_uop_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[str, UOp]): """Create a runtime function that evaluates the UOp graph via simplify.""" def fn(s0, s1, s2, d0, scc, vcc, laneId, exec_mask, literal, VGPR, src0_idx=0, vdst_idx=0, pc=None): - # Build substitution map: DEFINE_VAR -> CONST - # SIMM16 is passed via literal for SOPK instructions - may be unsigned 16-bit, convert to signed if literal is not None: simm16 = literal if -32768 <= literal <= 32767 else (literal - 65536 if literal < 65536 else 0) else: @@ -1183,23 +745,21 @@ def _make_uop_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: di input_vars['SIMM32']: UOp.const(dtypes.uint32, literal if literal is not None else 0), input_vars['PC']: UOp.const(dtypes.uint64, pc if pc is not None else 0), } - - # Substitute and simplify all at once + simplified_sink = sink.substitute(dvars).simplify() assert simplified_sink.op == Ops.SINK, f"expected SINK, got {simplified_sink.op}" - + result = {} for i, (name, dtype) in enumerate(output_info): out_uop = simplified_sink.src[i] assert out_uop.op == Ops.CONST, f"simplify did not produce CONST for {name}, got {out_uop.op}" val = out_uop.arg - # Convert to bits if _is_float(dtype): bits = _float_to_bits(val, dtype) else: bits = int(val) & (0xffffffff if dtype in (dtypes.uint32, dtypes.int32) else 0xffffffffffffffff) result[name] = bits - + return result return fn @@ -1207,21 +767,17 @@ def _make_uop_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: di # PUBLIC API # ═══════════════════════════════════════════════════════════════════════════════ -# Ops that ucode.py supports - only include ops that compile successfully -# NOTE: Float comparisons using <=, >=, or NOT are excluded due to NaN handling issues SUPPORTED_OPS: set[str] = { # VOP (153 ops) 'V_ADD3_U32', 'V_ADD_CO_CI_U32', 'V_ADD_CO_U32', 'V_ADD_F16', 'V_ADD_F32', 'V_ADD_F64', 'V_ADD_LSHL_U32', 'V_ADD_NC_I16', 'V_ADD_NC_I32', 'V_ADD_NC_U16', 'V_ADD_NC_U32', 'V_ALIGNBIT_B32', 'V_ALIGNBYTE_B32', 'V_AND_B16', 'V_AND_B32', 'V_AND_OR_B32', 'V_ASHRREV_I16', 'V_ASHRREV_I32', 'V_ASHRREV_I64', 'V_BFE_I32', 'V_BFE_U32', 'V_BFI_B32', 'V_BFM_B32', - 'V_CNDMASK_B16', 'V_CNDMASK_B32', 'V_COS_F16', 'V_COS_F32', 'V_CUBEID_F32', 'V_CUBESC_F32', 'V_CVT_F16_F32', 'V_CVT_F32_F16', 'V_CVT_F32_I32', 'V_CVT_F32_U32', 'V_CVT_F32_UBYTE0', 'V_CVT_F32_UBYTE1', 'V_CVT_F32_UBYTE2', 'V_CVT_F32_UBYTE3', 'V_CVT_FLOOR_I32_F32', 'V_CVT_I32_F32', 'V_CVT_I32_I16', 'V_CVT_NEAREST_I32_F32', 'V_CVT_PK_I16_F32', 'V_CVT_PK_U16_F32', 'V_CVT_PK_U8_F32', 'V_CVT_U32_F32', 'V_CVT_U32_U16', 'V_DOT2_F16_F16', 'V_DOT2_F32_F16', 'V_DOT2ACC_F32_F16', - 'V_FMA_DX9_ZERO_F32', 'V_FMA_F16', 'V_FMA_F32', 'V_FMA_F64', 'V_FMAAK_F16', 'V_FMAAK_F32', 'V_FMAC_DX9_ZERO_F32', 'V_FMAC_F16', 'V_FMAC_F32', 'V_FMAMK_F16', 'V_FMAMK_F32', 'V_FREXP_EXP_I16_F16', 'V_FREXP_EXP_I32_F32', 'V_FREXP_EXP_I32_F64', @@ -1239,23 +795,20 @@ SUPPORTED_OPS: set[str] = { 'V_PK_SUB_I16', 'V_PK_SUB_U16', 'V_RNDNE_F16', 'V_RNDNE_F32', 'V_RNDNE_F64', 'V_SAD_U8', 'V_SAD_U16', 'V_SAD_U32', 'V_SIN_F16', 'V_SIN_F32', 'V_SQRT_F16', 'V_SQRT_F32', 'V_SQRT_F64', - # Conversions 'V_CVT_F32_F64', 'V_CVT_F64_F32', 'V_CVT_F64_I32', 'V_CVT_F64_U32', 'V_CVT_I32_F64', 'V_CVT_U32_F64', 'V_CVT_NORM_I16_F16', 'V_CVT_NORM_U16_F16', 'V_CVT_PK_NORM_I16_F16', 'V_CVT_PK_NORM_U16_F16', 'V_CVT_PK_RTZ_F16_F32', 'V_SUB_CO_CI_U32', 'V_SUB_CO_U32', 'V_SUB_F16', 'V_SUB_F32', 'V_SUB_NC_I16', 'V_SUB_NC_I32', 'V_SUB_NC_U16', 'V_SUB_NC_U32', 'V_SUBREV_CO_CI_U32', 'V_SUBREV_CO_U32', 'V_SUBREV_F16', 'V_SUBREV_F32', 'V_SUBREV_NC_U32', 'V_SWAP_B16', 'V_SWAP_B32', 'V_TRUNC_F16', 'V_TRUNC_F32', 'V_TRUNC_F64', 'V_WRITELANE_B32', 'V_XAD_U32', 'V_XNOR_B32', 'V_XOR3_B32', 'V_XOR_B16', 'V_XOR_B32', - # Additional VOP ops (newly supported) 'V_CVT_F16_I16', 'V_CVT_F16_U16', 'V_CVT_I16_F16', 'V_CVT_U16_F16', 'V_EXP_F16', 'V_EXP_F32', 'V_LDEXP_F16', 'V_LDEXP_F32', 'V_LDEXP_F64', 'V_CUBEMA_F32', 'V_CUBETC_F32', 'V_SAT_PK_U8_I16', - # min3/max3/minmax/maxmin ops 'V_MAX3_I16', 'V_MAX3_I32', 'V_MAX3_U16', 'V_MAX3_U32', 'V_MIN3_I16', 'V_MIN3_I32', 'V_MIN3_U16', 'V_MIN3_U32', 'V_MAXMIN_I32', 'V_MAXMIN_U32', 'V_MINMAX_I32', 'V_MINMAX_U32', - # VOPC - integer and float comparisons (112 ops) + # VOPC (112 ops) 'V_CMP_EQ_F16', 'V_CMP_EQ_F32', 'V_CMP_EQ_F64', 'V_CMP_EQ_I16', 'V_CMP_EQ_I32', 'V_CMP_EQ_I64', 'V_CMP_EQ_U16', 'V_CMP_EQ_U32', 'V_CMP_EQ_U64', 'V_CMP_F_F16', 'V_CMP_F_F32', 'V_CMP_F_F64', 'V_CMP_F_I32', 'V_CMP_F_I64', 'V_CMP_F_U32', 'V_CMP_F_U64', 'V_CMP_GE_F16', 'V_CMP_GE_F32', 'V_CMP_GE_F64', 'V_CMP_GE_I16', 'V_CMP_GE_I32', 'V_CMP_GE_I64', 'V_CMP_GE_U16', 'V_CMP_GE_U32', 'V_CMP_GE_U64', @@ -1270,7 +823,7 @@ SUPPORTED_OPS: set[str] = { 'V_CMP_NLT_F16', 'V_CMP_NLT_F32', 'V_CMP_NLT_F64', 'V_CMP_O_F16', 'V_CMP_O_F32', 'V_CMP_O_F64', 'V_CMP_T_F16', 'V_CMP_T_F32', 'V_CMP_T_F64', 'V_CMP_T_I32', 'V_CMP_T_I64', 'V_CMP_T_U32', 'V_CMP_T_U64', 'V_CMP_U_F16', 'V_CMP_U_F32', 'V_CMP_U_F64', - # VOPCX - compare and write exec (112 ops) + # VOPCX (112 ops) 'V_CMPX_EQ_F16', 'V_CMPX_EQ_F32', 'V_CMPX_EQ_F64', 'V_CMPX_EQ_I16', 'V_CMPX_EQ_I32', 'V_CMPX_EQ_I64', 'V_CMPX_EQ_U16', 'V_CMPX_EQ_U32', 'V_CMPX_EQ_U64', 'V_CMPX_F_F16', 'V_CMPX_F_F32', 'V_CMPX_F_F64', 'V_CMPX_F_I32', 'V_CMPX_F_I64', 'V_CMPX_F_U32', 'V_CMPX_F_U64', 'V_CMPX_GE_F16', 'V_CMPX_GE_F32', 'V_CMPX_GE_F64', 'V_CMPX_GE_I16', 'V_CMPX_GE_I32', 'V_CMPX_GE_I64', 'V_CMPX_GE_U16', 'V_CMPX_GE_U32', 'V_CMPX_GE_U64', @@ -1308,11 +861,9 @@ SUPPORTED_OPS: set[str] = { 'S_SENDMSG_RTN_B32', 'S_SENDMSG_RTN_B64', 'S_SETPC_B64', 'S_SEXT_I32_I16', 'S_SEXT_I32_I8', 'S_SUB_F16', 'S_SUB_F32', 'S_SUB_I32', 'S_SUB_U32', 'S_SUBB_U32', 'S_TRUNC_F16', 'S_TRUNC_F32', 'S_VERSION', - # Additional SOP ops (newly supported) 'S_BITCMP0_B32', 'S_BITCMP0_B64', 'S_BITCMP1_B32', 'S_BITCMP1_B64', 'S_MAX_F16', 'S_MAX_F32', 'S_MIN_F16', 'S_MIN_F32', 'S_WAITCNT_EXPCNT', 'S_WAITCNT_LGKMCNT', 'S_WAITCNT_VMCNT', 'S_WAITCNT_VSCNT', - # Branch/control flow ops 'S_BRANCH', 'S_CALL_B64', 'S_CBRANCH_EXECNZ', 'S_CBRANCH_EXECZ', 'S_CBRANCH_SCC0', 'S_CBRANCH_SCC1', 'S_CBRANCH_VCCNZ', 'S_CBRANCH_VCCZ', 'S_GETPC_B64', 'S_XNOR_B32', 'S_XNOR_B64', 'S_XNOR_SAVEEXEC_B32', 'S_XNOR_SAVEEXEC_B64', @@ -1334,7 +885,6 @@ SUPPORTED_OPS: set[str] = { @functools.cache def compile_uop(cls_name: str, op_name: str, pseudocode: str): """Compile pseudocode to UOp-based function. Returns None if unsupported.""" - if op_name not in SUPPORTED_OPS: - return None + if op_name not in SUPPORTED_OPS: return None sink, output_info, input_vars = _compile_pseudocode(pseudocode) return _make_uop_fn(sink, output_info, input_vars)