From 55677ff8efeb45c00c75ab98e6abb5bb3f1c0ada Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 9 Jan 2026 17:02:02 -0800 Subject: [PATCH] comments --- extra/assembly/amd/README | 4 +++- extra/assembly/amd/pcode_parse.py | 10 ++++++++++ extra/assembly/amd/test/test_pcode_parse.py | 20 +++++++++++++++----- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/extra/assembly/amd/README b/extra/assembly/amd/README index d4b8697d6c..62dfad2986 100644 --- a/extra/assembly/amd/README +++ b/extra/assembly/amd/README @@ -9,6 +9,9 @@ Test with `PYTHONPATH="." pytest -n12 extra/assembly/amd/` * asm.py -- an asm/disasm function to transform to and from AMD assembly syntax * emu.py -- an emulator for RDNA that runs in tinygrad with `AMD=1 MOCKGPU=1 PYTHON_REMU=1` +* pcode_parse.py -- Parse the psuedocode from strings into ill-formed UOps +* pcode_transform.py -- Transform the ill-formed UOps into correctly formed and simplified UOps + The code should be as readable and deduplicated as possible. asm and emu shouldn't be required for dsl. The autogen folder is autogenerated from the AMD PDFs with `python3 -m extra.assembly.amd.pdf --arch all` @@ -36,4 +39,3 @@ IMPORTANT: if a test is failing in the emulator, it's an instruction bug. Use DE Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines. Get line count with `cloc --by-file extra/assembly/amd/*.py` - diff --git a/extra/assembly/amd/pcode_parse.py b/extra/assembly/amd/pcode_parse.py index 5d9ff4a9e1..49152078ec 100644 --- a/extra/assembly/amd/pcode_parse.py +++ b/extra/assembly/amd/pcode_parse.py @@ -94,16 +94,26 @@ def _infer_fn_dtype(name: str, srcs: tuple[UOp, ...]) -> DType: return dtypes.void # Statement types (control flow, not expressions) + +# these are UStatement (just one UOp) + +# TODO: this should be Ops.ASSIGN @dataclass(frozen=True) class Assign: lhs: UOp; rhs: UOp + +# TODO: this should be Ops.DEFINE_VAR @dataclass(frozen=True) class Declare: name: str; dtype: DType + +# this can all be late substitutes @dataclass(frozen=True) class If: branches: tuple[tuple[UOp|None, tuple[Stmt, ...]], ...] @dataclass(frozen=True) class For: var: str; start: UOp; end: UOp; body: tuple[Stmt, ...] @dataclass(frozen=True) class Lambda: name: str; params: tuple[str, ...]; body: tuple[Stmt, ...]|UOp + +# when are these two used? @dataclass(frozen=True) class Break: pass @dataclass(frozen=True) diff --git a/extra/assembly/amd/test/test_pcode_parse.py b/extra/assembly/amd/test/test_pcode_parse.py index 7f6610b2cb..2c6e4a5a30 100644 --- a/extra/assembly/amd/test/test_pcode_parse.py +++ b/extra/assembly/amd/test/test_pcode_parse.py @@ -1,8 +1,9 @@ -import unittest, re, os +import unittest, re, os, random from tinygrad.dtype import dtypes from tinygrad.uop import Ops from tinygrad.uop.ops import UOp from extra.assembly.amd.pcode_parse import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return +from extra.assembly.amd.pcode_transform import parse_transform from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS as RDNA3_PCODE from extra.assembly.amd.autogen.rdna4.str_pcode import PSEUDOCODE_STRINGS as RDNA4_PCODE from extra.assembly.amd.autogen.cdna.str_pcode import PSEUDOCODE_STRINGS as CDNA_PCODE @@ -230,8 +231,15 @@ def _pp(stmt, indent=0) -> str: def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98): ok, fail, match, void_ok, void_bad = 0, 0, 0, 0, 0 errs: dict[str, list[str]] = {} + + # test in random order + triples = [] for cls, ops in pcode_strings.items(): - for op, pc in ops.items(): + for op, pc in ops.items(): triples.append((cls, op, pc)) + random.shuffle(triples) + + if True: + for cls, op, pc in triples: try: ast = parse(pc) ok += 1 @@ -251,13 +259,15 @@ def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98): if DEBUG >= 2: print(f"{'='*60}\n\033[32m{op.name}\033[0m\n{'='*60}") print(pc) - for stmt in ast: print(_pp(stmt)) + if DEBUG >= 3: + ast_pt = parse_transform(pc) + for stmt in ast_pt: print(_pp(stmt)) elif DEBUG: orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()] rend_lines = [l for l in rendered.split('\n') if l.strip()] max_lines = max(len(orig_lines), len(rend_lines)) - print(f"{'='*60}\n{op.name}\n{'='*60}") - w = 50 + print(f"{'='*60}\n\033[31m{op.name}\033[0m\n{'='*60}") + w = 60 for i in range(max_lines): oline = orig_lines[i] if i < len(orig_lines) else '' rline = rend_lines[i] if i < len(rend_lines) else ''