mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
* assembly/amd: refactor asm.py * assembly/amd: refactor asm.py to be simpler * multiple fxns * fast * more tests pass * regen * stop decode
202 lines
9.1 KiB
Python
202 lines
9.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Test RDNA3 assembler/disassembler against LLVM test vectors."""
|
|
import unittest, re, subprocess
|
|
from tinygrad.helpers import fetch
|
|
from extra.assembly.amd.autogen.rdna3 import *
|
|
from extra.assembly.amd.asm import asm
|
|
from extra.assembly.amd.test.helpers import get_llvm_mc
|
|
|
|
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
|
|
|
# Format info: (filename, format_class, op_enum)
|
|
LLVM_TEST_FILES = {
|
|
# Scalar ALU
|
|
'sop1': ('gfx11_asm_sop1.s', SOP1, SOP1Op),
|
|
'sop2': ('gfx11_asm_sop2.s', SOP2, SOP2Op),
|
|
'sopp': ('gfx11_asm_sopp.s', SOPP, SOPPOp),
|
|
'sopk': ('gfx11_asm_sopk.s', SOPK, SOPKOp),
|
|
'sopc': ('gfx11_asm_sopc.s', SOPC, SOPCOp),
|
|
# Vector ALU
|
|
'vop1': ('gfx11_asm_vop1.s', VOP1, VOP1Op),
|
|
'vop2': ('gfx11_asm_vop2.s', VOP2, VOP2Op),
|
|
'vopc': ('gfx11_asm_vopc.s', VOPC, VOPCOp),
|
|
'vop3': ('gfx11_asm_vop3.s', VOP3, VOP3Op),
|
|
'vop3p': ('gfx11_asm_vop3p.s', VOP3P, VOP3POp),
|
|
'vop3sd': ('gfx11_asm_vop3.s', VOP3SD, VOP3SDOp), # VOP3SD shares file with VOP3
|
|
'vinterp': ('gfx11_asm_vinterp.s', VINTERP, VINTERPOp),
|
|
'vopd': ('gfx11_asm_vopd.s', VOPD, VOPDOp),
|
|
'vopcx': ('gfx11_asm_vopcx.s', VOPC, VOPCOp), # VOPCX uses VOPC format
|
|
# VOP3 promotions (VOP1/VOP2/VOPC promoted to VOP3 encoding)
|
|
'vop3_from_vop1': ('gfx11_asm_vop3_from_vop1.s', VOP3, VOP3Op),
|
|
'vop3_from_vop2': ('gfx11_asm_vop3_from_vop2.s', VOP3, VOP3Op),
|
|
'vop3_from_vopc': ('gfx11_asm_vop3_from_vopc.s', VOP3, VOP3Op),
|
|
'vop3_from_vopcx': ('gfx11_asm_vop3_from_vopcx.s', VOP3, VOP3Op),
|
|
# Memory
|
|
'ds': ('gfx11_asm_ds.s', DS, DSOp),
|
|
'smem': ('gfx11_asm_smem.s', SMEM, SMEMOp),
|
|
'flat': ('gfx11_asm_flat.s', FLAT, FLATOp),
|
|
'mubuf': ('gfx11_asm_mubuf.s', MUBUF, MUBUFOp),
|
|
'mtbuf': ('gfx11_asm_mtbuf.s', MTBUF, MTBUFOp),
|
|
'mimg': ('gfx11_asm_mimg.s', MIMG, MIMGOp),
|
|
# WMMA (matrix multiply)
|
|
'wmma': ('gfx11_asm_wmma.s', VOP3P, VOP3POp),
|
|
# Additional features
|
|
'vop3_features': ('gfx11_asm_vop3_features.s', VOP3, VOP3Op),
|
|
'vop3p_features': ('gfx11_asm_vop3p_features.s', VOP3P, VOP3POp),
|
|
'vopd_features': ('gfx11_asm_vopd_features.s', VOPD, VOPDOp),
|
|
# Alias files (alternative mnemonics)
|
|
'vop3_alias': ('gfx11_asm_vop3_alias.s', VOP3, VOP3Op),
|
|
'vop3p_alias': ('gfx11_asm_vop3p_alias.s', VOP3P, VOP3POp),
|
|
'vopc_alias': ('gfx11_asm_vopc_alias.s', VOPC, VOPCOp),
|
|
'vopcx_alias': ('gfx11_asm_vopcx_alias.s', VOPC, VOPCOp),
|
|
'vinterp_alias': ('gfx11_asm_vinterp_alias.s', VINTERP, VINTERPOp),
|
|
'smem_alias': ('gfx11_asm_smem_alias.s', SMEM, SMEMOp),
|
|
'mubuf_alias': ('gfx11_asm_mubuf_alias.s', MUBUF, MUBUFOp),
|
|
'mtbuf_alias': ('gfx11_asm_mtbuf_alias.s', MTBUF, MTBUFOp),
|
|
}
|
|
|
|
def parse_llvm_tests(text: str) -> list[tuple[str, bytes]]:
|
|
"""Parse LLVM test format into (asm, expected_bytes) pairs."""
|
|
tests, lines = [], text.split('\n')
|
|
for i, line in enumerate(lines):
|
|
line = line.strip()
|
|
if not line or line.startswith(('//', '.', ';')): continue
|
|
asm_text = line.split('//')[0].strip()
|
|
if not asm_text: continue
|
|
for j in range(i, min(i + 3, len(lines))):
|
|
# Match GFX11, W32, or W64 encodings (all valid for gfx11)
|
|
# Format 1: "// GFX11: v_foo ... ; encoding: [0x01,0x02,...]"
|
|
# Format 2: "// GFX11: [0x01,0x02,...]" (used by DS, older files)
|
|
if m := re.search(r'(?:GFX11|W32|W64)[^:]*:.*?encoding:\s*\[(.*?)\]', lines[j]):
|
|
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
|
elif m := re.search(r'(?:GFX11|W32|W64)[^:]*:\s*\[(0x[0-9a-fA-F,x\s]+)\]', lines[j]):
|
|
hex_bytes = m.group(1).replace('0x', '').replace(',', '').replace(' ', '')
|
|
else:
|
|
continue
|
|
if hex_bytes:
|
|
try: tests.append((asm_text, bytes.fromhex(hex_bytes)))
|
|
except ValueError: pass
|
|
break
|
|
return tests
|
|
|
|
def try_assemble(text: str):
|
|
"""Try to assemble instruction text, return bytes or None on failure."""
|
|
try: return asm(text).to_bytes()
|
|
except: return None
|
|
|
|
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
|
"""Compile multiple instructions with a single llvm-mc call."""
|
|
if not instrs: return []
|
|
asm_text = ".text\n" + "\n".join(instrs) + "\n"
|
|
result = subprocess.run(
|
|
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
|
input=asm_text, capture_output=True, text=True, timeout=30)
|
|
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
|
# Parse all encodings from output
|
|
results = []
|
|
for line in result.stdout.split('\n'):
|
|
if 'encoding:' not in line: continue
|
|
enc = line.split('encoding:')[1].strip()
|
|
if enc.startswith('[') and enc.endswith(']'):
|
|
results.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
|
|
if len(results) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(results)}")
|
|
return results
|
|
|
|
class TestLLVM(unittest.TestCase):
|
|
"""Test assembler and disassembler against all LLVM test vectors."""
|
|
tests: dict[str, list[tuple[str, bytes]]] = {}
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
for name, (filename, _, _) in LLVM_TEST_FILES.items():
|
|
try:
|
|
data = fetch(f"{LLVM_BASE}/{filename}").read_bytes()
|
|
cls.tests[name] = parse_llvm_tests(data.decode('utf-8', errors='ignore'))
|
|
except Exception as e:
|
|
print(f"Warning: couldn't fetch {filename}: {e}")
|
|
cls.tests[name] = []
|
|
|
|
# Generate test methods dynamically for each format
|
|
def _make_asm_test(name):
|
|
def test(self):
|
|
passed, failed, skipped = 0, 0, 0
|
|
for asm_text, expected in self.tests.get(name, []):
|
|
result = try_assemble(asm_text)
|
|
if result is None: skipped += 1
|
|
elif result == expected: passed += 1
|
|
else: failed += 1
|
|
print(f"{name.upper()} asm: {passed} passed, {failed} failed, {skipped} skipped")
|
|
self.assertEqual(failed, 0)
|
|
return test
|
|
|
|
def _make_disasm_test(name):
|
|
def test(self):
|
|
_, fmt_cls, op_enum = LLVM_TEST_FILES[name]
|
|
# VOP3SD opcodes that share encoding with VOP3 (only for vop3sd test, not vopc promotions)
|
|
vop3sd_opcodes = {288, 289, 290, 764, 765, 766, 767, 768, 769, 770}
|
|
is_vopc_promotion = name in ('vop3_from_vopc', 'vop3_from_vopcx')
|
|
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
|
|
|
|
# First pass: decode all instructions and collect disasm strings
|
|
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
|
|
skipped = 0
|
|
for asm_text, data in self.tests.get(name, []):
|
|
if len(data) > fmt_cls._size(): continue
|
|
temp_inst = fmt_cls.from_bytes(data)
|
|
temp_op = temp_inst._values.get('op', 0)
|
|
temp_op = temp_op.val if hasattr(temp_op, 'val') else temp_op
|
|
if temp_op in undocumented.get(name, set()): skipped += 1; continue
|
|
if name == 'sopp':
|
|
simm16 = temp_inst._values.get('simm16', 0)
|
|
simm16 = simm16.val if hasattr(simm16, 'val') else simm16
|
|
sopp_no_imm = {48, 54, 53, 55, 60, 61, 62}
|
|
if temp_op in sopp_no_imm and simm16 != 0: skipped += 1; continue
|
|
try:
|
|
if fmt_cls.__name__ in ('VOP3', 'VOP3SD'):
|
|
temp = VOP3.from_bytes(data)
|
|
op_val = temp._values.get('op', 0)
|
|
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
|
is_vop3sd = (op_val in vop3sd_opcodes) and not is_vopc_promotion
|
|
decoded = VOP3SD.from_bytes(data) if is_vop3sd else VOP3.from_bytes(data)
|
|
if is_vop3sd: VOP3SDOp(op_val)
|
|
else: VOP3Op(op_val)
|
|
else:
|
|
decoded = fmt_cls.from_bytes(data)
|
|
op_val = decoded._values.get('op', 0)
|
|
op_val = op_val.val if hasattr(op_val, 'val') else op_val
|
|
op_enum(op_val)
|
|
if decoded.to_bytes()[:len(data)] != data:
|
|
to_test.append((asm_text, data, None, "decode roundtrip failed"))
|
|
continue
|
|
to_test.append((asm_text, data, decoded.disasm(), None))
|
|
except Exception as e:
|
|
to_test.append((asm_text, data, None, f"exception: {e}"))
|
|
|
|
# Batch compile all disasm strings with single llvm-mc call
|
|
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
|
|
llvm_results = compile_asm_batch([s for _, s in disasm_strs]) if disasm_strs else []
|
|
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
|
|
|
# Match results back
|
|
passed, failed = 0, 0
|
|
failures: list[str] = []
|
|
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
|
|
if error:
|
|
failed += 1; failures.append(f"{error} for {data.hex()}")
|
|
elif disasm_str is not None and idx in llvm_map:
|
|
llvm_bytes = llvm_map[idx]
|
|
if llvm_bytes is not None and llvm_bytes == data: passed += 1
|
|
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
|
|
|
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
|
|
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
|
self.assertEqual(failed, 0)
|
|
return test
|
|
|
|
for name in LLVM_TEST_FILES:
|
|
setattr(TestLLVM, f'test_{name}_asm', _make_asm_test(name))
|
|
setattr(TestLLVM, f'test_{name}_disasm', _make_disasm_test(name))
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|