mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
133 lines
6.4 KiB
Python
133 lines
6.4 KiB
Python
"""Shared test helpers for AMD tests."""
|
|
import ctypes
|
|
from tinygrad.helpers import unwrap
|
|
from tinygrad.runtime.autogen import llvm
|
|
from tinygrad.runtime.support.elf import elf_loader
|
|
|
|
ARCH_TO_TARGET:dict[str, list[str]] = {
|
|
"rdna3":["gfx1100", "gfx1151"],
|
|
"rdna4":["gfx1200", "gfx1201"],
|
|
"cdna":["gfx950", "gfx942"],
|
|
}
|
|
|
|
TARGET_TO_ARCH:dict[str, str] = {t:arch for arch,targets in ARCH_TO_TARGET.items() for t in targets}
|
|
|
|
_DPP16_RANGE_OPS = {0x100: "row_shl", 0x110: "row_shr", 0x120: "row_ror", 0x150: "row_newbcast", 0x160: "row_share", 0x170: "row_xmask"}
|
|
_DPP16_EXACT_OPS = {0x130: ("wave_shl", 1), 0x134: ("wave_rol", 1), 0x138: ("wave_shr", 1), 0x13c: ("wave_ror", 1),
|
|
0x140: ("row_mirror", 0), 0x141: ("row_half_mirror", 0), 0x142: ("row_bcast", 15), 0x143: ("row_bcast", 31)}
|
|
|
|
def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0]
|
|
|
|
def decode_dpp16(dpp: int) -> tuple[str, int | tuple[int, int, int, int]]:
|
|
"""Decode a DPP16 control word into a symbolic operation and argument."""
|
|
if dpp < 0x100: return "quad_perm", ((dpp >> 0) & 0x3, (dpp >> 2) & 0x3, (dpp >> 4) & 0x3, (dpp >> 6) & 0x3)
|
|
if dpp in _DPP16_EXACT_OPS: return _DPP16_EXACT_OPS[dpp]
|
|
if (base := dpp & 0x1f0) in _DPP16_RANGE_OPS: return _DPP16_RANGE_OPS[base], dpp & 0xf
|
|
return "dpp", dpp
|
|
|
|
def get_mattr(arch:str) -> str:
|
|
return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch]
|
|
|
|
# LLVM in-process assembler/disassembler (replaces llvm-mc and llvm-objdump subprocesses)
|
|
_SENTINEL = b'\xde\xad\xbe\xef'
|
|
_SENTINEL_ASM = '.byte 0xde, 0xad, 0xbe, 0xef'
|
|
|
|
def _cerr(): return ctypes.pointer(ctypes.pointer(ctypes.c_char()))
|
|
def _expect(x, err, ret=None):
|
|
if x: raise RuntimeError(unwrap(ctypes.cast(err.contents, ctypes.c_char_p).value).decode() if not isinstance(err, str) else err)
|
|
return ret
|
|
|
|
def _init_llvm():
|
|
for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter', 'Disassembler']:
|
|
getattr(llvm, f'LLVMInitializeAMDGPU{component}')()
|
|
|
|
def _create_target_machine(mcpu:str, mattr:str) -> llvm.LLVMTargetMachineRef:
|
|
target = _expect(llvm.LLVMGetTargetFromTriple(b'amdgcn-amd-amdhsa', ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=_cerr()), err, tgt)
|
|
return llvm.LLVMCreateTargetMachine(target, b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(),
|
|
llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocDefault, llvm.LLVMCodeModelDefault)
|
|
|
|
def _emit_obj(asm_text:str, mcpu:str, mattr:str, diag_errors:list[str]|None=None) -> bytes:
|
|
"""Assemble raw asm text into an ELF object using LLVM in-process."""
|
|
_init_llvm()
|
|
tm = _create_target_machine(mcpu, mattr)
|
|
ctx = llvm.LLVMContextCreate()
|
|
try:
|
|
errors = diag_errors if diag_errors is not None else []
|
|
@llvm.LLVMDiagnosticHandler
|
|
def handle_diag(diag_ref, _arg):
|
|
if llvm.LLVMGetDiagInfoSeverity(diag_ref) == llvm.LLVMDSError:
|
|
errors.append(ctypes.string_at(llvm.LLVMGetDiagInfoDescription(diag_ref)).decode())
|
|
llvm.LLVMContextSetDiagnosticHandler(ctx, handle_diag, None)
|
|
mod = llvm.LLVMModuleCreateWithNameInContext(b'asm', ctx)
|
|
llvm.LLVMSetTarget(mod, b'amdgcn-amd-amdhsa')
|
|
asm_bytes = asm_text.encode()
|
|
llvm.LLVMSetModuleInlineAsm2(mod, asm_bytes, len(asm_bytes))
|
|
buf = llvm.LLVMMemoryBufferRef()
|
|
_expect(llvm.LLVMTargetMachineEmitToMemoryBuffer(tm, mod, llvm.LLVMObjectFile, err:=_cerr(), ctypes.pointer(buf)), err)
|
|
obj = ctypes.string_at(llvm.LLVMGetBufferStart(buf), llvm.LLVMGetBufferSize(buf))
|
|
llvm.LLVMDisposeMemoryBuffer(buf)
|
|
llvm.LLVMDisposeModule(mod)
|
|
return obj
|
|
finally:
|
|
llvm.LLVMContextDispose(ctx)
|
|
llvm.LLVMDisposeTargetMachine(tm)
|
|
|
|
def _extract_text(obj:bytes) -> bytes:
|
|
"""Extract .text section from ELF object bytes."""
|
|
return next(s.content for s in elf_loader(obj)[1] if s.name == ".text")
|
|
|
|
def llvm_assemble(instrs:list[str], mcpu:str, mattr:str) -> list[bytes]:
|
|
"""Assemble instructions in one LLVM emission, return per-instruction bytes."""
|
|
if not instrs: return []
|
|
parts = []
|
|
for instr in instrs:
|
|
parts.append(instr)
|
|
parts.append(_SENTINEL_ASM)
|
|
text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr))
|
|
results, start = [], 0
|
|
for _ in instrs:
|
|
idx = text.find(_SENTINEL, start)
|
|
assert idx != -1, "sentinel not found in .text section"
|
|
results.append(bytes(text[start:idx]))
|
|
start = idx + len(_SENTINEL)
|
|
return results
|
|
|
|
def llvm_disasm(code:bytes, mcpu:str, mattr:str) -> list[str]:
|
|
"""Disassemble raw bytes into instruction strings using LLVM."""
|
|
_init_llvm()
|
|
dc = llvm.LLVMCreateDisasmCPUFeatures(b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(), None, 0,
|
|
llvm.LLVMOpInfoCallback(0), llvm.LLVMSymbolLookupCallback(0))
|
|
if not dc: raise RuntimeError(f"failed to create disasm context for {mcpu}")
|
|
llvm.LLVMSetDisasmOptions(dc, 2 | 4) # PrintImmHex | AsmPrinterVariant
|
|
try:
|
|
buf = ctypes.create_string_buffer(256)
|
|
arr = (ctypes.c_uint8 * len(code)).from_buffer_copy(code)
|
|
results, offset = [], 0
|
|
while offset < len(code):
|
|
size = llvm.LLVMDisasmInstruction(dc, ctypes.cast(ctypes.addressof(arr) + offset, ctypes.POINTER(ctypes.c_uint8)),
|
|
len(code) - offset, 0, buf, 256)
|
|
if size == 0: break
|
|
results.append(buf.value.decode().strip())
|
|
offset += size
|
|
return results
|
|
finally:
|
|
llvm.LLVMDisasmDispose(dc)
|
|
|
|
def llvm_filter_valid_asm(tests:list[tuple[str, bytes]], mcpu:str, mattr:str) -> list[tuple[str, bytes]]:
|
|
"""Filter out tests where original ASM isn't valid on target, and where LLVM roundtrip doesn't match."""
|
|
if not tests: return []
|
|
# Assemble all instructions at once with sentinels and diagnostic handler to detect failures
|
|
parts, diag_errors = [], [] # type: ignore[var-annotated]
|
|
for asm, _ in tests:
|
|
parts.append(asm)
|
|
parts.append(_SENTINEL_ASM)
|
|
text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr, diag_errors))
|
|
results, start = [], 0
|
|
for _ in tests:
|
|
idx = text.find(_SENTINEL, start)
|
|
assert idx != -1, "sentinel not found in .text section"
|
|
results.append(bytes(text[start:idx]))
|
|
start = idx + len(_SENTINEL)
|
|
# Invalid instructions produce 0 bytes; also filter where LLVM roundtrip doesn't match original
|
|
return [(asm, data) for (asm, data), chunk in zip(tests, results) if len(chunk) > 0 and chunk == data]
|