This commit is contained in:
George Hotz
2026-01-05 19:57:30 -08:00
parent 6de310c87f
commit c8c6346336

View File

@@ -3,7 +3,9 @@ from tinygrad.dtype import dtypes
from tinygrad.uop import Ops
from tinygrad.uop.ops import UOp
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS as RDNA3_PCODE
from extra.assembly.amd.autogen.rdna4.str_pcode import PSEUDOCODE_STRINGS as RDNA4_PCODE
from extra.assembly.amd.autogen.cdna.str_pcode import PSEUDOCODE_STRINGS as CDNA_PCODE
DEBUG = int(os.getenv("DEBUG", "0"))
@@ -101,43 +103,47 @@ def _norm(s, keep_structure=False):
s = re.sub(r'\|\|', '|', s)
return s.strip()
def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
ok, fail, match, errs = 0, 0, 0, {}
for cls, ops in pcode_strings.items():
for op, pc in ops.items():
try:
ast = parse(pc)
ok += 1
except Exception as e:
fail += 1
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
continue
rendered = _pr(ast)
if _norm(pc) == _norm(rendered):
match += 1
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
elif DEBUG:
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
rend_lines = [l for l in rendered.split('\n') if l.strip()]
max_lines = max(len(orig_lines), len(rend_lines))
print(f"{'='*60}\n{op.name}\n{'='*60}")
w = 50
for i in range(max_lines):
oline = orig_lines[i] if i < len(orig_lines) else ''
rline = rend_lines[i] if i < len(rend_lines) else ''
line_match = _norm(oline) == _norm(rline)
color = '' if line_match else '\033[31m'
reset = '' if line_match else '\033[0m'
print(f"{color}{oline:<{w}} | {rline}{reset}")
total = ok + fail
parse_rate = 100 * ok / total
roundtrip_rate = 100 * match / ok if ok > 0 else 0
if DEBUG:
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
test.assertGreater(parse_rate, min_parse, f"Parse rate {parse_rate:.1f}% should be >{min_parse}%")
test.assertGreater(roundtrip_rate, min_roundtrip, f"Roundtrip rate {roundtrip_rate:.1f}% should be >{min_roundtrip}%")
class TestQcodeParseAndRoundtrip(unittest.TestCase):
def test_parse_and_roundtrip(self):
ok, fail, match, errs = 0, 0, 0, {}
for cls, ops in PSEUDOCODE_STRINGS.items():
for op, pc in ops.items():
try:
ast = parse(pc)
ok += 1
except Exception as e:
fail += 1
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
continue
rendered = _pr(ast)
if _norm(pc) == _norm(rendered):
match += 1
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
elif DEBUG:
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
rend_lines = [l for l in rendered.split('\n') if l.strip()]
max_lines = max(len(orig_lines), len(rend_lines))
print(f"{'='*60}\n{op.name}\n{'='*60}")
w = 50
for i in range(max_lines):
oline = orig_lines[i] if i < len(orig_lines) else ''
rline = rend_lines[i] if i < len(rend_lines) else ''
line_match = _norm(oline) == _norm(rline)
color = '' if line_match else '\033[31m'
reset = '' if line_match else '\033[0m'
print(f"{color}{oline:<{w}} | {rline}{reset}")
total = ok + fail
parse_rate = 100 * ok / total
roundtrip_rate = 100 * match / ok if ok > 0 else 0
if DEBUG:
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
self.assertGreater(parse_rate, 98, f"Parse rate {parse_rate:.1f}% should be >98%")
self.assertGreater(roundtrip_rate, 98, f"Roundtrip rate {roundtrip_rate:.1f}% should be >98%")
def test_rdna3(self): _test_arch(self, RDNA3_PCODE)
def test_rdna4(self): _test_arch(self, RDNA4_PCODE, min_parse=96)
def test_cdna(self): _test_arch(self, CDNA_PCODE, min_parse=78)
if __name__ == "__main__":
unittest.main()