mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
test qcode
This commit is contained in:
@@ -236,128 +236,3 @@ def parse(code: str) -> tuple[Stmt, ...]:
|
||||
if s := stmt(ln): stmts.append(s)
|
||||
i += 1
|
||||
return tuple(stmts)
|
||||
|
||||
_OP_SYMS = {v: k for k, v in _BINOPS.items() if k not in ('>', '>=', '<>', '||', '&&')}
|
||||
_DT_STR = {v: k for k, v in _QDTYPES.items() if k in ('u64', 'u32', 'u16', 'u8', 'i64', 'i32', 'i16', 'i8', 'f64', 'f32', 'f16', 'bf16')}
|
||||
|
||||
def _dt_bits(dt: DType) -> int:
|
||||
if m := re.search(r'(\d+)', dt.name): return int(m[1])
|
||||
return dt.itemsize * 8
|
||||
|
||||
def _pr(n, d=0):
|
||||
p = " "*d
|
||||
match n:
|
||||
case UOp(Ops.CONST, dt, _, v):
|
||||
if dt == dtypes.int32: return str(int(v))
|
||||
if dt == dtypes.int64: return f"{int(v)}LL"
|
||||
if dt == dtypes.float32: return f"{v}F"
|
||||
if dt == dtypes.float16: return f"16'{v}"
|
||||
if dt == dtypes.uint32: return f"{int(v)}U"
|
||||
if dt == dtypes.uint64: return f"{int(v)}ULL"
|
||||
if dt == dtypes.int16: return f"16'{int(v)}"
|
||||
bits = _dt_bits(dt)
|
||||
if 'u' in dt.name or 'b' in dt.name: return f"{bits}'{int(v)}U"
|
||||
if 'f' in dt.name or 'float' in dt.name: return f"{bits}'{v}"
|
||||
if 'i' in dt.name or 'int' in dt.name: return f"{bits}'{int(v)}"
|
||||
return f"{v}"
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, _, _)): return name
|
||||
case UOp(Ops.BITCAST, dt, (e,)): return f"{_pr(e)}.{_DT_STR.get(dt, dt.name)}"
|
||||
case UOp(Ops.CUSTOMI, _, (e, h, l)):
|
||||
if h is l: return f"{_pr(e)}[{_pr(h)}]"
|
||||
# Detect [start +: width] pattern: hi = (start + width) - 1, lo = start
|
||||
if h.op == Ops.SUB and h.src[1].op == Ops.CONST and h.src[1].arg == 1 and h.src[0].op == Ops.ADD and h.src[0].src[0] == l:
|
||||
return f"{_pr(e)}[{_pr(l)} +: {_pr(h.src[0].src[1])}]"
|
||||
return f"{_pr(e)}[{_pr(h)} : {_pr(l)}]"
|
||||
case UOp(Ops.CAST, dt, (e,)): return f"{_dt_bits(dt)}'{_DT_STR.get(dt, dt.name)[0].upper()}({_pr(e)})"
|
||||
case UOp(Ops.NEG, _, (x,)): return f"-{_pr(x)}"
|
||||
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
|
||||
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
|
||||
case UOp(_, _, (l, r), _) if n.op in _OP_SYMS:
|
||||
sym = _OP_SYMS[n.op]
|
||||
left, right = l, r
|
||||
# Restore > / >= if tag indicates flipped
|
||||
if n.tag == 'flipped' and n.op == Ops.CMPLT: sym, left, right = '>', r, l
|
||||
if n.tag == 'flipped' and n.op == Ops.CMPLE: sym, left, right = '>=', r, l
|
||||
# Use <> for unordered not-equal
|
||||
if n.tag == '<>' and n.op == Ops.CMPNE: sym = '<>'
|
||||
|
||||
return f"{_pr(left)} {sym} {_pr(right)}"
|
||||
case UOp(Ops.WHERE, _, (c, t, f)): return f"{_pr(c)} ? {_pr(t)} : {_pr(f)}"
|
||||
case UOp(Ops.CUSTOM, _, args, 'MEM'): return f"MEM[{_pr(args[0])}]"
|
||||
case UOp(Ops.CUSTOM, _, args, name): return f"{name}({', '.join(_pr(x) for x in args)})"
|
||||
case UOp(Ops.CAT, _, exprs): return f"{{{', '.join(_pr(x) for x in exprs)}}}"
|
||||
case Assign(l, r):
|
||||
# Detect compound assignment: lhs = lhs op rhs -> lhs op= rhs (but not for PC)
|
||||
compound = {Ops.ADD: '+=', Ops.SUB: '-=', Ops.OR: '|=', Ops.AND: '&=', Ops.XOR: '^=', Ops.SHL: '<<=', Ops.SHR: '>>='}
|
||||
is_pc = l.op == Ops.DEFINE_VAR and l.arg[0] == 'PC'
|
||||
if r.op in compound and len(r.src) == 2 and r.src[0] == l and not is_pc:
|
||||
return f"{p}{_pr(l)} {compound[r.op]} {_pr(r.src[1])}"
|
||||
return f"{p}{_pr(l)} = {_pr(r)}"
|
||||
case Declare(name, dt):
|
||||
base = dt.scalar() if dt.count > 1 else dt
|
||||
suffix = f"[{dt.count}]" if dt.count > 1 else ""
|
||||
return f"{p}declare {name} : {_dt_bits(base)}'{_DT_STR.get(base, base.name)[0].upper()}{suffix}"
|
||||
case If(br):
|
||||
parts = []
|
||||
for i, (c, b) in enumerate(br):
|
||||
kw = "if" if i == 0 else "elsif" if c is not None else "else"
|
||||
cond = f" {_pr(c)} then" if c is not None else ""
|
||||
body = "\n".join(_pr(s, d) for s in b)
|
||||
parts.append(f"{p}{kw}{cond}\n{body}")
|
||||
return "\n".join(parts) + f"\n{p}endif"
|
||||
case For(v, s, e, b): return f"{p}for {v} in {_pr(s)} : {_pr(e)} do\n" + "\n".join(_pr(x, d) for x in b) + f"\n{p}endfor"
|
||||
case tuple(): return "\n".join(_pr(x, d) for x in n)
|
||||
case _: return f"{p}{n}"
|
||||
|
||||
def _norm(s, keep_structure=False):
|
||||
# Strip leading description lines (lines without = or ; that aren't declare/if/for)
|
||||
while True:
|
||||
m = re.match(r'^(?!declare|if |for )[^=;\n]+\n', s)
|
||||
if not m: break
|
||||
s = s[m.end():]
|
||||
s = re.sub(r'//[^\n]*', '', s) # strip comments
|
||||
if keep_structure:
|
||||
s = re.sub(r';', '', s) # strip semicolons only
|
||||
s = re.sub(r'\n\s*\n', '\n', s) # collapse blank lines
|
||||
else:
|
||||
s = re.sub(r'[;()\s]', '', s) # strip semicolons, parens, whitespace
|
||||
s = re.sub(r'_eval=', '', s) # strip _eval= prefix from eval statements
|
||||
s = re.sub(r'0x[0-9a-fA-F]+', lambda m: str(int(m[0], 16)), s) # hex to decimal
|
||||
s = re.sub(r'\.b(\d+)', r'.u\1', s) # .bXX -> .uXX
|
||||
s = re.sub(r"'B", "'U", s) # 'B -> 'U
|
||||
s = re.sub(r'(\d+\.\d+)F', r'\1', s) # strip F suffix from floats
|
||||
s = re.sub(r'\+INF', 'INF', s) # +INF -> INF
|
||||
s = re.sub(r'&&', '&', s) # && -> &
|
||||
s = re.sub(r'\|\|', '|', s) # || -> |
|
||||
return s.strip()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
ok, fail, match, errs = 0, 0, 0, {}
|
||||
for cls, ops in PSEUDOCODE_STRINGS.items():
|
||||
for op, pc in ops.items():
|
||||
try: ast = parse(pc); ok += 1
|
||||
except Exception as e: fail += 1; errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
|
||||
else:
|
||||
rendered = _pr(ast)
|
||||
if _norm(pc) == _norm(rendered):
|
||||
match += 1
|
||||
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
|
||||
elif DEBUG:
|
||||
# Show diff side by side - normalize AMD pcode (left), raw rendered (right)
|
||||
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
|
||||
rend_lines = [l for l in rendered.split('\n') if l.strip()]
|
||||
max_lines = max(len(orig_lines), len(rend_lines))
|
||||
print(f"{'='*60}\n{op.name}\n{'='*60}")
|
||||
w = 50 # column width
|
||||
for i in range(max_lines):
|
||||
oline = orig_lines[i] if i < len(orig_lines) else ''
|
||||
rline = rend_lines[i] if i < len(rend_lines) else ''
|
||||
line_match = _norm(oline) == _norm(rline)
|
||||
color = '' if line_match else '\033[31m'
|
||||
reset = '' if line_match else '\033[0m'
|
||||
print(f"{color}{oline:<{w}} | {rline}{reset}")
|
||||
print(f"Parsed: {ok}/{ok+fail} ({100*ok/(ok+fail):.1f}%), Match: {match}/{ok} ({100*match/ok:.1f}%)")
|
||||
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
||||
|
||||
137
extra/assembly/amd/test/test_qcode.py
Normal file
137
extra/assembly/amd/test/test_qcode.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import unittest, re, os
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
|
||||
_OP_SYMS = {v: k for k, v in _BINOPS.items() if k not in ('>', '>=', '<>', '||', '&&')}
|
||||
_DT_STR = {v: k for k, v in _QDTYPES.items() if k in ('u64', 'u32', 'u16', 'u8', 'i64', 'i32', 'i16', 'i8', 'f64', 'f32', 'f16', 'bf16')}
|
||||
|
||||
def _dt_bits(dt):
|
||||
if m := re.search(r'(\d+)', dt.name): return int(m[1])
|
||||
return dt.itemsize * 8
|
||||
|
||||
def _pr(n, d=0):
|
||||
p = " "*d
|
||||
match n:
|
||||
case UOp(Ops.CONST, dt, _, v):
|
||||
if dt == dtypes.int32: return str(int(v))
|
||||
if dt == dtypes.int64: return f"{int(v)}LL"
|
||||
if dt == dtypes.float32: return f"{v}F"
|
||||
if dt == dtypes.float16: return f"16'{v}"
|
||||
if dt == dtypes.uint32: return f"{int(v)}U"
|
||||
if dt == dtypes.uint64: return f"{int(v)}ULL"
|
||||
if dt == dtypes.int16: return f"16'{int(v)}"
|
||||
bits = _dt_bits(dt)
|
||||
if 'u' in dt.name or 'b' in dt.name: return f"{bits}'{int(v)}U"
|
||||
if 'f' in dt.name or 'float' in dt.name: return f"{bits}'{v}"
|
||||
if 'i' in dt.name or 'int' in dt.name: return f"{bits}'{int(v)}"
|
||||
return f"{v}"
|
||||
case UOp(Ops.DEFINE_VAR, _, _, (name, _, _)): return name
|
||||
case UOp(Ops.BITCAST, dt, (e,)): return f"{_pr(e)}.{_DT_STR.get(dt, dt.name)}"
|
||||
case UOp(Ops.CUSTOMI, _, (e, h, l)):
|
||||
if h is l: return f"{_pr(e)}[{_pr(h)}]"
|
||||
if h.op == Ops.SUB and h.src[1].op == Ops.CONST and h.src[1].arg == 1 and h.src[0].op == Ops.ADD and h.src[0].src[0] == l:
|
||||
return f"{_pr(e)}[{_pr(l)} +: {_pr(h.src[0].src[1])}]"
|
||||
return f"{_pr(e)}[{_pr(h)} : {_pr(l)}]"
|
||||
case UOp(Ops.CAST, dt, (e,)): return f"{_dt_bits(dt)}'{_DT_STR.get(dt, dt.name)[0].upper()}({_pr(e)})"
|
||||
case UOp(Ops.NEG, _, (x,)): return f"-{_pr(x)}"
|
||||
case UOp(Ops.XOR, _, (x,)) if len(n.src) == 1: return f"~{_pr(x)}"
|
||||
case UOp(Ops.CMPEQ, _, (x,)) if len(n.src) == 1: return f"!{_pr(x)}"
|
||||
case UOp(_, _, (l, r), _) if n.op in _OP_SYMS:
|
||||
sym = _OP_SYMS[n.op]
|
||||
left, right = l, r
|
||||
if n.tag == 'flipped' and n.op == Ops.CMPLT: sym, left, right = '>', r, l
|
||||
if n.tag == 'flipped' and n.op == Ops.CMPLE: sym, left, right = '>=', r, l
|
||||
if n.tag == '<>' and n.op == Ops.CMPNE: sym = '<>'
|
||||
return f"{_pr(left)} {sym} {_pr(right)}"
|
||||
case UOp(Ops.WHERE, _, (c, t, f)): return f"{_pr(c)} ? {_pr(t)} : {_pr(f)}"
|
||||
case UOp(Ops.CUSTOM, _, args, 'MEM'): return f"MEM[{_pr(args[0])}]"
|
||||
case UOp(Ops.CUSTOM, _, args, name): return f"{name}({', '.join(_pr(x) for x in args)})"
|
||||
case UOp(Ops.CAT, _, exprs): return f"{{{', '.join(_pr(x) for x in exprs)}}}"
|
||||
case Assign(l, r):
|
||||
compound = {Ops.ADD: '+=', Ops.SUB: '-=', Ops.OR: '|=', Ops.AND: '&=', Ops.XOR: '^=', Ops.SHL: '<<=', Ops.SHR: '>>='}
|
||||
is_pc = l.op == Ops.DEFINE_VAR and l.arg[0] == 'PC'
|
||||
if r.op in compound and len(r.src) == 2 and r.src[0] == l and not is_pc:
|
||||
return f"{p}{_pr(l)} {compound[r.op]} {_pr(r.src[1])}"
|
||||
return f"{p}{_pr(l)} = {_pr(r)}"
|
||||
case Declare(name, dt):
|
||||
base = dt.scalar() if dt.count > 1 else dt
|
||||
suffix = f"[{dt.count}]" if dt.count > 1 else ""
|
||||
return f"{p}declare {name} : {_dt_bits(base)}'{_DT_STR.get(base, base.name)[0].upper()}{suffix}"
|
||||
case If(br):
|
||||
parts = []
|
||||
for i, (c, b) in enumerate(br):
|
||||
kw = "if" if i == 0 else "elsif" if c is not None else "else"
|
||||
cond = f" {_pr(c)} then" if c is not None else ""
|
||||
body = "\n".join(_pr(s, d) for s in b)
|
||||
parts.append(f"{p}{kw}{cond}\n{body}")
|
||||
return "\n".join(parts) + f"\n{p}endif"
|
||||
case For(v, s, e, b): return f"{p}for {v} in {_pr(s)} : {_pr(e)} do\n" + "\n".join(_pr(x, d) for x in b) + f"\n{p}endfor"
|
||||
case tuple(): return "\n".join(_pr(x, d) for x in n)
|
||||
case _: return f"{p}{n}"
|
||||
|
||||
def _norm(s, keep_structure=False):
|
||||
while True:
|
||||
m = re.match(r'^(?!declare|if |for )[^=;\n]+\n', s)
|
||||
if not m: break
|
||||
s = s[m.end():]
|
||||
s = re.sub(r'//[^\n]*', '', s)
|
||||
if keep_structure:
|
||||
s = re.sub(r';', '', s)
|
||||
s = re.sub(r'\n\s*\n', '\n', s)
|
||||
else:
|
||||
s = re.sub(r'[;()\s]', '', s)
|
||||
s = re.sub(r'_eval=', '', s)
|
||||
s = re.sub(r'0x[0-9a-fA-F]+', lambda m: str(int(m[0], 16)), s)
|
||||
s = re.sub(r'\.b(\d+)', r'.u\1', s)
|
||||
s = re.sub(r"'B", "'U", s)
|
||||
s = re.sub(r'(\d+\.\d+)F', r'\1', s)
|
||||
s = re.sub(r'\+INF', 'INF', s)
|
||||
s = re.sub(r'&&', '&', s)
|
||||
s = re.sub(r'\|\|', '|', s)
|
||||
return s.strip()
|
||||
|
||||
class TestQcodeParseAndRoundtrip(unittest.TestCase):
|
||||
def test_parse_and_roundtrip(self):
|
||||
ok, fail, match, errs = 0, 0, 0, {}
|
||||
for cls, ops in PSEUDOCODE_STRINGS.items():
|
||||
for op, pc in ops.items():
|
||||
try:
|
||||
ast = parse(pc)
|
||||
ok += 1
|
||||
except Exception as e:
|
||||
fail += 1
|
||||
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
|
||||
continue
|
||||
rendered = _pr(ast)
|
||||
if _norm(pc) == _norm(rendered):
|
||||
match += 1
|
||||
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
|
||||
elif DEBUG:
|
||||
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
|
||||
rend_lines = [l for l in rendered.split('\n') if l.strip()]
|
||||
max_lines = max(len(orig_lines), len(rend_lines))
|
||||
print(f"{'='*60}\n{op.name}\n{'='*60}")
|
||||
w = 50
|
||||
for i in range(max_lines):
|
||||
oline = orig_lines[i] if i < len(orig_lines) else ''
|
||||
rline = rend_lines[i] if i < len(rend_lines) else ''
|
||||
line_match = _norm(oline) == _norm(rline)
|
||||
color = '' if line_match else '\033[31m'
|
||||
reset = '' if line_match else '\033[0m'
|
||||
print(f"{color}{oline:<{w}} | {rline}{reset}")
|
||||
total = ok + fail
|
||||
parse_rate = 100 * ok / total
|
||||
roundtrip_rate = 100 * match / ok if ok > 0 else 0
|
||||
if DEBUG:
|
||||
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
|
||||
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
||||
self.assertGreater(parse_rate, 99, f"Parse rate {parse_rate:.1f}% should be >99%")
|
||||
self.assertGreater(roundtrip_rate, 98, f"Roundtrip rate {roundtrip_rate:.1f}% should be >98%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user