fix parser bugs

This commit is contained in:
George Hotz
2026-01-08 05:02:55 -08:00
parent 544a877960
commit 894230d0a9

View File

@@ -51,9 +51,7 @@ def _floor(x):
def _cvt(src_dt: DType, dst_dt: DType):
"""Create a conversion function that asserts input type and casts to output type."""
def convert(x: UOp) -> UOp:
# Allow: exact match, void (unresolved), or uint32 (unresolved array access/slice)
# TODO: should only allow exact match
assert x.dtype == src_dt or x.dtype == dtypes.void or x.dtype == dtypes.uint32, f"Expected {src_dt}, got {x.dtype}"
assert x.dtype == src_dt or x.dtype == dtypes.void, f"Expected {src_dt}, got {x.dtype}"
return UOp(Ops.CAST, dst_dt, (x,))
return convert
@@ -290,8 +288,10 @@ def expr(s: str) -> UOp:
slice_dtype = dtypes.uint32 # default for dynamic slices
return UOp(Ops.CUSTOMI, slice_dtype, (expr(b), hi_expr, lo_expr))
idx = expr(n)
# Single bit index returns uint32 (1 bit result fits in u32)
return UOp(Ops.CUSTOMI, dtypes.uint32, (expr(b), idx, idx))
base = expr(b)
# For array element access, use scalar type of the array; for bit index, use uint32
elem_dtype = base.dtype.scalar() if base.dtype != dtypes.void and base.dtype.count > 1 else dtypes.uint32
return UOp(Ops.CUSTOMI, elem_dtype, (base, idx, idx))
# Bitcast: expr.type
if '.' in s:
for i in range(len(s)-1, 0, -1):
@@ -379,9 +379,9 @@ def stmt(line: str) -> Stmt|None:
return expr(line)
raise ValueError(f"Cannot parse statement: {line}")
def parse(code: str) -> tuple[Stmt, ...]:
def parse(code: str, _toplevel: bool = True) -> tuple[Stmt, ...]:
global _var_dtypes
_var_dtypes = {} # reset for each parse
if _toplevel: _var_dtypes = {} # only reset at top level, preserve for recursive calls
lines = [l.split('//')[0].strip() for l in code.strip().split('\n') if l.split('//')[0].strip()]
# Join continuation lines (unbalanced parens) - but not for control flow or lambdas
joined, j = [], 0
@@ -417,7 +417,7 @@ def parse(code: str) -> tuple[Stmt, ...]:
try:
body = expr(body_text)
except ValueError:
body = parse(body_text)
body = parse(body_text, _toplevel=False)
stmts.append(Lambda(name, params, body)); continue
if ln[:4] == 'for ' and ' do' in ln and (m := re.match(r'for\s+(\w+)\s+in\s+(.+?)\s*:\s*(.+?)\s+do', ln)):
i, body, d = i+1, [], 1
@@ -427,7 +427,7 @@ def parse(code: str) -> tuple[Stmt, ...]:
elif line_i == 'endfor': d -= 1
if d > 0: body.append(lines[i])
i += 1
stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body)))); continue
stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body), _toplevel=False))); continue
if ln[:3] == 'if ':
cond = ln[3:ln.index(' then')] if ' then' in ln else ln[3:]
br, body, i, depth = [], [], i+1, 1
@@ -439,12 +439,12 @@ def parse(code: str) -> tuple[Stmt, ...]:
if depth > 0: body.append(lines[i])
elif depth == 1 and line_i[:6] == 'elsif ':
cond_end = line_i.index(' then') if ' then' in line_i else len(line_i)
br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:cond_end], []
br.append((expr(cond), parse('\n'.join(body), _toplevel=False))); cond, body = line_i[6:cond_end], []
elif depth == 1 and line_i == 'else':
br.append((expr(cond), parse('\n'.join(body)))); cond, body = None, []
br.append((expr(cond), parse('\n'.join(body), _toplevel=False))); cond, body = None, []
else: body.append(lines[i])
i += 1
br.append((expr(cond) if cond else None, parse('\n'.join(body)))); stmts.append(If(tuple(br))); continue
br.append((expr(cond) if cond else None, parse('\n'.join(body), _toplevel=False))); stmts.append(If(tuple(br))); continue
if ln == 'else' or ln[:6] == 'elsif ': raise ValueError(f"Unexpected {ln.split()[0]} without matching if")
s = stmt(ln)
if s is not None: stmts.append(s)