mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
comments
This commit is contained in:
@@ -9,6 +9,9 @@ Test with `PYTHONPATH="." pytest -n12 extra/assembly/amd/`
|
||||
* asm.py -- an asm/disasm function to transform to and from AMD assembly syntax
|
||||
* emu.py -- an emulator for RDNA that runs in tinygrad with `AMD=1 MOCKGPU=1 PYTHON_REMU=1`
|
||||
|
||||
* pcode_parse.py -- Parse the psuedocode from strings into ill-formed UOps
|
||||
* pcode_transform.py -- Transform the ill-formed UOps into correctly formed and simplified UOps
|
||||
|
||||
The code should be as readable and deduplicated as possible. asm and emu shouldn't be required for dsl.
|
||||
|
||||
The autogen folder is autogenerated from the AMD PDFs with `python3 -m extra.assembly.amd.pdf --arch all`
|
||||
@@ -36,4 +39,3 @@ IMPORTANT: if a test is failing in the emulator, it's an instruction bug. Use DE
|
||||
|
||||
Currently, only RDNA3 is well supported, but when finished, this will support RDNA3+RDNA4+CDNA in ~2000 lines.
|
||||
Get line count with `cloc --by-file extra/assembly/amd/*.py`
|
||||
|
||||
|
||||
@@ -94,16 +94,26 @@ def _infer_fn_dtype(name: str, srcs: tuple[UOp, ...]) -> DType:
|
||||
return dtypes.void
|
||||
|
||||
# Statement types (control flow, not expressions)
|
||||
|
||||
# these are UStatement (just one UOp)
|
||||
|
||||
# TODO: this should be Ops.ASSIGN
|
||||
@dataclass(frozen=True)
|
||||
class Assign: lhs: UOp; rhs: UOp
|
||||
|
||||
# TODO: this should be Ops.DEFINE_VAR
|
||||
@dataclass(frozen=True)
|
||||
class Declare: name: str; dtype: DType
|
||||
|
||||
# this can all be late substitutes
|
||||
@dataclass(frozen=True)
|
||||
class If: branches: tuple[tuple[UOp|None, tuple[Stmt, ...]], ...]
|
||||
@dataclass(frozen=True)
|
||||
class For: var: str; start: UOp; end: UOp; body: tuple[Stmt, ...]
|
||||
@dataclass(frozen=True)
|
||||
class Lambda: name: str; params: tuple[str, ...]; body: tuple[Stmt, ...]|UOp
|
||||
|
||||
# when are these two used?
|
||||
@dataclass(frozen=True)
|
||||
class Break: pass
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import unittest, re, os
|
||||
import unittest, re, os, random
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from extra.assembly.amd.pcode_parse import parse, _BINOPS, _QDTYPES, Assign, Declare, If, For, Lambda, Break, Return
|
||||
from extra.assembly.amd.pcode_transform import parse_transform
|
||||
from extra.assembly.amd.autogen.rdna3.str_pcode import PSEUDOCODE_STRINGS as RDNA3_PCODE
|
||||
from extra.assembly.amd.autogen.rdna4.str_pcode import PSEUDOCODE_STRINGS as RDNA4_PCODE
|
||||
from extra.assembly.amd.autogen.cdna.str_pcode import PSEUDOCODE_STRINGS as CDNA_PCODE
|
||||
@@ -230,8 +231,15 @@ def _pp(stmt, indent=0) -> str:
|
||||
def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
|
||||
ok, fail, match, void_ok, void_bad = 0, 0, 0, 0, 0
|
||||
errs: dict[str, list[str]] = {}
|
||||
|
||||
# test in random order
|
||||
triples = []
|
||||
for cls, ops in pcode_strings.items():
|
||||
for op, pc in ops.items():
|
||||
for op, pc in ops.items(): triples.append((cls, op, pc))
|
||||
random.shuffle(triples)
|
||||
|
||||
if True:
|
||||
for cls, op, pc in triples:
|
||||
try:
|
||||
ast = parse(pc)
|
||||
ok += 1
|
||||
@@ -251,13 +259,15 @@ def _test_arch(test, pcode_strings, min_parse=98, min_roundtrip=98):
|
||||
if DEBUG >= 2:
|
||||
print(f"{'='*60}\n\033[32m{op.name}\033[0m\n{'='*60}")
|
||||
print(pc)
|
||||
for stmt in ast: print(_pp(stmt))
|
||||
if DEBUG >= 3:
|
||||
ast_pt = parse_transform(pc)
|
||||
for stmt in ast_pt: print(_pp(stmt))
|
||||
elif DEBUG:
|
||||
orig_lines = [l for l in _norm(pc, keep_structure=True).split('\n') if l.strip()]
|
||||
rend_lines = [l for l in rendered.split('\n') if l.strip()]
|
||||
max_lines = max(len(orig_lines), len(rend_lines))
|
||||
print(f"{'='*60}\n{op.name}\n{'='*60}")
|
||||
w = 50
|
||||
print(f"{'='*60}\n\033[31m{op.name}\033[0m\n{'='*60}")
|
||||
w = 60
|
||||
for i in range(max_lines):
|
||||
oline = orig_lines[i] if i < len(orig_lines) else ''
|
||||
rline = rend_lines[i] if i < len(rend_lines) else ''
|
||||
|
||||
Reference in New Issue
Block a user