mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
436 lines
22 KiB
Python
436 lines
22 KiB
Python
# Minimal parser for AMD GPU pseudocode -> UOps
|
|
from __future__ import annotations
|
|
import re
|
|
from dataclasses import dataclass
|
|
from tinygrad.dtype import dtypes, DType
|
|
from tinygrad.uop import Ops
|
|
from tinygrad.uop.ops import UOp
|
|
# DType lookup table for AMD pseudocode type suffixes
|
|
from tinygrad.dtype import INVERSE_DTYPES_DICT
|
|
_QDTYPES: dict[str, DType] = {
|
|
'f64': dtypes.float64, 'f32': dtypes.float32, 'f16': dtypes.float16, 'bf16': dtypes.bfloat16,
|
|
'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None),
|
|
'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None),
|
|
'fp4': DType.new(4, 4, "fp4", None), 'i4': DType.new(5, 4, "i4", None),
|
|
'u64': dtypes.uint64, 'u32': dtypes.uint32, 'u16': dtypes.uint16, 'u8': dtypes.uint8,
|
|
'i64': dtypes.int64, 'i32': dtypes.int32, 'i16': dtypes.int16, 'i8': dtypes.int8,
|
|
'b1201': DType.new(6, 1201, "b1201", None), 'b1024': DType.new(6, 1024, "b1024", None), 'b512': DType.new(6, 512, "b512", None),
|
|
'b192': DType.new(6, 192, "b192", None), 'b128': DType.new(6, 128, "b128", None),
|
|
'b65': DType.new(6, 65, "b65", None), 'b64': dtypes.uint64, 'b32': dtypes.uint32, 'b23': DType.new(6, 23, "b23", None), 'b16': dtypes.uint16, 'b8': dtypes.uint8, 'b4': DType.new(6, 4, "b4", None),
|
|
'u1201': DType.new(6, 1201, "u1201", None), 'u65': DType.new(6, 65, "u65", None), 'u24': DType.new(6, 24, "u24", None), 'u23': DType.new(6, 23, "u23", None),
|
|
'u6': DType.new(6, 6, "u6", None), 'u4': DType.new(6, 4, "u4", None),
|
|
# i1/u1 are used for carry/overflow bits in 64-bit multiply-add ops (e.g., { D1.u1, D0.u64 } = 65-bit result)
|
|
'u3': DType.new(6, 3, "u3", None), 'u1': DType.new(6, 1, "u1", None),
|
|
'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None),
|
|
'i1': DType.new(5, 1, "i1", None),
|
|
'u': dtypes.uint32, 'i': dtypes.int32, 'f': dtypes.float32,
|
|
}
|
|
# Register custom dtypes for repr
|
|
for k, v in _QDTYPES.items():
|
|
if v.name not in INVERSE_DTYPES_DICT: INVERSE_DTYPES_DICT[v.name] = k
|
|
|
|
# String to Ops mapping
|
|
_BINOPS: dict[str, Ops] = {
|
|
'+': Ops.ADD, '-': Ops.SUB, '*': Ops.MUL, '/': Ops.FDIV, '%': Ops.MOD, '**': Ops.POW,
|
|
'&': Ops.AND, '|': Ops.OR, '^': Ops.XOR, '<<': Ops.SHL, '>>': Ops.SHR,
|
|
'==': Ops.CMPEQ, '!=': Ops.CMPNE, '<>': Ops.CMPNE,
|
|
'<': Ops.CMPLT, '>': Ops.CMPLT, '<=': Ops.CMPLE, '>=': Ops.CMPLE,
|
|
'||': Ops.OR, '&&': Ops.AND,
|
|
}
|
|
# NOTE: ~ is bitwise NOT (XOR with -1), ! is logical NOT (compare == 0). NEG is arithmetic negation, not suitable here.
|
|
_UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ}
|
|
|
|
# Direct function -> UOp mappings (parsed directly, not as CUSTOM)
|
|
_DIRECT_OPS: dict[str, Ops] = {'trunc': Ops.TRUNC, 'sqrt': Ops.SQRT, 'exp2': Ops.EXP2, 'log2': Ops.LOG2, 'sin': Ops.SIN, 'rcp': Ops.RECIPROCAL}
|
|
|
|
def _typed_const(src: UOp, val) -> UOp:
|
|
"""Create a const with same dtype as src, or a deferred const if src.dtype is void."""
|
|
return UOp.const(src.dtype, val) if src.dtype != dtypes.void else UOp(Ops.CONST, dtypes.void, (src,), val)
|
|
|
|
# Function return type inference for CUSTOM ops
|
|
_BOOL_FNS = {'isNAN', 'isINF', 'isDENORM', 'isQuietNAN', 'isSignalNAN', 'isEven', 'LT_NEG_ZERO', 'GT_NEG_ZERO'}
|
|
_PASSTHRU_FNS = {'abs', 'floor', 'fract', 'sqrt', 'sin', 'cos', 'trunc', 'fma', 'clamp', 'min', 'max', 'ldexp',
|
|
'cvtToQuietNAN', 'pow', 'rcp', 'rsqrt', 'exp2', 'log2', 'mantissa', 'v_min_f16', 'v_min_f32',
|
|
'v_max_f16', 'v_max_f32', 'v_min3_f16', 'v_min3_f32', 'v_max3_f16', 'v_max3_f32'}
|
|
_U32_FNS = {'sign', 'exponent', 'ABSDIFF', 'SAT8', 'BYTE_PERMUTE', 'count_ones', 'countbits', 'reverse_bits',
|
|
'u8_to_u32', 'u4_to_u32', 'u32_to_u16', 's_ff1_i32_b32', 's_ff1_i32_b64', 'v_sad_u8', 'v_msad_u8',
|
|
'v_min_u16', 'v_min_u32', 'v_max_u16', 'v_max_u32', 'v_min3_u16', 'v_min3_u32', 'v_max3_u16', 'v_max3_u32'}
|
|
_I32_FNS = {'v_min_i16', 'v_min_i32', 'v_max_i16', 'v_max_i32', 'v_min3_i16', 'v_min3_i32', 'v_max3_i16', 'v_max3_i32'}
|
|
_CVT_FNS = { # conversion functions: name -> output dtype
|
|
'f32_to_i32': dtypes.int32, 'f32_to_u32': dtypes.uint32, 'f32_to_f16': dtypes.float16, 'f32_to_f64': dtypes.float64,
|
|
'f32_to_i8': dtypes.int8, 'f32_to_u8': dtypes.uint8, 'f32_to_i16': dtypes.int16, 'f32_to_u16': dtypes.uint16,
|
|
'f64_to_i32': dtypes.int32, 'f64_to_u32': dtypes.uint32, 'f64_to_f32': dtypes.float32,
|
|
'f16_to_f32': dtypes.float32, 'f16_to_i16': dtypes.int16, 'f16_to_u16': dtypes.uint16,
|
|
'i32_to_f32': dtypes.float32, 'i32_to_f64': dtypes.float64, 'i32_to_i16': dtypes.int16,
|
|
'u32_to_f32': dtypes.float32, 'u32_to_f64': dtypes.float64,
|
|
'i16_to_f16': dtypes.float16, 'u16_to_f16': dtypes.float16,
|
|
'bf16_to_f32': dtypes.float32, 'f32_to_bf16': dtypes.bfloat16,
|
|
'v_cvt_u16_f32': dtypes.uint16, 'v_cvt_i16_f32': dtypes.int16,
|
|
'f16_to_snorm': dtypes.int16, 'f16_to_unorm': dtypes.uint16, 'f32_to_snorm': dtypes.int16, 'f32_to_unorm': dtypes.uint16,
|
|
'signext': dtypes.int64, 'signext_from_bit': dtypes.int64,
|
|
}
|
|
|
|
def _infer_fn_dtype(name: str, srcs: tuple[UOp, ...]) -> DType:
|
|
"""Infer output dtype for a function call based on function name and input types."""
|
|
if name in _BOOL_FNS: return dtypes.bool
|
|
if name in _PASSTHRU_FNS: return srcs[0].dtype if srcs and srcs[0].dtype != dtypes.void else dtypes.void
|
|
if name in _U32_FNS: return dtypes.uint32
|
|
if name in _I32_FNS: return dtypes.int32
|
|
if name in _CVT_FNS: return _CVT_FNS[name]
|
|
if name == 'trig_preop_result': return dtypes.float64
|
|
# Default: inherit from first non-void source, or void
|
|
for s in srcs:
|
|
if s.dtype != dtypes.void: return s.dtype
|
|
return dtypes.void
|
|
|
|
# Statement types (control flow, not expressions)
|
|
@dataclass(frozen=True)
|
|
class Assign: lhs: UOp; rhs: UOp
|
|
@dataclass(frozen=True)
|
|
class Declare: name: str; dtype: DType
|
|
@dataclass(frozen=True)
|
|
class If: branches: tuple[tuple[UOp|None, tuple[Stmt, ...]], ...]
|
|
@dataclass(frozen=True)
|
|
class For: var: str; start: UOp; end: UOp; body: tuple[Stmt, ...]
|
|
@dataclass(frozen=True)
|
|
class Lambda: name: str; params: tuple[str, ...]; body: tuple[Stmt, ...]|UOp
|
|
@dataclass(frozen=True)
|
|
class Break: pass
|
|
@dataclass(frozen=True)
|
|
class Return: value: UOp
|
|
Stmt = Assign|Declare|If|For|Lambda|Break|Return
|
|
|
|
# Parse context for tracking variable dtypes (module-level, set during parse())
|
|
_var_dtypes: dict[str, DType] = {}
|
|
|
|
def _match(s, i, o, c):
|
|
d = 1
|
|
for j in range(i+1, len(s)):
|
|
if s[j] == o: d += 1
|
|
elif s[j] == c: d -= 1
|
|
if d == 0: return j
|
|
return -1
|
|
|
|
def _split(s):
|
|
r, d, l = [], 0, 0
|
|
for i, c in enumerate(s):
|
|
if c in '([{': d += 1
|
|
elif c in ')]}': d -= 1
|
|
elif c == ',' and d == 0: r.append(s[l:i].strip()); l = i+1
|
|
if s[l:].strip(): r.append(s[l:].strip())
|
|
return r
|
|
|
|
def _fop(s, ops):
|
|
d = b = 0
|
|
for i in range(len(s)-1, -1, -1):
|
|
c = s[i]
|
|
if c == ')': d += 1
|
|
elif c == '(': d -= 1
|
|
elif c == ']': b += 1
|
|
elif c == '[': b -= 1
|
|
elif d == 0 and b == 0:
|
|
for op in sorted(ops, key=len, reverse=True):
|
|
if s[i:i+len(op)] == op:
|
|
if op in ('<', '>') and (i+1 < len(s) and s[i+1] in '<>=' or i > 0 and s[i-1] in '<>='): continue
|
|
if op == '*' and (i+1 < len(s) and s[i+1] == '*' or i > 0 and s[i-1] == '*'): continue
|
|
if op == '-' and (not s[:i].rstrip() or s[:i].rstrip()[-1] in '+-*/(<>=&|^,'): continue
|
|
return i
|
|
return -1
|
|
|
|
def _get_dtype(name: str) -> DType | None: return _QDTYPES.get(name.lower())
|
|
|
|
def expr(s: str) -> UOp:
|
|
s = s.strip().rstrip(';')
|
|
if s.endswith('.') and not (len(s) > 1 and s[-2].isdigit()): s = s[:-1]
|
|
s = s.strip()
|
|
if not s: raise ValueError("Empty expression")
|
|
if s == '+INF': s = 'INF'
|
|
# Parentheses
|
|
if s[0] == '(' and (e := _match(s, 0, '(', ')')) == len(s)-1: return expr(s[1:e])
|
|
# Pack -> CAT: { hi, lo } concatenates to larger type
|
|
if s[0] == '{' and s[-1] == '}':
|
|
parts = tuple(expr(a) for a in _split(s[1:-1]))
|
|
# Infer combined bitwidth from parts (e.g., {u32, u32} -> u64)
|
|
total_bits = sum(p.dtype.bitsize for p in parts if p.dtype != dtypes.void)
|
|
cat_dtype = dtypes.uint64 if total_bits > 32 else dtypes.uint32 if total_bits > 0 else dtypes.void
|
|
return UOp(Ops.CAT, cat_dtype, parts)
|
|
# Typed cast: 32'U(expr) - value conversion (vs .type which is bit reinterpretation)
|
|
if m := re.match(r"^(\d+)'([IUFB])\(", s):
|
|
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1:
|
|
cast_dtype = _QDTYPES[f"{m[2].lower()}{m[1]}"]
|
|
assert cast_dtype != dtypes.void, f"CAST target type should not be void"
|
|
return UOp(Ops.CAST, cast_dtype, (expr(s[m.end():e]),))
|
|
# Typed constant: 32'-5I
|
|
if m := re.match(r"^(\d+)'(-?\d+)([IUFB])?$", s):
|
|
return UOp(Ops.CONST, _QDTYPES[f"{(m[3] or 'I').lower()}{m[1]}"], arg=int(m[2]))
|
|
if m := re.match(r"^(\d+)'(-?[\d.]+)$", s):
|
|
return UOp(Ops.CONST, _QDTYPES[f"f{m[1]}"], arg=float(m[2]))
|
|
if m := re.match(r"^(\d+)'(0x[0-9a-fA-F]+)$", s):
|
|
return UOp(Ops.CONST, _QDTYPES[f"u{m[1]}"], arg=int(m[2], 16))
|
|
# Function call -> direct UOp or CUSTOM
|
|
if m := re.match(r"^([A-Za-z_]\w*)\(", s):
|
|
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1:
|
|
a = _split(s[m.end():e])
|
|
srcs = tuple(expr(x) for x in a) if a != [''] else ()
|
|
name = m[1]
|
|
# Direct UOp mappings for functions
|
|
if name in _DIRECT_OPS: return UOp(_DIRECT_OPS[name], srcs[0].dtype, srcs)
|
|
if name == 'fma': return UOp(Ops.MULACC, srcs[2].dtype, (srcs[0], srcs[1], srcs[2]))
|
|
if name == 'isNAN': return UOp(Ops.CMPNE, dtypes.bool, (srcs[0], srcs[0]))
|
|
if name == 'rsqrt': return UOp(Ops.RECIPROCAL, srcs[0].dtype, (UOp(Ops.SQRT, srcs[0].dtype, (srcs[0],)),))
|
|
if name == 'clamp':
|
|
x, lo, hi = srcs[0], srcs[1], srcs[2]
|
|
c = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, lo)), lo, x))
|
|
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (hi, c)), hi, c))
|
|
if name == 'abs':
|
|
x = srcs[0]
|
|
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, _typed_const(x, 0))), UOp(Ops.NEG, x.dtype, (x,)), x))
|
|
if name == 'cos':
|
|
x = srcs[0]
|
|
return UOp(Ops.SIN, x.dtype, (UOp(Ops.ADD, x.dtype, (x, _typed_const(x, 1.5707963267948966))),))
|
|
if name == 'floor':
|
|
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
|
|
return UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
|
|
if name == 'fract':
|
|
x, trunc = srcs[0], UOp(Ops.TRUNC, srcs[0].dtype, (srcs[0],))
|
|
floor = UOp(Ops.WHERE, x.dtype, (UOp(Ops.CMPLT, dtypes.bool, (x, trunc)), UOp(Ops.SUB, x.dtype, (trunc, _typed_const(x, 1))), trunc))
|
|
return UOp(Ops.SUB, x.dtype, (x, floor))
|
|
output_dtype = _infer_fn_dtype(name, srcs)
|
|
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=name)
|
|
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST
|
|
if s[:4] == 'MEM[' and (e := _match(s, 3, '[', ']')) != -1:
|
|
r, b = s[e+1:], UOp(Ops.CUSTOM, dtypes.void, (expr(s[4:e]),), arg='MEM')
|
|
if not r: return b
|
|
if r[:1] == '.' and (dt := _get_dtype(r[1:])):
|
|
assert dt != dtypes.void, f"BITCAST target type should not be void"
|
|
return UOp(Ops.BITCAST, dt, (b,))
|
|
# Ternary: cond ? t : f -> WHERE
|
|
if (q := _fop(s, ('?',))) > 0:
|
|
d = b = 0
|
|
for i in range(q+1, len(s)):
|
|
if s[i] == '(': d += 1
|
|
elif s[i] == ')': d -= 1
|
|
elif s[i] == '[': b += 1
|
|
elif s[i] == ']': b -= 1
|
|
elif s[i] == ':' and d == 0 and b == 0:
|
|
gate, lhs, rhs = expr(s[:q]), expr(s[q+1:i]), expr(s[i+1:])
|
|
# Infer output dtype from lhs/rhs (prefer non-void)
|
|
out_dtype = lhs.dtype if lhs.dtype != dtypes.void else rhs.dtype
|
|
# Gate can be bool, void (unresolved), or integer (C-style truthiness: 0=false, non-zero=true)
|
|
assert gate.dtype == dtypes.void or gate.dtype == dtypes.bool or dtypes.is_int(gate.dtype), \
|
|
f"gate on WHERE must be bool or int, got {gate.dtype}"
|
|
return UOp(Ops.WHERE, out_dtype, (gate, lhs, rhs))
|
|
# Binary ops
|
|
for ops in [('||',),('&&',),('|',),('^',),('&',),('==','!=','<>'),('<=','>=','<','>'),('<<','>>'),
|
|
('+','-'),('*','/','%'),('**',)]:
|
|
if (p := _fop(s, ops)) > 0:
|
|
op = next(o for o in sorted(ops, key=len, reverse=True) if s[p:p+len(o)] == o)
|
|
l, r = s[:p].strip(), s[p+len(op):].strip()
|
|
if l and r:
|
|
lhs, rhs = expr(l), expr(r)
|
|
flipped = op in ('>', '>=')
|
|
if flipped: lhs, rhs = rhs, lhs
|
|
tag = 'flipped' if flipped else ('<>' if op == '<>' else None)
|
|
uop_op = _BINOPS[op]
|
|
output_dtype = lhs.dtype if lhs.dtype != dtypes.void else rhs.dtype
|
|
if uop_op in {Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE, Ops.CMPLT}: output_dtype = dtypes.bool
|
|
return UOp(uop_op, output_dtype, (lhs, rhs), tag=tag)
|
|
# Unary ops: - (negate), ~ (bitwise NOT), ! (logical NOT)
|
|
if s[0] in '-~!' and len(s) > 1 and (s[0] != '!' or s[1] != '='):
|
|
src = expr(s[1:])
|
|
# ! is logical NOT (compare to 0), returns bool; - and ~ preserve dtype
|
|
out_dtype = dtypes.bool if s[0] == '!' else src.dtype
|
|
return UOp(_UNOPS[s[0]], out_dtype, (src,))
|
|
# Slice/Index -> CUSTOMI
|
|
if '[' in s and s[-1] == ']':
|
|
d = 0
|
|
for i in range(len(s)-1, -1, -1):
|
|
if s[i] == ']': d += 1
|
|
elif s[i] == '[': d -= 1
|
|
if d == 0 and s[i] == '[': break
|
|
b, n = s[:i], s[i+1:-1]
|
|
if '+:' in n: # Verilog [start +: width]
|
|
st, w = expr(n.split('+:', 1)[0]), expr(n.split('+:', 1)[1])
|
|
# hi = start + width - 1; use int32 for index computations
|
|
idx_dt = st.dtype if st.dtype != dtypes.void else dtypes.int32
|
|
hi = UOp(Ops.SUB, idx_dt, (UOp(Ops.ADD, idx_dt, (st, w)), UOp(Ops.CONST, dtypes.int32, arg=1)))
|
|
# Infer slice dtype from width (if constant)
|
|
slice_dtype = dtypes.uint64 if (w.op == Ops.CONST and w.arg > 32) else dtypes.uint32
|
|
# NOTE: CUSTOMI is used for bit slicing; SHRINK would be for tensor operations
|
|
return UOp(Ops.CUSTOMI, slice_dtype, (expr(b), hi, st))
|
|
if ':' in n and '?' not in n:
|
|
d = 0
|
|
for j, c in enumerate(n):
|
|
if c in '([{': d += 1
|
|
elif c in ')]}': d -= 1
|
|
elif c == ':' and d == 0:
|
|
hi_expr, lo_expr = expr(n[:j]), expr(n[j+1:])
|
|
# Infer slice dtype from constant bounds (hi - lo + 1 bits)
|
|
if hi_expr.op == Ops.CONST and lo_expr.op == Ops.CONST:
|
|
width = abs(int(hi_expr.arg) - int(lo_expr.arg)) + 1
|
|
slice_dtype = dtypes.uint64 if width > 32 else dtypes.uint32
|
|
else:
|
|
slice_dtype = dtypes.uint32 # default for dynamic slices
|
|
return UOp(Ops.CUSTOMI, slice_dtype, (expr(b), hi_expr, lo_expr))
|
|
idx = expr(n)
|
|
# Single bit index returns uint32 (1 bit result fits in u32)
|
|
return UOp(Ops.CUSTOMI, dtypes.uint32, (expr(b), idx, idx))
|
|
# Bitcast: expr.type
|
|
if '.' in s:
|
|
for i in range(len(s)-1, 0, -1):
|
|
if s[i] == '.' and (dt := _get_dtype(s[i+1:])):
|
|
assert dt != dtypes.void, f"BITCAST target type should not be void: {s}"
|
|
return UOp(Ops.BITCAST, dt, (expr(s[:i]),))
|
|
# Variable
|
|
if s[:5] == 'eval ': return UOp(Ops.DEFINE_VAR, dtypes.void, arg=(s, None, None))
|
|
if re.match(r'^[A-Za-z_][\w.]*$', s):
|
|
var_dtype = _var_dtypes.get(s, dtypes.void)
|
|
return UOp(Ops.DEFINE_VAR, var_dtype, arg=(s, None, None))
|
|
# Numeric literal
|
|
# NOTE: hex constants are unsigned (uint32) even without U suffix
|
|
try:
|
|
if s[:2].lower() == '0x':
|
|
m = re.match(r'0[xX]([0-9a-fA-F]+)([UuLl]*)$', s)
|
|
if m:
|
|
val, suf = int(m[1], 16), m[2].lower()
|
|
if 'll' in suf: return UOp(Ops.CONST, dtypes.uint64 if 'u' in suf else dtypes.int64, arg=val)
|
|
if 'u' in suf: return UOp(Ops.CONST, dtypes.uint32, arg=val)
|
|
return UOp(Ops.CONST, dtypes.uint32, arg=val)
|
|
suffix = re.search(r'([UuLlFf]+)$', s)
|
|
suf = suffix[1].lower() if suffix else ''
|
|
c = re.sub(r'[FfLlUu]+$', '', s)
|
|
if '.' in c or 'e' in c.lower() or 'f' in suf: return UOp(Ops.CONST, dtypes.float32, arg=float(c))
|
|
if 'u' in suf: return UOp(Ops.CONST, dtypes.uint64 if 'll' in suf else dtypes.uint32, arg=int(c))
|
|
if 'll' in suf: return UOp(Ops.CONST, dtypes.int64, arg=int(c))
|
|
return UOp(Ops.CONST, dtypes.int32, arg=int(c))
|
|
except ValueError: pass
|
|
raise ValueError(f"Cannot parse expression: {s}")
|
|
|
|
def stmt(line: str) -> Stmt|None:
|
|
# NOTE: variable dtypes are resolved in ucode.py via INPUT_VARS (SCC=uint32, ADDR=uint64, etc.)
|
|
line = line.split('//')[0].strip().rstrip(';').rstrip('.')
|
|
if not line: return None
|
|
if line == 'break': return Break()
|
|
if line[:7] == 'return ': return Return(expr(line[7:]))
|
|
if line[:5] == 'eval ': return Assign(UOp(Ops.DEFINE_VAR, dtypes.void, arg=('_eval', None, None)), UOp(Ops.DEFINE_VAR, dtypes.void, arg=(line, None, None)))
|
|
if line[:8] == 'declare ' and ':' in line:
|
|
n, t = line[8:].split(':', 1)
|
|
t = t.strip()
|
|
vec_count = int(m[1]) if (m := re.search(r'\[(\d+)\]$', t)) else 1
|
|
t = t.split('[')[0] # strip array suffix like [64]
|
|
if m := re.match(r"^(\d+)'([IUFB])$", t):
|
|
dt = _QDTYPES[f"{m[2].lower()}{m[1]}"]
|
|
final_dt = dt.vec(vec_count) if vec_count > 1 else dt
|
|
_var_dtypes[n.strip()] = final_dt # track dtype for subsequent expr() calls
|
|
return Declare(n.strip(), final_dt)
|
|
return None # unsupported declare type
|
|
for op, uop in [('+=', Ops.ADD), ('-=', Ops.SUB), ('|=', Ops.OR), ('&=', Ops.AND), ('^=', Ops.XOR), ('<<=', Ops.SHL), ('>>=', Ops.SHR)]:
|
|
if op in line:
|
|
l, r = line.split(op, 1)
|
|
lhs, rhs = expr(l), expr(r)
|
|
# Infer result dtype from operands (dtype resolution happens in ucode._stmt)
|
|
result_dtype = lhs.dtype if lhs.dtype != dtypes.void else rhs.dtype
|
|
return Assign(lhs, UOp(uop, result_dtype, (lhs, rhs)))
|
|
if '=' in line and not any(line[:k] == p for k, p in [(3,'if '),(6,'elsif '),(4,'for ')]):
|
|
# Find leftmost assignment = (not ==, <=, >=, !=) for chained assignment support
|
|
eq = -1
|
|
for i in range(1, len(line) - 1):
|
|
if line[i] == '=' and line[i-1] not in '!<>=' and line[i+1] != '=':
|
|
eq = i
|
|
break
|
|
if eq > 0:
|
|
rhs = line[eq+1:].strip()
|
|
# Check if RHS contains another assignment = (not ==, <=, >=, !=)
|
|
has_assign = False
|
|
for i in range(1, len(rhs) - 1):
|
|
if rhs[i] == '=' and rhs[i-1] not in '!<>=' and rhs[i+1] != '=':
|
|
has_assign = True
|
|
break
|
|
if has_assign:
|
|
rhs_parsed = stmt(rhs)
|
|
if isinstance(rhs_parsed, Assign):
|
|
lhs = expr(line[:eq])
|
|
return Assign(lhs, rhs_parsed)
|
|
lhs, rhs_expr = expr(line[:eq]), expr(rhs)
|
|
# Track dtype for bare variable assignments (e.g., tmp = S0.u32)
|
|
if lhs.op == Ops.DEFINE_VAR and lhs.dtype == dtypes.void and rhs_expr.dtype != dtypes.void:
|
|
_var_dtypes[lhs.arg[0]] = rhs_expr.dtype
|
|
lhs = UOp(Ops.DEFINE_VAR, rhs_expr.dtype, arg=lhs.arg)
|
|
return Assign(lhs, rhs_expr)
|
|
# Bare function call (e.g., nop())
|
|
if re.match(r'\w+\([^)]*\)$', line):
|
|
return expr(line)
|
|
raise ValueError(f"Cannot parse statement: {line}")
|
|
|
|
def parse(code: str) -> tuple[Stmt, ...]:
|
|
global _var_dtypes
|
|
_var_dtypes = {} # reset for each parse
|
|
lines = [l.split('//')[0].strip() for l in code.strip().split('\n') if l.split('//')[0].strip()]
|
|
# Join continuation lines (unbalanced parens) - but not for control flow or lambdas
|
|
joined, j = [], 0
|
|
while j < len(lines):
|
|
ln = lines[j]
|
|
# Don't join lambda lines - they have their own multiline handling
|
|
if '= lambda(' not in ln:
|
|
while ln.count('(') > ln.count(')') and j + 1 < len(lines):
|
|
next_ln = lines[j + 1]
|
|
# Don't join if next line is control flow or looks like a new statement
|
|
if next_ln[:3] == 'if ' or next_ln[:4] == 'for ' or next_ln[:6] == 'elsif ' or next_ln == 'else' or \
|
|
next_ln == 'endif' or next_ln == 'endfor' or '= lambda(' in next_ln: break
|
|
j += 1
|
|
ln += ' ' + next_ln
|
|
joined.append(ln)
|
|
j += 1
|
|
lines = joined
|
|
stmts, i = [], 0
|
|
while i < len(lines):
|
|
ln = lines[i].rstrip(';')
|
|
# Lambda: NAME = lambda(params) ( body );
|
|
if '= lambda(' in ln and (m := re.match(r'(\w+)\s*=\s*lambda\(([^)]*)\)\s*\(', ln)):
|
|
name, params = m[1], tuple(p.strip() for p in m[2].split(',')) if m[2].strip() else ()
|
|
# Collect lambda body until closing );
|
|
body_lines = [ln[m.end():]]
|
|
i += 1
|
|
while i < len(lines) and not lines[i-1].rstrip().endswith(');'):
|
|
body_lines.append(lines[i])
|
|
i += 1
|
|
body_text = '\n'.join(body_lines).strip()
|
|
if body_text.endswith(');'): body_text = body_text[:-2]
|
|
# Try to parse as expression first, then as statements
|
|
try:
|
|
body = expr(body_text)
|
|
except ValueError:
|
|
body = parse(body_text)
|
|
stmts.append(Lambda(name, params, body)); continue
|
|
if ln[:4] == 'for ' and ' do' in ln and (m := re.match(r'for\s+(\w+)\s+in\s+(.+?)\s*:\s*(.+?)\s+do', ln)):
|
|
i, body, d = i+1, [], 1
|
|
while i < len(lines) and d > 0:
|
|
line_i = lines[i].rstrip(';').rstrip('.')
|
|
if line_i[:4] == 'for ' and ' do' in line_i: d += 1
|
|
elif line_i == 'endfor': d -= 1
|
|
if d > 0: body.append(lines[i])
|
|
i += 1
|
|
stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body)))); continue
|
|
if ln[:3] == 'if ':
|
|
cond = ln[3:ln.index(' then')] if ' then' in ln else ln[3:]
|
|
br, body, i, depth = [], [], i+1, 1
|
|
while i < len(lines) and depth > 0:
|
|
line_i = lines[i].rstrip(';').rstrip('.')
|
|
if line_i[:3] == 'if ': depth += 1; body.append(lines[i])
|
|
elif line_i == 'endif':
|
|
depth -= 1
|
|
if depth > 0: body.append(lines[i])
|
|
elif depth == 1 and line_i[:6] == 'elsif ':
|
|
cond_end = line_i.index(' then') if ' then' in line_i else len(line_i)
|
|
br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:cond_end], []
|
|
elif depth == 1 and line_i == 'else':
|
|
br.append((expr(cond), parse('\n'.join(body)))); cond, body = None, []
|
|
else: body.append(lines[i])
|
|
i += 1
|
|
br.append((expr(cond) if cond else None, parse('\n'.join(body)))); stmts.append(If(tuple(br))); continue
|
|
if ln == 'else' or ln[:6] == 'elsif ': raise ValueError(f"Unexpected {ln.split()[0]} without matching if")
|
|
s = stmt(ln)
|
|
if s is not None: stmts.append(s)
|
|
i += 1
|
|
return tuple(stmts)
|