From 894230d0a923dd92d101fb4bbe4e9498728f197d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 8 Jan 2026 05:02:55 -0800 Subject: [PATCH] fix parser bugs --- extra/assembly/amd/pcode_parse.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/extra/assembly/amd/pcode_parse.py b/extra/assembly/amd/pcode_parse.py index 8996038d57..cf4f23523a 100644 --- a/extra/assembly/amd/pcode_parse.py +++ b/extra/assembly/amd/pcode_parse.py @@ -51,9 +51,7 @@ def _floor(x): def _cvt(src_dt: DType, dst_dt: DType): """Create a conversion function that asserts input type and casts to output type.""" def convert(x: UOp) -> UOp: - # Allow: exact match, void (unresolved), or uint32 (unresolved array access/slice) - # TODO: should only allow exact match - assert x.dtype == src_dt or x.dtype == dtypes.void or x.dtype == dtypes.uint32, f"Expected {src_dt}, got {x.dtype}" + assert x.dtype == src_dt or x.dtype == dtypes.void, f"Expected {src_dt}, got {x.dtype}" return UOp(Ops.CAST, dst_dt, (x,)) return convert @@ -290,8 +288,10 @@ def expr(s: str) -> UOp: slice_dtype = dtypes.uint32 # default for dynamic slices return UOp(Ops.CUSTOMI, slice_dtype, (expr(b), hi_expr, lo_expr)) idx = expr(n) - # Single bit index returns uint32 (1 bit result fits in u32) - return UOp(Ops.CUSTOMI, dtypes.uint32, (expr(b), idx, idx)) + base = expr(b) + # For array element access, use scalar type of the array; for bit index, use uint32 + elem_dtype = base.dtype.scalar() if base.dtype != dtypes.void and base.dtype.count > 1 else dtypes.uint32 + return UOp(Ops.CUSTOMI, elem_dtype, (base, idx, idx)) # Bitcast: expr.type if '.' in s: for i in range(len(s)-1, 0, -1): @@ -379,9 +379,9 @@ def stmt(line: str) -> Stmt|None: return expr(line) raise ValueError(f"Cannot parse statement: {line}") -def parse(code: str) -> tuple[Stmt, ...]: +def parse(code: str, _toplevel: bool = True) -> tuple[Stmt, ...]: global _var_dtypes - _var_dtypes = {} # reset for each parse + if _toplevel: _var_dtypes = {} # only reset at top level, preserve for recursive calls lines = [l.split('//')[0].strip() for l in code.strip().split('\n') if l.split('//')[0].strip()] # Join continuation lines (unbalanced parens) - but not for control flow or lambdas joined, j = [], 0 @@ -417,7 +417,7 @@ def parse(code: str) -> tuple[Stmt, ...]: try: body = expr(body_text) except ValueError: - body = parse(body_text) + body = parse(body_text, _toplevel=False) stmts.append(Lambda(name, params, body)); continue if ln[:4] == 'for ' and ' do' in ln and (m := re.match(r'for\s+(\w+)\s+in\s+(.+?)\s*:\s*(.+?)\s+do', ln)): i, body, d = i+1, [], 1 @@ -427,7 +427,7 @@ def parse(code: str) -> tuple[Stmt, ...]: elif line_i == 'endfor': d -= 1 if d > 0: body.append(lines[i]) i += 1 - stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body)))); continue + stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body), _toplevel=False))); continue if ln[:3] == 'if ': cond = ln[3:ln.index(' then')] if ' then' in ln else ln[3:] br, body, i, depth = [], [], i+1, 1 @@ -439,12 +439,12 @@ def parse(code: str) -> tuple[Stmt, ...]: if depth > 0: body.append(lines[i]) elif depth == 1 and line_i[:6] == 'elsif ': cond_end = line_i.index(' then') if ' then' in line_i else len(line_i) - br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:cond_end], [] + br.append((expr(cond), parse('\n'.join(body), _toplevel=False))); cond, body = line_i[6:cond_end], [] elif depth == 1 and line_i == 'else': - br.append((expr(cond), parse('\n'.join(body)))); cond, body = None, [] + br.append((expr(cond), parse('\n'.join(body), _toplevel=False))); cond, body = None, [] else: body.append(lines[i]) i += 1 - br.append((expr(cond) if cond else None, parse('\n'.join(body)))); stmts.append(If(tuple(br))); continue + br.append((expr(cond) if cond else None, parse('\n'.join(body), _toplevel=False))); stmts.append(If(tuple(br))); continue if ln == 'else' or ln[:6] == 'elsif ': raise ValueError(f"Unexpected {ln.split()[0]} without matching if") s = stmt(ln) if s is not None: stmts.append(s)