assembly/amd: test_roundtrip for cdna/rdna4 (#14066)

This commit is contained in:
qazal
2026-01-08 07:03:13 -05:00
committed by GitHub
parent 15a056715d
commit 309197bca5

View File

@@ -1,12 +1,17 @@
#!/usr/bin/env python3
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
import unittest, io, sys, re, subprocess, os
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.dsl import Inst
from extra.assembly.amd.asm import asm
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.asm import asm, detect_format
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
# arch: (mcpu, mattr)
ARCH_CONFIG = {
'rdna3': ('gfx1100', '+real-true16,+wavefrontsize32'),
'rdna4': ('gfx1200', '+real-true16,+wavefrontsize32'),
'cdna': ('gfx942', '+wavefrontsize64'),
}
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
old_stdout = sys.stdout
@@ -30,14 +35,15 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
continue
return results
def compile_asm(instr: str, mcpu: str = 'gfx1100') -> bytes:
def compile_asm(instr: str, arch: str = 'rdna3') -> bytes:
"""Compile a single instruction using LLVM."""
return compile_asm_batch([instr], mcpu)[0]
return compile_asm_batch([instr], arch)[0]
def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]:
def compile_asm_batch(instrs: list[str], arch: str = 'rdna3') -> list[bytes]:
"""Compile multiple instructions with a single llvm-mc call."""
if not instrs: return []
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
mcpu, mattr = ARCH_CONFIG[arch]
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-show-encoding'],
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
encodings = []
@@ -49,15 +55,16 @@ def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]:
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
return encodings
def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[str]:
def compile_and_disasm_batch(instrs: list[str], arch: str = 'rdna3') -> list[str]:
"""Compile instructions with LLVM and get LLVM's disassembly."""
import tempfile
if not instrs: return []
mcpu, mattr = ARCH_CONFIG[arch]
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n"
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
obj_path = f.name
try:
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-filetype=obj', '-o', obj_path],
input=src, capture_output=True, text=True)
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True)
@@ -73,6 +80,7 @@ def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[s
class TestTinygradKernelRoundtrip(unittest.TestCase):
"""Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern."""
arch = 'rdna3'
def _test_kernel_roundtrip(self, op_fn):
"""Generate kernel from op_fn, test:
@@ -80,11 +88,14 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
2. asm(disasm()) matches LLVM output
3. our disasm() matches LLVM's disassembly string exactly
"""
arch = self.arch
mcpu, mattr = ARCH_CONFIG[arch]
from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad
from tinygrad.runtime.support.compiler_amd import HIPCompiler
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
compiler = HIPCompiler('gfx1100')
compiler = HIPCompiler(mcpu)
# First pass: decode all instructions and collect info
decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
@@ -92,7 +103,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
offset = 0
while offset < len(kernel.code):
remaining = kernel.code[offset:]
fmt = detect_format(remaining)
fmt = detect_format(remaining, arch)
if fmt is None:
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
offset += 4
@@ -129,11 +140,11 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
disasm_test_instrs.append((idx, our_disasm))
# Batch compile for asm test
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs])
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], arch)
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
# Batch compile+disasm for disasm comparison test
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs])
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], arch)
disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)}
# Now evaluate results
@@ -184,9 +195,9 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
else:
disasm_skipped += 1
print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
print(f"[{arch}] decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
print(f"[{arch}] asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
print(f"[{arch}] disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
@@ -236,5 +247,11 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
# Fused ops
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
@unittest.skip("no asm support for RDNA4")
class TestTinygradKernelRoundtripRDNA4(TestTinygradKernelRoundtrip): arch = 'rdna4'
@unittest.skip("no asm support for CDNA")
class TestTinygradKernelRoundtripCDNA(TestTinygradKernelRoundtrip): arch = 'cdna'
if __name__ == "__main__":
unittest.main()