mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix parser bugs
This commit is contained in:
@@ -51,9 +51,7 @@ def _floor(x):
|
||||
def _cvt(src_dt: DType, dst_dt: DType):
|
||||
"""Create a conversion function that asserts input type and casts to output type."""
|
||||
def convert(x: UOp) -> UOp:
|
||||
# Allow: exact match, void (unresolved), or uint32 (unresolved array access/slice)
|
||||
# TODO: should only allow exact match
|
||||
assert x.dtype == src_dt or x.dtype == dtypes.void or x.dtype == dtypes.uint32, f"Expected {src_dt}, got {x.dtype}"
|
||||
assert x.dtype == src_dt or x.dtype == dtypes.void, f"Expected {src_dt}, got {x.dtype}"
|
||||
return UOp(Ops.CAST, dst_dt, (x,))
|
||||
return convert
|
||||
|
||||
@@ -290,8 +288,10 @@ def expr(s: str) -> UOp:
|
||||
slice_dtype = dtypes.uint32 # default for dynamic slices
|
||||
return UOp(Ops.CUSTOMI, slice_dtype, (expr(b), hi_expr, lo_expr))
|
||||
idx = expr(n)
|
||||
# Single bit index returns uint32 (1 bit result fits in u32)
|
||||
return UOp(Ops.CUSTOMI, dtypes.uint32, (expr(b), idx, idx))
|
||||
base = expr(b)
|
||||
# For array element access, use scalar type of the array; for bit index, use uint32
|
||||
elem_dtype = base.dtype.scalar() if base.dtype != dtypes.void and base.dtype.count > 1 else dtypes.uint32
|
||||
return UOp(Ops.CUSTOMI, elem_dtype, (base, idx, idx))
|
||||
# Bitcast: expr.type
|
||||
if '.' in s:
|
||||
for i in range(len(s)-1, 0, -1):
|
||||
@@ -379,9 +379,9 @@ def stmt(line: str) -> Stmt|None:
|
||||
return expr(line)
|
||||
raise ValueError(f"Cannot parse statement: {line}")
|
||||
|
||||
def parse(code: str) -> tuple[Stmt, ...]:
|
||||
def parse(code: str, _toplevel: bool = True) -> tuple[Stmt, ...]:
|
||||
global _var_dtypes
|
||||
_var_dtypes = {} # reset for each parse
|
||||
if _toplevel: _var_dtypes = {} # only reset at top level, preserve for recursive calls
|
||||
lines = [l.split('//')[0].strip() for l in code.strip().split('\n') if l.split('//')[0].strip()]
|
||||
# Join continuation lines (unbalanced parens) - but not for control flow or lambdas
|
||||
joined, j = [], 0
|
||||
@@ -417,7 +417,7 @@ def parse(code: str) -> tuple[Stmt, ...]:
|
||||
try:
|
||||
body = expr(body_text)
|
||||
except ValueError:
|
||||
body = parse(body_text)
|
||||
body = parse(body_text, _toplevel=False)
|
||||
stmts.append(Lambda(name, params, body)); continue
|
||||
if ln[:4] == 'for ' and ' do' in ln and (m := re.match(r'for\s+(\w+)\s+in\s+(.+?)\s*:\s*(.+?)\s+do', ln)):
|
||||
i, body, d = i+1, [], 1
|
||||
@@ -427,7 +427,7 @@ def parse(code: str) -> tuple[Stmt, ...]:
|
||||
elif line_i == 'endfor': d -= 1
|
||||
if d > 0: body.append(lines[i])
|
||||
i += 1
|
||||
stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body)))); continue
|
||||
stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body), _toplevel=False))); continue
|
||||
if ln[:3] == 'if ':
|
||||
cond = ln[3:ln.index(' then')] if ' then' in ln else ln[3:]
|
||||
br, body, i, depth = [], [], i+1, 1
|
||||
@@ -439,12 +439,12 @@ def parse(code: str) -> tuple[Stmt, ...]:
|
||||
if depth > 0: body.append(lines[i])
|
||||
elif depth == 1 and line_i[:6] == 'elsif ':
|
||||
cond_end = line_i.index(' then') if ' then' in line_i else len(line_i)
|
||||
br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:cond_end], []
|
||||
br.append((expr(cond), parse('\n'.join(body), _toplevel=False))); cond, body = line_i[6:cond_end], []
|
||||
elif depth == 1 and line_i == 'else':
|
||||
br.append((expr(cond), parse('\n'.join(body)))); cond, body = None, []
|
||||
br.append((expr(cond), parse('\n'.join(body), _toplevel=False))); cond, body = None, []
|
||||
else: body.append(lines[i])
|
||||
i += 1
|
||||
br.append((expr(cond) if cond else None, parse('\n'.join(body)))); stmts.append(If(tuple(br))); continue
|
||||
br.append((expr(cond) if cond else None, parse('\n'.join(body), _toplevel=False))); stmts.append(If(tuple(br))); continue
|
||||
if ln == 'else' or ln[:6] == 'elsif ': raise ValueError(f"Unexpected {ln.split()[0]} without matching if")
|
||||
s = stmt(ln)
|
||||
if s is not None: stmts.append(s)
|
||||
|
||||
Reference in New Issue
Block a user