This commit is contained in:
George Hotz
2026-01-05 19:54:56 -08:00
parent ffba806b65
commit 6de310c87f
2 changed files with 38 additions and 12 deletions

View File

@@ -48,7 +48,9 @@ class For: var: str; start: UOp; end: UOp; body: tuple[Stmt, ...]
class Lambda: name: str; params: tuple[str, ...]; body: tuple[Stmt, ...]|UOp
@dataclass(frozen=True)
class Break: pass
Stmt = Assign|Declare|If|For|Lambda|Break
@dataclass(frozen=True)
class Return: value: UOp
Stmt = Assign|Declare|If|For|Lambda|Break|Return
def _match(s, i, o, c):
d = 1
@@ -186,9 +188,10 @@ def expr(s: str) -> UOp:
raise ValueError(f"Cannot parse expression: {s}")
def stmt(line: str) -> Stmt|None:
line = line.split('//')[0].strip().rstrip(';')
line = line.split('//')[0].strip().rstrip(';').rstrip('.')
if not line: return None
if line == 'break': return Break()
if line[:7] == 'return ': return Return(expr(line[7:]))
if line[:5] == 'eval ': return Assign(UOp(Ops.DEFINE_VAR, dtypes.void, arg=('_eval', None, None)), UOp(Ops.DEFINE_VAR, dtypes.void, arg=(line, None, None)))
if line[:8] == 'declare ' and ':' in line:
n, t = line[8:].split(':', 1)
@@ -208,10 +211,29 @@ def stmt(line: str) -> Stmt|None:
eq = line.index('=')
if eq > 0 and line[eq-1] not in '!<>=' and eq < len(line)-1 and line[eq+1] != '=':
return Assign(expr(line[:eq]), expr(line[eq+1:]))
return None
# Bare function call (e.g., nop())
if re.match(r'\w+\([^)]*\)$', line):
return expr(line)
raise ValueError(f"Cannot parse statement: {line}")
def parse(code: str) -> tuple[Stmt, ...]:
lines = [l.split('//')[0].strip() for l in code.strip().split('\n') if l.split('//')[0].strip()]
# Join continuation lines (unbalanced parens) - but not for control flow or lambdas
joined, j = [], 0
while j < len(lines):
ln = lines[j]
# Don't join lambda lines - they have their own multiline handling
if '= lambda(' not in ln:
while ln.count('(') > ln.count(')') and j + 1 < len(lines):
next_ln = lines[j + 1]
# Don't join if next line is control flow or looks like a new statement
if next_ln[:3] == 'if ' or next_ln[:4] == 'for ' or next_ln[:6] == 'elsif ' or next_ln == 'else' or \
next_ln == 'endif' or next_ln == 'endfor' or '= lambda(' in next_ln: break
j += 1
ln += ' ' + next_ln
joined.append(ln)
j += 1
lines = joined
stmts, i = [], 0
while i < len(lines):
ln = lines[i].rstrip(';')
@@ -224,7 +246,7 @@ def parse(code: str) -> tuple[Stmt, ...]:
while i < len(lines) and not lines[i-1].rstrip().endswith(');'):
body_lines.append(lines[i])
i += 1
body_text = ' '.join(body_lines).strip()
body_text = '\n'.join(body_lines).strip()
if body_text.endswith(');'): body_text = body_text[:-2]
# Try to parse as expression first, then as statements
try:
@@ -241,22 +263,25 @@ def parse(code: str) -> tuple[Stmt, ...]:
if d > 0: body.append(lines[i])
i += 1
stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body)))); continue
if ln[:3] == 'if ' and ' then' in ln:
br, cond, body, i, depth = [], ln[3:ln.index(' then')], [], i+1, 1
if ln[:3] == 'if ':
cond = ln[3:ln.index(' then')] if ' then' in ln else ln[3:]
br, body, i, depth = [], [], i+1, 1
while i < len(lines) and depth > 0:
line_i = lines[i].rstrip(';').rstrip('.')
if line_i[:3] == 'if ' and ' then' in line_i: depth += 1; body.append(lines[i])
if line_i[:3] == 'if ': depth += 1; body.append(lines[i])
elif line_i == 'endif':
depth -= 1
if depth > 0: body.append(lines[i])
elif depth == 1 and line_i[:6] == 'elsif ' and ' then' in line_i:
br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:line_i.index(' then')], []
elif depth == 1 and line_i[:6] == 'elsif ':
cond_end = line_i.index(' then') if ' then' in line_i else len(line_i)
br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:cond_end], []
elif depth == 1 and line_i == 'else':
br.append((expr(cond), parse('\n'.join(body)))); cond, body = None, []
else: body.append(lines[i])
i += 1
br.append((expr(cond) if cond else None, parse('\n'.join(body)))); stmts.append(If(tuple(br))); continue
if ln == 'else' or ln[:6] == 'elsif ': raise ValueError(f"Unexpected {ln.split()[0]} without matching if")
if s := stmt(ln): stmts.append(s)
s = stmt(ln)
if s is not None: stmts.append(s)
i += 1
return tuple(stmts)

View File

@@ -2,7 +2,7 @@ import unittest, re, os
from tinygrad.dtype import dtypes
from tinygrad.uop import Ops
from tinygrad.uop.ops import UOp
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
DEBUG = int(os.getenv("DEBUG", "0"))
@@ -72,6 +72,7 @@ def _pr(n, d=0):
return "\n".join(parts) + f"\n{p}endif"
case For(v, s, e, b): return f"{p}for {v} in {_pr(s)} : {_pr(e)} do\n" + "\n".join(_pr(x, d) for x in b) + f"\n{p}endfor"
case Break(): return f"{p}break"
case Return(v): return f"{p}return {_pr(v)}"
case Lambda(name, params, body):
body_str = _pr(body) if isinstance(body, UOp) else "\n".join(_pr(x, d) for x in body)
return f"{p}{name} = lambda({', '.join(params)}) (\n{body_str});"
@@ -135,7 +136,7 @@ class TestQcodeParseAndRoundtrip(unittest.TestCase):
if DEBUG:
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
self.assertGreater(parse_rate, 98.5, f"Parse rate {parse_rate:.1f}% should be >98.5%")
self.assertGreater(parse_rate, 98, f"Parse rate {parse_rate:.1f}% should be >98%")
self.assertGreater(roundtrip_rate, 98, f"Roundtrip rate {roundtrip_rate:.1f}% should be >98%")
if __name__ == "__main__":