diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 298cc83e2f..7d66d4aac8 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -56,7 +56,7 @@ jobs: - name: Run pytest -nauto run: | source /tmp/tinygrad_pytest_ci/bin/activate - pytest -nauto --ignore test/amd/ --durations=20 + pytest -nauto --durations=20 testmacbenchmark: name: Mac Benchmark diff --git a/test/amd/helpers.py b/test/amd/helpers.py index 143f16bd07..15fc3fc7dc 100644 --- a/test/amd/helpers.py +++ b/test/amd/helpers.py @@ -1,6 +1,9 @@ -"""Shared test helpers for RDNA3 tests.""" -import shutil +"""Shared test helpers for AMD tests.""" +import ctypes from dataclasses import dataclass +from tinygrad.helpers import unwrap +from tinygrad.runtime.autogen import llvm +from tinygrad.runtime.support.elf import elf_loader @dataclass class KernelInfo: @@ -11,19 +14,6 @@ class KernelInfo: buf_idxs: list[int] # indices into shared buffer pool buf_sizes: list[int] # sizes for each buffer index -# LLVM tool detection (shared across test files) -def get_llvm_mc(): - """Find llvm-mc executable, preferring newer versions.""" - for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']: - if shutil.which(p): return p - raise FileNotFoundError("llvm-mc not found") - -def get_llvm_objdump(): - """Find llvm-objdump executable, preferring newer versions.""" - for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']: - if shutil.which(p): return p - raise FileNotFoundError("llvm-objdump not found") - ARCH_TO_TARGET:dict[str, list[str]] = { "rdna3":["gfx1100"], "rdna4":["gfx1200"], @@ -35,4 +25,107 @@ TARGET_TO_ARCH:dict[str, str] = {t:arch for arch,targets in ARCH_TO_TARGET.items def get_target(arch:str) -> str: return ARCH_TO_TARGET[arch][0] def get_mattr(arch:str) -> str: - return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch] \ No newline at end of file + return {"rdna3":"+real-true16,+wavefrontsize32", "rdna4":"+real-true16,+wavefrontsize32", "cdna":"+wavefrontsize64"}[arch] + +# LLVM in-process assembler/disassembler (replaces llvm-mc and llvm-objdump subprocesses) +_SENTINEL = b'\xde\xad\xbe\xef' +_SENTINEL_ASM = '.byte 0xde, 0xad, 0xbe, 0xef' + +def _cerr(): return ctypes.pointer(ctypes.pointer(ctypes.c_char())) +def _expect(x, err, ret=None): + if x: raise RuntimeError(unwrap(ctypes.cast(err.contents, ctypes.c_char_p).value).decode() if not isinstance(err, str) else err) + return ret + +def _init_llvm(): + for component in ['Target', 'TargetInfo', 'TargetMC', 'AsmParser', 'AsmPrinter', 'Disassembler']: + getattr(llvm, f'LLVMInitializeAMDGPU{component}')() + +def _create_target_machine(mcpu:str, mattr:str) -> llvm.LLVMTargetMachineRef: + target = _expect(llvm.LLVMGetTargetFromTriple(b'amdgcn-amd-amdhsa', ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=_cerr()), err, tgt) + return llvm.LLVMCreateTargetMachine(target, b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(), + llvm.LLVMCodeGenLevelDefault, llvm.LLVMRelocDefault, llvm.LLVMCodeModelDefault) + +def _emit_obj(asm_text:str, mcpu:str, mattr:str, diag_errors:list[str]|None=None) -> bytes: + """Assemble raw asm text into an ELF object using LLVM in-process.""" + _init_llvm() + tm = _create_target_machine(mcpu, mattr) + ctx = llvm.LLVMContextCreate() + try: + errors = diag_errors if diag_errors is not None else [] + @llvm.LLVMDiagnosticHandler + def handle_diag(diag_ref, _arg): + if llvm.LLVMGetDiagInfoSeverity(diag_ref) == llvm.LLVMDSError: + errors.append(ctypes.string_at(llvm.LLVMGetDiagInfoDescription(diag_ref)).decode()) + llvm.LLVMContextSetDiagnosticHandler(ctx, handle_diag, None) + mod = llvm.LLVMModuleCreateWithNameInContext(b'asm', ctx) + llvm.LLVMSetTarget(mod, b'amdgcn-amd-amdhsa') + asm_bytes = asm_text.encode() + llvm.LLVMSetModuleInlineAsm2(mod, asm_bytes, len(asm_bytes)) + buf = llvm.LLVMMemoryBufferRef() + _expect(llvm.LLVMTargetMachineEmitToMemoryBuffer(tm, mod, llvm.LLVMObjectFile, err:=_cerr(), ctypes.pointer(buf)), err) + obj = ctypes.string_at(llvm.LLVMGetBufferStart(buf), llvm.LLVMGetBufferSize(buf)) + llvm.LLVMDisposeMemoryBuffer(buf) + llvm.LLVMDisposeModule(mod) + return obj + finally: + llvm.LLVMContextDispose(ctx) + llvm.LLVMDisposeTargetMachine(tm) + +def _extract_text(obj:bytes) -> bytes: + """Extract .text section from ELF object bytes.""" + return next(s.content for s in elf_loader(obj)[1] if s.name == ".text") + +def llvm_assemble(instrs:list[str], mcpu:str, mattr:str) -> list[bytes]: + """Assemble instructions in one LLVM emission, return per-instruction bytes.""" + if not instrs: return [] + parts = [] + for instr in instrs: + parts.append(instr) + parts.append(_SENTINEL_ASM) + text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr)) + results, start = [], 0 + for _ in instrs: + idx = text.find(_SENTINEL, start) + assert idx != -1, "sentinel not found in .text section" + results.append(bytes(text[start:idx])) + start = idx + len(_SENTINEL) + return results + +def llvm_disasm(code:bytes, mcpu:str, mattr:str) -> list[str]: + """Disassemble raw bytes into instruction strings using LLVM.""" + _init_llvm() + dc = llvm.LLVMCreateDisasmCPUFeatures(b'amdgcn-amd-amdhsa', mcpu.encode(), mattr.encode(), None, 0, + llvm.LLVMOpInfoCallback(0), llvm.LLVMSymbolLookupCallback(0)) + if not dc: raise RuntimeError(f"failed to create disasm context for {mcpu}") + llvm.LLVMSetDisasmOptions(dc, 2 | 4) # PrintImmHex | AsmPrinterVariant + try: + buf = ctypes.create_string_buffer(256) + arr = (ctypes.c_uint8 * len(code)).from_buffer_copy(code) + results, offset = [], 0 + while offset < len(code): + size = llvm.LLVMDisasmInstruction(dc, ctypes.cast(ctypes.addressof(arr) + offset, ctypes.POINTER(ctypes.c_uint8)), + len(code) - offset, 0, buf, 256) + if size == 0: break + results.append(buf.value.decode().strip()) + offset += size + return results + finally: + llvm.LLVMDisasmDispose(dc) + +def llvm_filter_valid_asm(tests:list[tuple[str, bytes]], mcpu:str, mattr:str) -> list[tuple[str, bytes]]: + """Filter out tests where original ASM isn't valid on target, and where LLVM roundtrip doesn't match.""" + if not tests: return [] + # Assemble all instructions at once with sentinels and diagnostic handler to detect failures + parts, diag_errors = [], [] # type: ignore[var-annotated] + for asm, _ in tests: + parts.append(asm) + parts.append(_SENTINEL_ASM) + text = _extract_text(_emit_obj('.text\n' + '\n'.join(parts) + '\n', mcpu, mattr, diag_errors)) + results, start = [], 0 + for _ in tests: + idx = text.find(_SENTINEL, start) + assert idx != -1, "sentinel not found in .text section" + results.append(bytes(text[start:idx])) + start = idx + len(_SENTINEL) + # Invalid instructions produce 0 bytes; also filter where LLVM roundtrip doesn't match original + return [(asm, data) for (asm, data), chunk in zip(tests, results) if len(chunk) > 0 and chunk == data] diff --git a/test/amd/test_integration.py b/test/amd/test_integration.py index 970ef3469a..880ce34c84 100644 --- a/test/amd/test_integration.py +++ b/test/amd/test_integration.py @@ -1,55 +1,23 @@ #!/usr/bin/env python3 -"""Integration test: round-trip RDNA3 assembly through AMD toolchain.""" -import unittest, io, sys +"""Integration test: round-trip RDNA3 assembly through LLVM toolchain.""" +import unittest from tinygrad.runtime.autogen.amd.rdna3.ins import * +from test.amd.helpers import llvm_assemble, llvm_disasm def waitcnt(vmcnt: int = 0x3f, expcnt: int = 0x7, lgkmcnt: int = 0x3f) -> int: return (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10) -def disassemble(lib: bytes, arch: str = "gfx1100") -> str: - """Disassemble ELF binary using tinygrad's compiler, return raw output.""" - from tinygrad.runtime.support.compiler_amd import HIPCompiler - old_stdout = sys.stdout - sys.stdout = io.StringIO() - HIPCompiler(arch).disassemble(lib) - output = sys.stdout.getvalue() - sys.stdout = old_stdout - return output - -def parse_disassembly(raw: str) -> list[str]: - """Parse disassembly output to list of instruction mnemonics.""" - lines = [] - for line in raw.splitlines(): - if line.startswith('\t'): - instr = line.split('//')[0].strip() - if instr: lines.append(instr) - return lines - -def assemble_and_disassemble(instructions: list, arch: str = "gfx1100") -> list[str]: - """Assemble instructions with our DSL, then disassemble with AMD toolchain.""" - from tinygrad.runtime.support.compiler_amd import HIPCompiler - - # Generate bytes from our DSL +def assemble_and_disassemble(instructions: list, mcpu: str = "gfx1100", mattr: str = "+real-true16,+wavefrontsize32") -> list[str]: + """Assemble instructions with our DSL, then disassemble with LLVM.""" code_bytes = b''.join(inst.to_bytes() for inst in instructions) - - # Wrap in minimal ELF-compatible assembly with .byte directives - byte_str = ', '.join(f'0x{b:02x}' for b in code_bytes) - asm_src = f".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n.byte {byte_str}\n" - - # Assemble with AMD COMGR and disassemble - lib = HIPCompiler(arch).compile(asm_src) - return parse_disassembly(disassemble(lib, arch)) + return llvm_disasm(code_bytes, mcpu, mattr) class TestIntegration(unittest.TestCase): """Test our DSL output matches LLVM disassembly.""" def test_simple_sop1(self): """Test SOP1 instructions round-trip.""" - instructions = [ - s_mov_b32(s[0], s[1]), - s_mov_b32(s[2], 0), - s_not_b32(s[3], s[4]), - ] + instructions = [s_mov_b32(s[0], s[1]), s_mov_b32(s[2], 0), s_not_b32(s[3], s[4])] disasm = assemble_and_disassemble(instructions) self.assertIn('s_mov_b32', disasm[0]) self.assertIn('s_mov_b32', disasm[1]) @@ -57,11 +25,7 @@ class TestIntegration(unittest.TestCase): def test_simple_sop2(self): """Test SOP2 instructions round-trip.""" - instructions = [ - s_add_u32(s[0], s[1], s[2]), - s_sub_u32(s[3], s[4], 10), - s_and_b32(s[5], s[6], s[7]), - ] + instructions = [s_add_u32(s[0], s[1], s[2]), s_sub_u32(s[3], s[4], 10), s_and_b32(s[5], s[6], s[7])] disasm = assemble_and_disassemble(instructions) self.assertIn('s_add_u32', disasm[0]) self.assertIn('s_sub_u32', disasm[1]) @@ -69,33 +33,22 @@ class TestIntegration(unittest.TestCase): def test_simple_vop2(self): """Test VOP2 instructions round-trip.""" - instructions = [ - v_add_f32_e32(v[0], v[1], v[2]), - v_mul_f32_e32(v[3], 1.0, v[4]), # 1.0 is inline constant - v_and_b32_e32(v[5], 10, v[6]), # small inline constant - ] + instructions = [v_add_f32_e32(v[0], v[1], v[2]), v_mul_f32_e32(v[3], 1.0, v[4]), v_and_b32_e32(v[5], 10, v[6])] disasm = assemble_and_disassemble(instructions) self.assertIn('v_add_f32', disasm[0]) self.assertIn('v_mul_f32', disasm[1]) def test_control_flow(self): """Test control flow instructions.""" - instructions = [ - s_waitcnt(simm16=waitcnt(lgkmcnt=0)), - s_endpgm(), - ] + instructions = [s_waitcnt(simm16=waitcnt(lgkmcnt=0)), s_endpgm()] disasm = assemble_and_disassemble(instructions) self.assertIn('s_waitcnt', disasm[0]) self.assertIn('s_endpgm', disasm[1]) def test_memory_ops(self): """Test memory instructions.""" - instructions = [ - s_load_b32(s[0], s[0:1], NULL), - s_waitcnt(simm16=waitcnt(lgkmcnt=0)), - global_store_b32(addr=v[0:1], data=v[2], saddr=OFF), - s_endpgm(), - ] + instructions = [s_load_b32(s[0], s[0:1], NULL), s_waitcnt(simm16=waitcnt(lgkmcnt=0)), global_store_b32(addr=v[0:1], data=v[2], saddr=OFF), + s_endpgm()] disasm = assemble_and_disassemble(instructions) self.assertIn('s_load_b32', disasm[0]) self.assertIn('s_waitcnt', disasm[1]) @@ -103,156 +56,62 @@ class TestIntegration(unittest.TestCase): def test_full_kernel(self): """Test a complete kernel similar to tinygrad output.""" - # Simple kernel: load value, add 1, store back - instructions = [ - # Get thread ID - v_mov_b32_e32(v[0], s[0]), # base addr low - v_mov_b32_e32(v[1], s[1]), # base addr high - # Load value - global_load_b32(vdst=v[2], addr=v[0:1], saddr=OFF), - s_waitcnt(simm16=waitcnt(vmcnt=0)), - # Add 1.0 - v_add_f32_e32(v[2], 1.0, v[2]), - # Store result - global_store_b32(addr=v[0:1], data=v[2], saddr=OFF), - s_endpgm(), - ] + instructions = [v_mov_b32_e32(v[0], s[0]), v_mov_b32_e32(v[1], s[1]), global_load_b32(vdst=v[2], addr=v[0:1], saddr=OFF), + s_waitcnt(simm16=waitcnt(vmcnt=0)), v_add_f32_e32(v[2], 1.0, v[2]), global_store_b32(addr=v[0:1], data=v[2], saddr=OFF), + s_endpgm()] disasm = assemble_and_disassemble(instructions) - # Verify key instructions are present self.assertTrue(any('global_load' in d for d in disasm)) self.assertTrue(any('v_add_f32' in d for d in disasm)) self.assertTrue(any('global_store' in d for d in disasm)) self.assertTrue(any('s_endpgm' in d for d in disasm)) def test_bytes_roundtrip(self): - """Test that our bytes match what AMD assembler produces.""" - from tinygrad.runtime.support.compiler_amd import HIPCompiler - - # Simple instruction + """Test that our bytes match what LLVM assembler produces.""" inst = s_mov_b32(s[0], s[1]) our_bytes = inst.to_bytes() - - # Assemble same instruction with AMD toolchain - asm_src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\ns_mov_b32 s0, s1\n" - compiler = HIPCompiler("gfx1100") - lib = compiler.compile(asm_src) - raw = disassemble(lib) - - for line in raw.splitlines(): - if 's_mov_b32' in line and '//' in line: - # Extract hex bytes from comment: "// 000000001300: BE800001" - comment = line.split('//')[1].strip() - hex_str = comment.split(':')[1].strip() - # Convert big-endian hex string to little-endian bytes - amd_bytes = bytes.fromhex(hex_str)[::-1] # reverse for little-endian - self.assertEqual(our_bytes, amd_bytes, f"Bytes mismatch: ours={our_bytes.hex()} AMD={amd_bytes.hex()}") - return - self.fail("Could not find s_mov_b32 in disassembly") + llvm_bytes = llvm_assemble(["s_mov_b32 s0, s1"], "gfx1100", "+real-true16,+wavefrontsize32")[0] + self.assertEqual(our_bytes, llvm_bytes, f"Bytes mismatch: ours={our_bytes.hex()} LLVM={llvm_bytes.hex()}") class TestTinygradIntegration(unittest.TestCase): - """Test that we can parse disassembled tinygrad kernels.""" + """Test that we can parse tinygrad kernel disassembly.""" + + def _get_kernel_code(self, op_fn) -> bytes: + from tinygrad import Tensor + from tinygrad.codegen import get_program + from tinygrad.renderer.llvmir import AMDLLVMRenderer + from tinygrad.runtime.support.compiler_amd import AMDLLVMCompiler + from tinygrad.runtime.support.elf import elf_loader + from tinygrad.uop.ops import Ops + + result = op_fn(Tensor) + schedule = result.schedule() + sink_items = [si for si in schedule if si.ast.op == Ops.SINK] + assert len(sink_items) > 0, "No SINK in schedule" + renderer = AMDLLVMRenderer('gfx1100') + prg = get_program(sink_items[0].ast, renderer) + lib = AMDLLVMCompiler('gfx1100').compile(prg.src) + return next(s.content for s in elf_loader(lib)[1] if s.name == ".text") def test_simple_add_kernel(self): """Generate a simple add kernel from tinygrad and verify disassembly.""" - from tinygrad import Tensor - from tinygrad.codegen import get_program - from tinygrad.renderer.cstyle import AMDHIPRenderer - from tinygrad.runtime.support.compiler_amd import HIPCompiler - from tinygrad.uop.ops import Ops - - # Create a computation that generates a real kernel - a = Tensor([1.0, 2.0, 3.0, 4.0]).realize() - b = Tensor([5.0, 6.0, 7.0, 8.0]).realize() - c = a + b - - # Get schedule and find SINK - schedule = c.schedule() - sink_items = [si for si in schedule if si.ast.op == Ops.SINK] - self.assertTrue(len(sink_items) > 0, "No SINK in schedule") - - # Generate program - renderer = AMDHIPRenderer('gfx1100') - prg = get_program(sink_items[0].ast, renderer) - self.assertIsNotNone(prg.src) - - # Compile and disassemble - compiler = HIPCompiler('gfx1100') - lib = compiler.compile(prg.src) - raw_disasm = disassemble(lib) - instrs = parse_disassembly(raw_disasm) - - # Verify we got some instructions + code = self._get_kernel_code(lambda T: T([1.0, 2.0, 3.0, 4.0]).realize() + T([5.0, 6.0, 7.0, 8.0]).realize()) + instrs = llvm_disasm(code, "gfx1100", "+real-true16,+wavefrontsize32") self.assertTrue(len(instrs) > 0, "No instructions in disassembly") - # Should have an endpgm self.assertTrue(any('s_endpgm' in i for i in instrs), "Missing s_endpgm") def test_matmul_kernel(self): """Generate a matmul kernel and verify disassembly has expected patterns.""" - from tinygrad import Tensor - from tinygrad.codegen import get_program - from tinygrad.renderer.cstyle import AMDHIPRenderer - from tinygrad.runtime.support.compiler_amd import HIPCompiler - from tinygrad.uop.ops import Ops - - # Create a small matmul - a = Tensor.rand(4, 4).realize() - b = Tensor.rand(4, 4).realize() - c = a @ b - - # Get schedule - schedule = c.schedule() - sink_items = [si for si in schedule if si.ast.op == Ops.SINK] - self.assertTrue(len(sink_items) > 0) - - # Generate and compile - renderer = AMDHIPRenderer('gfx1100') - prg = get_program(sink_items[0].ast, renderer) - compiler = HIPCompiler('gfx1100') - lib = compiler.compile(prg.src) - raw_disasm = disassemble(lib) - instrs = parse_disassembly(raw_disasm) - - # Matmul should have multiply and add instructions + code = self._get_kernel_code(lambda T: T.rand(4, 4).realize() @ T.rand(4, 4).realize()) + instrs = llvm_disasm(code, "gfx1100", "+real-true16,+wavefrontsize32") has_mul = any('mul' in i.lower() for i in instrs) has_add = any('add' in i.lower() for i in instrs) self.assertTrue(has_mul or has_add, "Matmul should have mul/add ops") def test_disasm_to_bytes_roundtrip(self): - """Parse disassembled instructions and verify we can re-encode some of them.""" - from tinygrad import Tensor - from tinygrad.codegen import get_program - from tinygrad.renderer.cstyle import AMDHIPRenderer - from tinygrad.runtime.support.compiler_amd import HIPCompiler - from tinygrad.uop.ops import Ops - - # Simple kernel - a = Tensor([1.0, 2.0, 3.0, 4.0]).realize() - b = (a * 2.0) - - schedule = b.schedule() - sink_items = [si for si in schedule if si.ast.op == Ops.SINK] - if not sink_items: return # skip if no kernel - - renderer = AMDHIPRenderer('gfx1100') - prg = get_program(sink_items[0].ast, renderer) - compiler = HIPCompiler('gfx1100') - lib = compiler.compile(prg.src) - raw_disasm = disassemble(lib) - - # Find s_endpgm and verify we can encode it - for line in raw_disasm.splitlines(): - if 's_endpgm' in line and '//' in line: - # Extract bytes from comment - comment = line.split('//')[1].strip() - hex_str = comment.split(':')[1].strip() - amd_bytes = bytes.fromhex(hex_str)[::-1] - - # Our encoding - our_inst = s_endpgm() - our_bytes = our_inst.to_bytes() - - self.assertEqual(our_bytes, amd_bytes, f"s_endpgm mismatch: ours={our_bytes.hex()} AMD={amd_bytes.hex()}") - return + """Verify s_endpgm encoding matches between our DSL and LLVM.""" + our_bytes = s_endpgm().to_bytes() + llvm_bytes = llvm_assemble(["s_endpgm"], "gfx1100", "+real-true16,+wavefrontsize32")[0] + self.assertEqual(our_bytes, llvm_bytes, f"s_endpgm mismatch: ours={our_bytes.hex()} LLVM={llvm_bytes.hex()}") if __name__ == "__main__": unittest.main() diff --git a/test/amd/test_llvm.py b/test/amd/test_llvm.py index 308c65081f..3ebbfed0a1 100644 --- a/test/amd/test_llvm.py +++ b/test/amd/test_llvm.py @@ -8,11 +8,11 @@ Only compute-relevant instruction formats are tested. Graphics-only formats not - VIMAGE/VSAMPLE: image sampling instructions (RDNA4) - VBUFFER: buffer instructions (RDNA4) """ -import unittest, re, subprocess, functools +import unittest, re, functools from tinygrad.helpers import fetch from test.amd.disasm import disasm from tinygrad.renderer.amd import decode_inst, detect_format -from test.amd.helpers import get_llvm_mc, get_target, get_mattr +from test.amd.helpers import llvm_assemble, llvm_filter_valid_asm, get_target, get_mattr LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/llvmorg-21.1.0/llvm/test/MC/AMDGPU" @@ -74,42 +74,13 @@ def _get_tests_uncached(f: str, arch: str) -> list[tuple[str, bytes]]: # Exclude v_interp_* (graphics-only, not on CDNA) if arch == "cdna": tests = [(asm, data) for asm, data in tests if not asm.startswith('v_interp_')] # Filter out tests where original ASM isn't valid on target (e.g., gfx9 tests with gfx942/gfx950 constraints) - if arch == "cdna" and not ('gfx942' in f or 'gfx950' in f or 'gfx90a' in f): tests = _filter_valid_asm(tests, arch) + if arch == "cdna" and not ('gfx942' in f or 'gfx950' in f or 'gfx90a' in f): + tests = llvm_filter_valid_asm(tests, get_target(arch), get_mattr(arch)) return tests @functools.cache def _get_tests(f: str, arch: str) -> list[tuple[str, bytes]]: return _get_tests_uncached(f, arch) -def _compile_asm_batch(instrs: list[str], arch: str = "rdna3", mcpu: str|None = None) -> list[bytes]: - if not instrs: return [] - mcpu, mattr = mcpu or get_target(arch), get_mattr(arch) - result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-show-encoding'], - input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30) - if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}") - return [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', '')) - for line in result.stdout.split('\n') if 'encoding:' in line] - -def _filter_valid_asm(tests: list[tuple[str, bytes]], arch: str) -> list[tuple[str, bytes]]: - """Filter out tests where the original ASM isn't valid on the target (e.g., gfx9 tests with gfx942/gfx950 constraints).""" - if not tests: return [] - mcpu = get_target(arch) - # Batch assemble all instructions, parse stderr to find which lines failed - instrs = [asm for asm, _ in tests] - result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-show-encoding'], - input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True, timeout=30) - # Parse error lines from stderr (format: ":N:..." where N is 1-indexed, line 1 is ".text") - failed_lines = set() - for line in result.stderr.split('\n'): - if m := re.match(r':(\d+):', line): failed_lines.add(int(m.group(1)) - 1) # -1 for .text, so line 2 -> index 1 -> tests[0] - # Also filter out tests where LLVM roundtrip doesn't match original (reserved bits set in original) - valid = [(asm, data) for i, (asm, data) in enumerate(tests) if (i + 1) not in failed_lines] - if not valid: return [] - llvm_result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-show-encoding'], - input=".text\n" + "\n".join(asm for asm, _ in valid) + "\n", capture_output=True, text=True, timeout=30) - llvm_bytes = [bytes.fromhex(line.split('encoding:')[1].strip()[1:-1].replace('0x', '').replace(',', '').replace(' ', '')) - for line in llvm_result.stdout.split('\n') if 'encoding:' in line] - return [(asm, data) for (asm, data), lb in zip(valid, llvm_bytes) if lb == data] - def _make_test(f: str, arch: str, test_type: str): def test(self): tests = _get_tests(f, arch) @@ -160,7 +131,7 @@ def _make_test(f: str, arch: str, test_type: str): print(f"{name}: {len(to_test)} passed, {skipped} skipped") self.assertEqual(skipped, 0, f"{name}: {skipped} tests skipped, expected 0") # Compare disasm->reassemble with original encoding (filter reserved bit cases where LLVM can't reproduce) - llvm_bytes = _compile_asm_batch([t[1] for t in to_test], arch, mcpu) + llvm_bytes = llvm_assemble([t[1] for t in to_test], mcpu, get_mattr(arch)) valid = [(enc, d, llvm) for (enc, d), llvm in zip(to_test, llvm_bytes) if llvm == enc] print(f"{name}: {len(valid)}/{len(to_test)} matched LLVM encoding") for enc, _, llvm in valid: self.assertEqual(llvm, enc) diff --git a/test/amd/test_mockgpu_invalid.py b/test/amd/test_mockgpu_invalid.py index d85954b9a8..d1d86e7ff5 100644 --- a/test/amd/test_mockgpu_invalid.py +++ b/test/amd/test_mockgpu_invalid.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """Test that invalid instructions raise exceptions through the mock GPU stack.""" -import unittest, subprocess, os, time +import unittest, subprocess, os, sys, time class TestMockGPUInvalidInstruction(unittest.TestCase): def test_unsupported_instruction_raises(self): @@ -43,7 +43,7 @@ dev.synchronize() env["HCQDEV_WAIT_TIMEOUT_MS"] = "10000" st = time.perf_counter() - result = subprocess.run(["python", "-c", test_code], env=env, capture_output=True, text=True, timeout=60) + result = subprocess.run([sys.executable, "-c", test_code], env=env, capture_output=True, text=True, timeout=60) elapsed = time.perf_counter() - st self.assertNotEqual(result.returncode, 0, "should have raised") diff --git a/test/amd/test_rdna3_asm.py b/test/amd/test_rdna3_asm.py index aa29781cb3..a3b451e11e 100644 --- a/test/amd/test_rdna3_asm.py +++ b/test/amd/test_rdna3_asm.py @@ -1,27 +1,14 @@ #!/usr/bin/env python3 -import unittest, subprocess +import unittest from tinygrad.runtime.autogen.amd.rdna3.ins import * -from test.amd.helpers import get_llvm_mc +from test.amd.helpers import llvm_assemble from test.amd.disasm import disasm -def llvm_assemble(asm: str) -> bytes: - """Assemble using llvm-mc and return bytes.""" - result = subprocess.run( - [get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"], - input=asm, capture_output=True, text=True - ) - out = b'' - for line in result.stdout.split('\n'): - if 'encoding:' in line: - enc = line.split('encoding:')[1].strip() - enc = enc.strip('[]').replace('0x', '').replace(',', '') - out += bytes.fromhex(enc) - if not out: raise ValueError(f"no encoding found: {result.stdout} {result.stderr}") - return out +def _asm(asm: str) -> bytes: return llvm_assemble([asm], 'gfx1100', '+real-true16,+wavefrontsize32')[0] class TestRDNA3Asm(unittest.TestCase): def test_full_program(self): - """Test the full program from rdna3fun.py matches llvm-mc output.""" + """Test the full program from rdna3fun.py matches LLVM output.""" program = [ v_bfe_u32(v[1], v[0], 10, 10), s_load_b128(s[4:7], s[0:1], NULL), @@ -45,52 +32,35 @@ class TestRDNA3Asm(unittest.TestCase): s_endpgm(), ] - asm = """ -v_bfe_u32 v1, v0, 10, 10 -s_load_b128 s[4:7], s[0:1], null -v_and_b32_e32 v0, 0x3FF, v0 -s_mulk_i32 s3, 0x87 -v_mad_u64_u32 v[1:2], null, s2, 3, v[1:2] -v_mul_u32_u24_e32 v0, 45, v0 -v_ashrrev_i32_e32 v2, 31, v1 -v_add3_u32 v0, v0, s3, v1 -v_lshlrev_b64 v[2:3], 2, v[1:2] -v_ashrrev_i32_e32 v1, 31, v0 -v_lshlrev_b64 v[0:1], 2, v[0:1] -s_waitcnt lgkmcnt(0) -v_add_co_u32 v2, vcc_lo, s6, v2 -v_add_co_ci_u32_e32 v3, vcc_lo, s7, v3, vcc_lo -v_add_co_u32 v0, vcc_lo, s4, v0 -global_load_b32 v2, v[2:3], off -v_add_co_ci_u32_e32 v1, vcc_lo, s5, v1, vcc_lo -s_waitcnt vmcnt(0) -global_store_b32 v[0:1], v2, off -s_endpgm -""" - expected = llvm_assemble(asm) - for inst,rt in zip(program, asm.strip().split("\n")): print(f"{disasm(inst):50s} {rt}") - actual = b''.join(inst.to_bytes() for inst in program) - self.assertEqual(actual, expected) + asm_lines = [ + "v_bfe_u32 v1, v0, 10, 10", "s_load_b128 s[4:7], s[0:1], null", "v_and_b32_e32 v0, 0x3FF, v0", + "s_mulk_i32 s3, 0x87", "v_mad_u64_u32 v[1:2], null, s2, 3, v[1:2]", "v_mul_u32_u24_e32 v0, 45, v0", + "v_ashrrev_i32_e32 v2, 31, v1", "v_add3_u32 v0, v0, s3, v1", "v_lshlrev_b64 v[2:3], 2, v[1:2]", + "v_ashrrev_i32_e32 v1, 31, v0", "v_lshlrev_b64 v[0:1], 2, v[0:1]", "s_waitcnt lgkmcnt(0)", + "v_add_co_u32 v2, vcc_lo, s6, v2", "v_add_co_ci_u32_e32 v3, vcc_lo, s7, v3, vcc_lo", + "v_add_co_u32 v0, vcc_lo, s4, v0", "global_load_b32 v2, v[2:3], off", + "v_add_co_ci_u32_e32 v1, vcc_lo, s5, v1, vcc_lo", "s_waitcnt vmcnt(0)", + "global_store_b32 v[0:1], v2, off", "s_endpgm", + ] + expected = llvm_assemble(asm_lines, 'gfx1100', '+real-true16,+wavefrontsize32') + for inst, rt in zip(program, asm_lines): print(f"{disasm(inst):50s} {rt}") + for inst, exp in zip(program, expected): self.assertEqual(inst.to_bytes(), exp) def test_sop2_s_add_u32(self): inst = SOP2(SOP2Op.S_ADD_U32, s[3], s[0], s[1]) - expected = llvm_assemble("s_add_u32 s3, s0, s1") - self.assertEqual(inst.to_bytes(), expected) + self.assertEqual(inst.to_bytes(), _asm("s_add_u32 s3, s0, s1")) def test_vop2_v_and_b32_inline_const(self): inst = v_and_b32_e32(v[0], 10, v[0]) - expected = llvm_assemble("v_and_b32_e32 v0, 10, v0") - self.assertEqual(inst.to_bytes(), expected) + self.assertEqual(inst.to_bytes(), _asm("v_and_b32_e32 v0, 10, v0")) def test_sopp_s_endpgm(self): inst = s_endpgm() - expected = llvm_assemble("s_endpgm") - self.assertEqual(inst.to_bytes(), expected) + self.assertEqual(inst.to_bytes(), _asm("s_endpgm")) def test_sop1_s_mov_b32(self): inst = s_mov_b32(s[0], s[1]) - expected = llvm_assemble("s_mov_b32 s0, s1") - self.assertEqual(inst.to_bytes(), expected) + self.assertEqual(inst.to_bytes(), _asm("s_mov_b32 s0, s1")) if __name__ == "__main__": unittest.main() diff --git a/test/amd/test_roundtrip.py b/test/amd/test_roundtrip.py index 6d0bcbd96d..d35f85a41a 100644 --- a/test/amd/test_roundtrip.py +++ b/test/amd/test_roundtrip.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 """Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match.""" -import unittest, io, sys, re, subprocess, os +import unittest, io, sys, re from tinygrad import Device from tinygrad.renderer.amd import detect_format -from test.amd.helpers import get_llvm_mc, get_llvm_objdump, get_target, get_mattr +from test.amd.helpers import llvm_assemble, llvm_disasm, get_target, get_mattr from test.amd.disasm import disasm def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]: @@ -31,45 +31,18 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]: def compile_asm(instr: str, arch: str = 'rdna3') -> bytes: """Compile a single instruction using LLVM.""" - return compile_asm_batch([instr], arch)[0] + return llvm_assemble([instr], get_target(arch), get_mattr(arch))[0] def compile_asm_batch(instrs: list[str], arch: str = 'rdna3') -> list[bytes]: - """Compile multiple instructions with a single llvm-mc call.""" - if not instrs: return [] - result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={get_target(arch)}', f'-mattr={get_mattr(arch)}', '-show-encoding'], - input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True) - if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}") - encodings = [] - for line in result.stdout.split('\n'): - if 'encoding:' in line: - enc = line.split('encoding:')[1].strip() - if enc.startswith('[') and enc.endswith(']'): - encodings.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', ''))) - if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}") - return encodings + """Compile multiple instructions with a single LLVM emission.""" + return llvm_assemble(instrs, get_target(arch), get_mattr(arch)) def compile_and_disasm_batch(instrs: list[str], arch: str = 'rdna3') -> list[str]: """Compile instructions with LLVM and get LLVM's disassembly.""" - import tempfile if not instrs: return [] mcpu, mattr = get_target(arch), get_mattr(arch) - src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n" - with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f: - obj_path = f.name - try: - result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', f'-mattr={mattr}', '-filetype=obj', '-o', obj_path], - input=src, capture_output=True, text=True) - if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}") - result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True) - if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}") - results: list[str] = [] - for line in result.stdout.splitlines(): - if '//' not in line: continue - instr = line.split('//')[0].strip() - if instr: results.append(instr) - return results[:len(instrs)] - finally: - os.unlink(obj_path) + code = b''.join(llvm_assemble(instrs, mcpu, mattr)) + return llvm_disasm(code, mcpu, mattr)[:len(instrs)] @unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device") class TestTinygradKernelRoundtrip(unittest.TestCase): @@ -174,12 +147,12 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): if our_disasm is None: disasm_skipped += 1 elif idx in disasm_llvm_map: - llvm_disasm = disasm_llvm_map[idx] - if our_disasm == llvm_disasm: + llvm_disasm_str = disasm_llvm_map[idx] + if our_disasm == llvm_disasm_str: disasm_passed += 1 else: disasm_failed += 1 - disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'") + disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm_str}'") else: disasm_skipped += 1 diff --git a/test/amd/test_sqtt_examples.py b/test/amd/test_sqtt_examples.py index a2f511637a..58ad0926ed 100644 --- a/test/amd/test_sqtt_examples.py +++ b/test/amd/test_sqtt_examples.py @@ -89,7 +89,7 @@ def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str): try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None) except Exception as e: exc = e (t:=threading.Thread(target=worker, daemon=True)).start() - t.join(timeout=1) + t.join(timeout=5) if exc is not None: raise exc if t.is_alive(): raise RuntimeError("rocprof decoder timeout") return occupancy_records, wave_insts diff --git a/test/amd/test_sqtt_tables.py b/test/amd/test_sqtt_tables.py index e2fdced0c2..8ced6b4ead 100644 --- a/test/amd/test_sqtt_tables.py +++ b/test/amd/test_sqtt_tables.py @@ -80,6 +80,7 @@ def extract_packet_encodings(): def extract_cdna_packet_sizes(): """Extract CDNA pkt_fmt -> size mapping by running rocprof decoder to populate its hash table.""" + if not _load_lib(): return None from test.amd.test_sqtt_examples import run_rocprof_decoder if not (pkl_path := next((EXAMPLES_DIR / "gfx950").glob("*.pkl"), None)): return None @@ -119,8 +120,7 @@ class TestSQTTMatchesBinary(unittest.TestCase): def test_cdna_packet_sizes(self): """Extract and verify CDNA pkt_fmt -> size mapping from rocprof's hash table.""" if not (EXAMPLES_DIR / "gfx950").exists(): self.skipTest("no CDNA examples") - pkt_sizes = extract_cdna_packet_sizes() - self.assertIsNotNone(pkt_sizes, "failed to extract CDNA packet sizes") + if not (pkt_sizes := extract_cdna_packet_sizes()): self.skipTest("rocprof-trace-decoder not installed") for pkt_fmt, size in CDNA_PKT_SIZES.items(): with self.subTest(pkt_fmt=pkt_fmt): self.assertEqual(pkt_sizes.get(pkt_fmt), size) diff --git a/tinygrad/renderer/amd/emu.py b/tinygrad/renderer/amd/emu.py index db23b65e7a..08637cf578 100644 --- a/tinygrad/renderer/amd/emu.py +++ b/tinygrad/renderer/amd/emu.py @@ -9,37 +9,47 @@ from __future__ import annotations import ctypes, functools, re, platform, subprocess, tempfile from typing import Any, Callable -# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) in MXCSR to match RDNA3 default float mode +# Set/restore DAZ+FTZ (denormals-are-zero + flush-to-zero) to match RDNA3 default float mode +# x86: MXCSR bits DAZ(6)+FTZ(15), ARM64: FPCR bit FZ(24) # Only applied during emulator execution, restored afterward to avoid breaking hypothesis tests @functools.cache -def _get_mxcsr_lib(): - if platform.machine() not in ('x86_64', 'AMD64'): return None - try: +def _get_ftz_lib(): + machine = platform.machine() + if machine in ('x86_64', 'AMD64'): src = b''' -unsigned int get_mxcsr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;} -void set_mxcsr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));} +unsigned int get_fpcr(void){unsigned int m;__asm__ __volatile__("stmxcsr %0":"=m"(m));return m;} +void set_fpcr(unsigned int m){__asm__ __volatile__("ldmxcsr %0"::"m"(m));} ''' + ftz_bits = 0x8040 # DAZ (bit 6) + FTZ (bit 15) + elif machine in ('arm64', 'aarch64'): + src = b''' +unsigned int get_fpcr(void){unsigned long long v;__asm__ __volatile__("mrs %0,fpcr":"=r"(v));return(unsigned int)v;} +void set_fpcr(unsigned int m){unsigned long long v=m;__asm__ __volatile__("msr fpcr,%0"::"r"(v));} +''' + ftz_bits = 1 << 24 # FZ (bit 24) + else: return None, 0 + try: with tempfile.NamedTemporaryFile(suffix='.so', delete=False) as f: subprocess.check_output(['clang', '-shared', '-O2', '-x', 'c', '-', '-o', f.name], input=src) lib = ctypes.CDLL(f.name) - lib.get_mxcsr.restype = ctypes.c_uint32 - lib.set_mxcsr.argtypes = [ctypes.c_uint32] - return lib - except Exception: return None + lib.get_fpcr.restype = ctypes.c_uint32 + lib.set_fpcr.argtypes = [ctypes.c_uint32] + return lib, ftz_bits + except Exception: return None, 0 class _MXCSRContext: """Context manager to set DAZ+FTZ during emulator execution and restore afterward.""" __slots__ = ('_saved',) def __enter__(self): - lib = _get_mxcsr_lib() + lib, ftz_bits = _get_ftz_lib() if lib is None: return self - self._saved = lib.get_mxcsr() - lib.set_mxcsr(self._saved | 0x8040) # DAZ (bit 6) + FTZ (bit 15) + self._saved = lib.get_fpcr() + lib.set_fpcr(self._saved | ftz_bits) return self def __exit__(self, *args): - lib = _get_mxcsr_lib() + lib, _ = _get_ftz_lib() if lib is None or not hasattr(self, '_saved'): return - lib.set_mxcsr(self._saved) + lib.set_fpcr(self._saved) from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType from tinygrad.dtype import dtypes