mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tests
This commit is contained in:
@@ -3,7 +3,9 @@ from tinygrad.dtype import dtypes
|
|||||||
from tinygrad.uop import Ops
|
from tinygrad.uop import Ops
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return
|
from extra.assembly.amd.qcode import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return
|
||||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS
|
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS as RDNA3_PCODE
|
||||||
|
from extra.assembly.amd.autogen.rdna4.str_pcode import PSEUDOCODE_STRINGS as RDNA4_PCODE
|
||||||
|
from extra.assembly.amd.autogen.cdna.str_pcode import PSEUDOCODE_STRINGS as CDNA_PCODE
|
||||||
|
|
||||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||||
|
|
||||||
@@ -101,43 +103,47 @@ def _norm(s, keep_structure=False):
|
|||||||
s = re.sub(r'\|\|', '|', s)
|
s = re.sub(r'\|\|', '|', s)
|
||||||
return s.strip()
|
return s.strip()
|
||||||
|
|
||||||
|
def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
|
||||||
|
ok, fail, match, errs = 0, 0, 0, {}
|
||||||
|
for cls, ops in pcode_strings.items():
|
||||||
|
for op, pc in ops.items():
|
||||||
|
try:
|
||||||
|
ast = parse(pc)
|
||||||
|
ok += 1
|
||||||
|
except Exception as e:
|
||||||
|
fail += 1
|
||||||
|
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
|
||||||
|
continue
|
||||||
|
rendered = _pr(ast)
|
||||||
|
if _norm(pc) == _norm(rendered):
|
||||||
|
match += 1
|
||||||
|
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
|
||||||
|
elif DEBUG:
|
||||||
|
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
|
||||||
|
rend_lines = [l for l in rendered.split('\n') if l.strip()]
|
||||||
|
max_lines = max(len(orig_lines), len(rend_lines))
|
||||||
|
print(f"{'='*60}\n{op.name}\n{'='*60}")
|
||||||
|
w = 50
|
||||||
|
for i in range(max_lines):
|
||||||
|
oline = orig_lines[i] if i < len(orig_lines) else ''
|
||||||
|
rline = rend_lines[i] if i < len(rend_lines) else ''
|
||||||
|
line_match = _norm(oline) == _norm(rline)
|
||||||
|
color = '' if line_match else '\033[31m'
|
||||||
|
reset = '' if line_match else '\033[0m'
|
||||||
|
print(f"{color}{oline:<{w}} | {rline}{reset}")
|
||||||
|
total = ok + fail
|
||||||
|
parse_rate = 100 * ok / total
|
||||||
|
roundtrip_rate = 100 * match / ok if ok > 0 else 0
|
||||||
|
if DEBUG:
|
||||||
|
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
|
||||||
|
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
||||||
|
test.assertGreater(parse_rate, min_parse, f"Parse rate {parse_rate:.1f}% should be >{min_parse}%")
|
||||||
|
test.assertGreater(roundtrip_rate, min_roundtrip, f"Roundtrip rate {roundtrip_rate:.1f}% should be >{min_roundtrip}%")
|
||||||
|
|
||||||
class TestQcodeParseAndRoundtrip(unittest.TestCase):
|
class TestQcodeParseAndRoundtrip(unittest.TestCase):
|
||||||
def test_parse_and_roundtrip(self):
|
def test_rdna3(self): _test_arch(self, RDNA3_PCODE)
|
||||||
ok, fail, match, errs = 0, 0, 0, {}
|
def test_rdna4(self): _test_arch(self, RDNA4_PCODE, min_parse=96)
|
||||||
for cls, ops in PSEUDOCODE_STRINGS.items():
|
def test_cdna(self): _test_arch(self, CDNA_PCODE, min_parse=78)
|
||||||
for op, pc in ops.items():
|
|
||||||
try:
|
|
||||||
ast = parse(pc)
|
|
||||||
ok += 1
|
|
||||||
except Exception as e:
|
|
||||||
fail += 1
|
|
||||||
errs[str(e)[:60]] = errs.get(str(e)[:60], 0) + 1
|
|
||||||
continue
|
|
||||||
rendered = _pr(ast)
|
|
||||||
if _norm(pc) == _norm(rendered):
|
|
||||||
match += 1
|
|
||||||
if DEBUG >= 2: print(f"\033[32m{op.name}\033[0m")
|
|
||||||
elif DEBUG:
|
|
||||||
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
|
|
||||||
rend_lines = [l for l in rendered.split('\n') if l.strip()]
|
|
||||||
max_lines = max(len(orig_lines), len(rend_lines))
|
|
||||||
print(f"{'='*60}\n{op.name}\n{'='*60}")
|
|
||||||
w = 50
|
|
||||||
for i in range(max_lines):
|
|
||||||
oline = orig_lines[i] if i < len(orig_lines) else ''
|
|
||||||
rline = rend_lines[i] if i < len(rend_lines) else ''
|
|
||||||
line_match = _norm(oline) == _norm(rline)
|
|
||||||
color = '' if line_match else '\033[31m'
|
|
||||||
reset = '' if line_match else '\033[0m'
|
|
||||||
print(f"{color}{oline:<{w}} | {rline}{reset}")
|
|
||||||
total = ok + fail
|
|
||||||
parse_rate = 100 * ok / total
|
|
||||||
roundtrip_rate = 100 * match / ok if ok > 0 else 0
|
|
||||||
if DEBUG:
|
|
||||||
print(f"Parsed: {ok}/{total} ({parse_rate:.1f}%), Match: {match}/{ok} ({roundtrip_rate:.1f}%)")
|
|
||||||
for e, c in sorted(errs.items(), key=lambda x: -x[1])[:10]: print(f" {c}: {e}")
|
|
||||||
self.assertGreater(parse_rate, 98, f"Parse rate {parse_rate:.1f}% should be >98%")
|
|
||||||
self.assertGreater(roundtrip_rate, 98, f"Roundtrip rate {roundtrip_rate:.1f}% should be >98%")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user