mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
assembly/amd: test_roundtrip for cdna/rdna4 (#14066)
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
|
||||
import unittest, io, sys, re, subprocess, os
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd.asm import asm
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.asm import asm, detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
|
||||
|
||||
# arch: (mcpu, mattr)
|
||||
ARCH_CONFIG = {
|
||||
'rdna3': ('gfx1100', '+real-true16,+wavefrontsize32'),
|
||||
'rdna4': ('gfx1200', '+real-true16,+wavefrontsize32'),
|
||||
'cdna': ('gfx942', '+wavefrontsize64'),
|
||||
}
|
||||
|
||||
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
|
||||
old_stdout = sys.stdout
|
||||
@@ -30,14 +35,15 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
continue
|
||||
return results
|
||||
|
||||
def compile_asm(instr: str, mcpu: str = 'gfx1100') -> bytes:
|
||||
def compile_asm(instr: str, arch: str = 'rdna3') -> bytes:
|
||||
"""Compile a single instruction using LLVM."""
|
||||
return compile_asm_batch([instr], mcpu)[0]
|
||||
return compile_asm_batch([instr], arch)[0]
|
||||
|
||||
def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]:
|
||||
def compile_asm_batch(instrs: list[str], arch: str = 'rdna3') -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
mcpu, mattr = ARCH_CONFIG[arch]
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
encodings = []
|
||||
@@ -49,15 +55,16 @@ def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]:
|
||||
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
|
||||
return encodings
|
||||
|
||||
def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[str]:
|
||||
def compile_and_disasm_batch(instrs: list[str], arch: str = 'rdna3') -> list[str]:
|
||||
"""Compile instructions with LLVM and get LLVM's disassembly."""
|
||||
import tempfile
|
||||
if not instrs: return []
|
||||
mcpu, mattr = ARCH_CONFIG[arch]
|
||||
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n"
|
||||
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
|
||||
obj_path = f.name
|
||||
try:
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True)
|
||||
@@ -73,6 +80,7 @@ def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[s
|
||||
|
||||
class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
"""Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern."""
|
||||
arch = 'rdna3'
|
||||
|
||||
def _test_kernel_roundtrip(self, op_fn):
|
||||
"""Generate kernel from op_fn, test:
|
||||
@@ -80,11 +88,14 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
2. asm(disasm()) matches LLVM output
|
||||
3. our disasm() matches LLVM's disassembly string exactly
|
||||
"""
|
||||
arch = self.arch
|
||||
mcpu, mattr = ARCH_CONFIG[arch]
|
||||
|
||||
from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
compiler = HIPCompiler(mcpu)
|
||||
|
||||
# First pass: decode all instructions and collect info
|
||||
decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
|
||||
@@ -92,7 +103,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
offset = 0
|
||||
while offset < len(kernel.code):
|
||||
remaining = kernel.code[offset:]
|
||||
fmt = detect_format(remaining)
|
||||
fmt = detect_format(remaining, arch)
|
||||
if fmt is None:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
@@ -129,11 +140,11 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
disasm_test_instrs.append((idx, our_disasm))
|
||||
|
||||
# Batch compile for asm test
|
||||
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs])
|
||||
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], arch)
|
||||
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
|
||||
|
||||
# Batch compile+disasm for disasm comparison test
|
||||
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs])
|
||||
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], arch)
|
||||
disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)}
|
||||
|
||||
# Now evaluate results
|
||||
@@ -184,9 +195,9 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
|
||||
print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
print(f"[{arch}] decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"[{arch}] asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"[{arch}] disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
|
||||
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
|
||||
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
|
||||
@@ -236,5 +247,11 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
# Fused ops
|
||||
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
|
||||
|
||||
@unittest.skip("no asm support for RDNA4")
|
||||
class TestTinygradKernelRoundtripRDNA4(TestTinygradKernelRoundtrip): arch = 'rdna4'
|
||||
|
||||
@unittest.skip("no asm support for CDNA")
|
||||
class TestTinygradKernelRoundtripCDNA(TestTinygradKernelRoundtrip): arch = 'cdna'
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user