mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
parsing
This commit is contained in:
@@ -48,7 +48,9 @@ class For: var: str; start: UOp; end: UOp; body: tuple[Stmt, ...]
|
||||
class Lambda: name: str; params: tuple[str, ...]; body: tuple[Stmt, ...]|UOp
|
||||
@dataclass(frozen=True)
|
||||
class Break: pass
|
||||
Stmt = Assign|Declare|If|For|Lambda|Break
|
||||
@dataclass(frozen=True)
|
||||
class Return: value: UOp
|
||||
Stmt = Assign|Declare|If|For|Lambda|Break|Return
|
||||
|
||||
def _match(s, i, o, c):
|
||||
d = 1
|
||||
@@ -186,9 +188,10 @@ def expr(s: str) -> UOp:
|
||||
raise ValueError(f"Cannot parse expression: {s}")
|
||||
|
||||
def stmt(line: str) -> Stmt|None:
|
||||
line = line.split('//')[0].strip().rstrip(';')
|
||||
line = line.split('//')[0].strip().rstrip(';').rstrip('.')
|
||||
if not line: return None
|
||||
if line == 'break': return Break()
|
||||
if line[:7] == 'return ': return Return(expr(line[7:]))
|
||||
if line[:5] == 'eval ': return Assign(UOp(Ops.DEFINE_VAR, dtypes.void, arg=('_eval', None, None)), UOp(Ops.DEFINE_VAR, dtypes.void, arg=(line, None, None)))
|
||||
if line[:8] == 'declare ' and ':' in line:
|
||||
n, t = line[8:].split(':', 1)
|
||||
@@ -208,10 +211,29 @@ def stmt(line: str) -> Stmt|None:
|
||||
eq = line.index('=')
|
||||
if eq > 0 and line[eq-1] not in '!<>=' and eq < len(line)-1 and line[eq+1] != '=':
|
||||
return Assign(expr(line[:eq]), expr(line[eq+1:]))
|
||||
return None
|
||||
# Bare function call (e.g., nop())
|
||||
if re.match(r'\w+\([^)]*\)$', line):
|
||||
return expr(line)
|
||||
raise ValueError(f"Cannot parse statement: {line}")
|
||||
|
||||
def parse(code: str) -> tuple[Stmt, ...]:
|
||||
lines = [l.split('//')[0].strip() for l in code.strip().split('\n') if l.split('//')[0].strip()]
|
||||
# Join continuation lines (unbalanced parens) - but not for control flow or lambdas
|
||||
joined, j = [], 0
|
||||
while j < len(lines):
|
||||
ln = lines[j]
|
||||
# Don't join lambda lines - they have their own multiline handling
|
||||
if '= lambda(' not in ln:
|
||||
while ln.count('(') > ln.count(')') and j + 1 < len(lines):
|
||||
next_ln = lines[j + 1]
|
||||
# Don't join if next line is control flow or looks like a new statement
|
||||
if next_ln[:3] == 'if ' or next_ln[:4] == 'for ' or next_ln[:6] == 'elsif ' or next_ln == 'else' or \
|
||||
next_ln == 'endif' or next_ln == 'endfor' or '= lambda(' in next_ln: break
|
||||
j += 1
|
||||
ln += ' ' + next_ln
|
||||
joined.append(ln)
|
||||
j += 1
|
||||
lines = joined
|
||||
stmts, i = [], 0
|
||||
while i < len(lines):
|
||||
ln = lines[i].rstrip(';')
|
||||
@@ -224,7 +246,7 @@ def parse(code: str) -> tuple[Stmt, ...]:
|
||||
while i < len(lines) and not lines[i-1].rstrip().endswith(');'):
|
||||
body_lines.append(lines[i])
|
||||
i += 1
|
||||
body_text = ' '.join(body_lines).strip()
|
||||
body_text = '\n'.join(body_lines).strip()
|
||||
if body_text.endswith(');'): body_text = body_text[:-2]
|
||||
# Try to parse as expression first, then as statements
|
||||
try:
|
||||
@@ -241,22 +263,25 @@ def parse(code: str) -> tuple[Stmt, ...]:
|
||||
if d > 0: body.append(lines[i])
|
||||
i += 1
|
||||
stmts.append(For(m[1], expr(m[2]), expr(m[3]), parse('\n'.join(body)))); continue
|
||||
if ln[:3] == 'if ' and ' then' in ln:
|
||||
br, cond, body, i, depth = [], ln[3:ln.index(' then')], [], i+1, 1
|
||||
if ln[:3] == 'if ':
|
||||
cond = ln[3:ln.index(' then')] if ' then' in ln else ln[3:]
|
||||
br, body, i, depth = [], [], i+1, 1
|
||||
while i < len(lines) and depth > 0:
|
||||
line_i = lines[i].rstrip(';').rstrip('.')
|
||||
if line_i[:3] == 'if ' and ' then' in line_i: depth += 1; body.append(lines[i])
|
||||
if line_i[:3] == 'if ': depth += 1; body.append(lines[i])
|
||||
elif line_i == 'endif':
|
||||
depth -= 1
|
||||
if depth > 0: body.append(lines[i])
|
||||
elif depth == 1 and line_i[:6] == 'elsif ' and ' then' in line_i:
|
||||
br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:line_i.index(' then')], []
|
||||
elif depth == 1 and line_i[:6] == 'elsif ':
|
||||
cond_end = line_i.index(' then') if ' then' in line_i else len(line_i)
|
||||
br.append((expr(cond), parse('\n'.join(body)))); cond, body = line_i[6:cond_end], []
|
||||
elif depth == 1 and line_i == 'else':
|
||||
br.append((expr(cond), parse('\n'.join(body)))); cond, body = None, []
|
||||
else: body.append(lines[i])
|
||||
i += 1
|
||||
br.append((expr(cond) if cond else None, parse('\n'.join(body)))); stmts.append(If(tuple(br))); continue
|
||||
if ln == 'else' or ln[:6] == 'elsif ': raise ValueError(f"Unexpected {ln.split()[0]} without matching if")
|
||||
if s := stmt(ln): stmts.append(s)
|
||||
s = stmt(ln)
|
||||
if s is not None: stmts.append(s)
|
||||
i += 1
|
||||
return tuple(stmts)
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest, re, os
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break
|
||||
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
@@ -72,6 +72,7 @@ def _pr(n, d=0):
|
||||
return "\n".join(parts) + f"\n{p}endif"
|
||||
case For(v, s, e, b): return f"{p}for {v} in {_pr(s)} : {_pr(e)} do\n" + "\n".join(_pr(x, d) for x in b) + f"\n{p}endfor"
|
||||
case Break(): return f"{p}break"
|
||||
case Return(v): return f"{p}return {_pr(v)}"
|
||||
case Lambda(name, params, body):
|
||||
body_str = _pr(body) if isinstance(body, UOp) else "\n".join(_pr(x, d) for x in body)
|
||||
return f"{p}{name} = lambda({', '.join(params)}) (\n{body_str});"
|
||||
@@ -135,7 +136,7 @@ class TestQcodeParseAndRoundtrip(unittest.TestCase):
|
||||
if DEBUG:
|
||||
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
|
||||
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
||||
self.assertGreater(parse_rate, 98.5, f"Parse rate {parse_rate:.1f}% should be >98.5%")
|
||||
self.assertGreater(parse_rate, 98, f"Parse rate {parse_rate:.1f}% should be >98%")
|
||||
self.assertGreater(roundtrip_rate, 98, f"Roundtrip rate {roundtrip_rate:.1f}% should be >98%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user