mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
tests
This commit is contained in:
@@ -3,7 +3,9 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS as RDNA3_PCODE
|
||||
from extra.assembly.amd.autogen.rdna4.str_pcode import PSEUDOCODE_STRINGS as RDNA4_PCODE
|
||||
from extra.assembly.amd.autogen.cdna.str_pcode import PSEUDOCODE_STRINGS as CDNA_PCODE
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
|
||||
@@ -101,43 +103,47 @@ def _norm(s, keep_structure=False):
|
||||
s = re.sub(r'\|\|', '|', s)
|
||||
return s.strip()
|
||||
|
||||
def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
|
||||
ok, fail, match, errs = 0, 0, 0, {}
|
||||
for cls, ops in pcode_strings.items():
|
||||
for op, pc in ops.items():
|
||||
try:
|
||||
ast = parse(pc)
|
||||
ok += 1
|
||||
except Exception as e:
|
||||
fail += 1
|
||||
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
|
||||
continue
|
||||
rendered = _pr(ast)
|
||||
if _norm(pc) == _norm(rendered):
|
||||
match += 1
|
||||
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
|
||||
elif DEBUG:
|
||||
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
|
||||
rend_lines = [l for l in rendered.split('\n') if l.strip()]
|
||||
max_lines = max(len(orig_lines), len(rend_lines))
|
||||
print(f"{'='*60}\n{op.name}\n{'='*60}")
|
||||
w = 50
|
||||
for i in range(max_lines):
|
||||
oline = orig_lines[i] if i < len(orig_lines) else ''
|
||||
rline = rend_lines[i] if i < len(rend_lines) else ''
|
||||
line_match = _norm(oline) == _norm(rline)
|
||||
color = '' if line_match else '\033[31m'
|
||||
reset = '' if line_match else '\033[0m'
|
||||
print(f"{color}{oline:<{w}} | {rline}{reset}")
|
||||
total = ok + fail
|
||||
parse_rate = 100 * ok / total
|
||||
roundtrip_rate = 100 * match / ok if ok > 0 else 0
|
||||
if DEBUG:
|
||||
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
|
||||
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
||||
test.assertGreater(parse_rate, min_parse, f"Parse rate {parse_rate:.1f}% should be >{min_parse}%")
|
||||
test.assertGreater(roundtrip_rate, min_roundtrip, f"Roundtrip rate {roundtrip_rate:.1f}% should be >{min_roundtrip}%")
|
||||
|
||||
class TestQcodeParseAndRoundtrip(unittest.TestCase):
|
||||
def test_parse_and_roundtrip(self):
|
||||
ok, fail, match, errs = 0, 0, 0, {}
|
||||
for cls, ops in PSEUDOCODE_STRINGS.items():
|
||||
for op, pc in ops.items():
|
||||
try:
|
||||
ast = parse(pc)
|
||||
ok += 1
|
||||
except Exception as e:
|
||||
fail += 1
|
||||
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
|
||||
continue
|
||||
rendered = _pr(ast)
|
||||
if _norm(pc) == _norm(rendered):
|
||||
match += 1
|
||||
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
|
||||
elif DEBUG:
|
||||
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
|
||||
rend_lines = [l for l in rendered.split('\n') if l.strip()]
|
||||
max_lines = max(len(orig_lines), len(rend_lines))
|
||||
print(f"{'='*60}\n{op.name}\n{'='*60}")
|
||||
w = 50
|
||||
for i in range(max_lines):
|
||||
oline = orig_lines[i] if i < len(orig_lines) else ''
|
||||
rline = rend_lines[i] if i < len(rend_lines) else ''
|
||||
line_match = _norm(oline) == _norm(rline)
|
||||
color = '' if line_match else '\033[31m'
|
||||
reset = '' if line_match else '\033[0m'
|
||||
print(f"{color}{oline:<{w}} | {rline}{reset}")
|
||||
total = ok + fail
|
||||
parse_rate = 100 * ok / total
|
||||
roundtrip_rate = 100 * match / ok if ok > 0 else 0
|
||||
if DEBUG:
|
||||
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
|
||||
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
||||
self.assertGreater(parse_rate, 98, f"Parse rate {parse_rate:.1f}% should be >98%")
|
||||
self.assertGreater(roundtrip_rate, 98, f"Roundtrip rate {roundtrip_rate:.1f}% should be >98%")
|
||||
def test_rdna3(self): _test_arch(self, RDNA3_PCODE)
|
||||
def test_rdna4(self): _test_arch(self, RDNA4_PCODE, min_parse=96)
|
||||
def test_cdna(self): _test_arch(self, CDNA_PCODE, min_parse=78)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user