diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fa15d50f90..4807cc465b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -672,7 +672,9 @@ jobs: wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-21 main" | sudo tee /etc/apt/sources.list.d/llvm.list sudo apt-get update - sudo apt-get install llvm-21 llvm-21-tools + sudo apt-get install llvm-21 llvm-21-tools cloc + - name: RDNA3 Line Count + run: cloc --by-file extra/assembly/rdna3/*.py - name: Run RDNA3 emulator tests run: python -m pytest -n=auto extra/assembly/rdna3/ --durations 20 - name: Install pdfplumber diff --git a/extra/assembly/rdna3/autogen/gen_pcode.py b/extra/assembly/rdna3/autogen/gen_pcode.py index c3fda6de19..9366cffdd0 100644 --- a/extra/assembly/rdna3/autogen/gen_pcode.py +++ b/extra/assembly/rdna3/autogen/gen_pcode.py @@ -1,6 +1,7 @@ # autogenerated by pcode.py - do not edit # to regenerate: python -m extra.assembly.rdna3.pcode # ruff: noqa: E501,F405,F403 +# mypy: ignore-errors from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp from extra.assembly.rdna3.pcode import * @@ -12675,27 +12676,11 @@ VOPCOp_FUNCTIONS = { } -# Manually implemented lane instructions (require special vgpr_write handling) +# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode) def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): - # VGPR[lane][VDST] = S0.b32 - writes s0 to specified lane's VGPR wr_lane = s1 & 0x1f # lane select (5 bits for wave32) return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)} - -def _VOP3Op_V_READLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): - # D0 = VGPR[lane][SRC0] - reads from specified lane's VGPR - rd_lane = s1 & 0x1f # lane select (5 bits for wave32) - val = VGPR[rd_lane][src0_idx] if VGPR is not None and rd_lane < len(VGPR) and src0_idx < len(VGPR[rd_lane]) else s0 - return {'d0': val & 0xffffffff, 'scc': scc} - -def _VOP1Op_V_READFIRSTLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): - # D0 = VGPR[first_active_lane][SRC0] - reads from first active lane - first_lane = 0 - for i in range(32): - if exec_mask & (1 << i): - first_lane = i - break - val = VGPR[first_lane][src0_idx] if VGPR is not None and first_lane < len(VGPR) and src0_idx < len(VGPR[first_lane]) else s0 - return {'d0': val & 0xffffffff, 'scc': scc} +VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32 COMPILED_FUNCTIONS = { SOP1Op: SOP1Op_FUNCTIONS, @@ -12711,9 +12696,4 @@ COMPILED_FUNCTIONS = { VOPCOp: VOPCOp_FUNCTIONS, } -# Add lane instructions to their respective dicts -VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32 -VOP3Op_FUNCTIONS[VOP3Op.V_READLANE_B32] = _VOP3Op_V_READLANE_B32 -VOP1Op_FUNCTIONS[VOP1Op.V_READFIRSTLANE_B32] = _VOP1Op_V_READFIRSTLANE_B32 - def get_compiled_functions(): return COMPILED_FUNCTIONS \ No newline at end of file diff --git a/extra/assembly/rdna3/emu.py b/extra/assembly/rdna3/emu.py index 75d62da802..1e1ecebc35 100644 --- a/extra/assembly/rdna3/emu.py +++ b/extra/assembly/rdna3/emu.py @@ -1,4 +1,5 @@ # RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF +# mypy: ignore-errors from __future__ import annotations import ctypes, os from extra.assembly.rdna3.lib import Inst, RawImm diff --git a/extra/assembly/rdna3/lib.py b/extra/assembly/rdna3/lib.py index 7d07572520..bfa2400e33 100644 --- a/extra/assembly/rdna3/lib.py +++ b/extra/assembly/rdna3/lib.py @@ -1,4 +1,5 @@ # library for RDNA3 assembly DSL +# mypy: ignore-errors from __future__ import annotations from enum import IntEnum from typing import overload, Annotated, TypeVar, Generic diff --git a/extra/assembly/rdna3/pcode.py b/extra/assembly/rdna3/pcode.py index 1da20cc73d..cb0f1967c1 100644 --- a/extra/assembly/rdna3/pcode.py +++ b/extra/assembly/rdna3/pcode.py @@ -806,6 +806,7 @@ def generate_gen_pcode(output_path: str = "extra/assembly/rdna3/autogen/gen_pcod lines = ['''# autogenerated by pcode.py - do not edit # to regenerate: python -m extra.assembly.rdna3.pcode # ruff: noqa: E501,F405,F403 +# mypy: ignore-errors from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp from extra.assembly.rdna3.pcode import * '''] @@ -963,29 +964,13 @@ from extra.assembly.rdna3.pcode import * lines.append('}') lines.append('') - # Add manually implemented lane instructions + # Add manually implemented V_WRITELANE_B32 (not in PDF pseudocode, requires special vgpr_write handling) lines.append(''' -# Manually implemented lane instructions (require special vgpr_write handling) +# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode) def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): - # VGPR[lane][VDST] = S0.b32 - writes s0 to specified lane's VGPR wr_lane = s1 & 0x1f # lane select (5 bits for wave32) return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)} - -def _VOP3Op_V_READLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): - # D0 = VGPR[lane][SRC0] - reads from specified lane's VGPR - rd_lane = s1 & 0x1f # lane select (5 bits for wave32) - val = VGPR[rd_lane][src0_idx] if VGPR is not None and rd_lane < len(VGPR) and src0_idx < len(VGPR[rd_lane]) else s0 - return {'d0': val & 0xffffffff, 'scc': scc} - -def _VOP1Op_V_READFIRSTLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0): - # D0 = VGPR[first_active_lane][SRC0] - reads from first active lane - first_lane = 0 - for i in range(32): - if exec_mask & (1 << i): - first_lane = i - break - val = VGPR[first_lane][src0_idx] if VGPR is not None and first_lane < len(VGPR) and src0_idx < len(VGPR[first_lane]) else s0 - return {'d0': val & 0xffffffff, 'scc': scc} +VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32 ''') lines.append('COMPILED_FUNCTIONS = {') @@ -994,11 +979,6 @@ def _VOP1Op_V_READFIRSTLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, liter if by_cls.get(enum_cls): lines.append(f' {cls_name}: {cls_name}_FUNCTIONS,') lines.append('}') lines.append('') - lines.append("# Add lane instructions to their respective dicts") - lines.append("VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32") - lines.append("VOP3Op_FUNCTIONS[VOP3Op.V_READLANE_B32] = _VOP3Op_V_READLANE_B32") - lines.append("VOP1Op_FUNCTIONS[VOP1Op.V_READFIRSTLANE_B32] = _VOP1Op_V_READFIRSTLANE_B32") - lines.append('') lines.append('def get_compiled_functions(): return COMPILED_FUNCTIONS') Path(output_path).write_text('\n'.join(lines)) diff --git a/extra/assembly/rdna3/test/helpers.py b/extra/assembly/rdna3/test/helpers.py new file mode 100644 index 0000000000..cb98cd4968 --- /dev/null +++ b/extra/assembly/rdna3/test/helpers.py @@ -0,0 +1,24 @@ +"""Shared test helpers for RDNA3 tests.""" +import shutil +from dataclasses import dataclass + +@dataclass +class KernelInfo: + code: bytes + global_size: tuple[int, int, int] + local_size: tuple[int, int, int] + buf_idxs: list[int] # indices into shared buffer pool + buf_sizes: list[int] # sizes for each buffer index + +# LLVM tool detection (shared across test files) +def get_llvm_mc(): + """Find llvm-mc executable, preferring newer versions.""" + for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']: + if shutil.which(p): return p + raise FileNotFoundError("llvm-mc not found") + +def get_llvm_objdump(): + """Find llvm-objdump executable, preferring newer versions.""" + for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']: + if shutil.which(p): return p + raise FileNotFoundError("llvm-objdump not found") diff --git a/extra/assembly/rdna3/test/test_compare_emulators.py b/extra/assembly/rdna3/test/test_compare_emulators.py index dcb4da6d31..9455028e26 100644 --- a/extra/assembly/rdna3/test/test_compare_emulators.py +++ b/extra/assembly/rdna3/test/test_compare_emulators.py @@ -7,19 +7,13 @@ from pathlib import Path # This allows generating AMD GPU kernels without requiring real hardware os.environ["AMD"] = "1" os.environ["MOCKGPU"] = "1" +os.environ["PYTHON_REMU"] = "1" from extra.assembly.rdna3.emu import WaveState, decode_program, step_wave, WAVE_SIZE +from extra.assembly.rdna3.test.helpers import KernelInfo REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so" -@dataclass -class KernelInfo: - code: bytes - global_size: tuple[int, int, int] - local_size: tuple[int, int, int] - buf_idxs: list[int] # indices into shared buffer pool - buf_sizes: list[int] # sizes for each buffer index - def _is_f32_nan(bits: int) -> bool: """Check if 32-bit value is a NaN (exponent all 1s, mantissa non-zero).""" return (bits & 0x7f800000) == 0x7f800000 and (bits & 0x007fffff) != 0 @@ -206,6 +200,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t for i in range(128): python.set_sgpr(i, rust_after.sgpr[i]) for lane in range(n_lanes): for i in range(256): python.set_vgpr(lane, i, rust_after.vgpr[lane][i]) + assert python.state is not None python.state.pc, python.state.scc, python.state.vcc, python.state.exec_mask = rust_after.pc, rust_after.scc, rust_after.vcc, rust_after.exec_mask if rust_result == -1: @@ -347,7 +342,6 @@ def get_kernel_from_tinygrad(op_fn) -> tuple[bytes, tuple[int, int, int], tuple[ k = kernels[-1] return k.code, k.global_size, k.local_size, k.buf_sizes -@unittest.skipUnless(REMU_PATH.exists(), "libremu.so not found") class TestTinygradKernels(unittest.TestCase): """Compare emulators on real tinygrad-compiled kernels.""" diff --git a/extra/assembly/rdna3/test/test_integration.py b/extra/assembly/rdna3/test/test_integration.py index 1b0f8c6c64..443f5c923c 100644 --- a/extra/assembly/rdna3/test/test_integration.py +++ b/extra/assembly/rdna3/test/test_integration.py @@ -3,16 +3,7 @@ import unittest, re, io, sys, subprocess from extra.assembly.rdna3.autogen import * from extra.assembly.rdna3.asm import waitcnt, asm -from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc - -def get_amd_toolchain(): - """Check if AMD toolchain is available.""" - try: - from tinygrad.runtime.support.compiler_amd import HIPCompiler - HIPCompiler("gfx1100").compile(".text\ns_endpgm") - return True - except Exception: - return False +from extra.assembly.rdna3.test.helpers import get_llvm_mc def disassemble(lib: bytes, arch: str = "gfx1100") -> str: """Disassemble ELF binary using tinygrad's compiler, return raw output.""" @@ -48,7 +39,6 @@ def assemble_and_disassemble(instructions: list, arch: str = "gfx1100") -> list[ lib = HIPCompiler(arch).compile(asm_src) return parse_disassembly(disassemble(lib, arch)) -@unittest.skipUnless(get_amd_toolchain(), "AMD toolchain not available") class TestIntegration(unittest.TestCase): """Test our assembler output matches LLVM disassembly.""" @@ -158,7 +148,6 @@ class TestIntegration(unittest.TestCase): return self.fail("Could not find s_mov_b32 in disassembly") -@unittest.skipUnless(get_amd_toolchain(), "AMD toolchain not available") class TestAsm(unittest.TestCase): """Test asm() string parsing.""" @@ -214,7 +203,7 @@ class TestAsm(unittest.TestCase): def test_asm_vop3_modifiers(self): """Test asm() with VOP3 modifiers (neg, abs, clamp).""" def get_llvm_encoding(instr: str) -> str: - result = subprocess.run([_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'], + result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'], input=instr, capture_output=True, text=True) if m := re.search(r'encoding:\s*\[(.*?)\]', result.stdout): return m.group(1).replace('0x','').replace(',','').replace(' ','') @@ -232,7 +221,6 @@ class TestAsm(unittest.TestCase): llvm_hex = get_llvm_encoding(t) self.assertEqual(our_hex, llvm_hex, f"mismatch for: {t}") -@unittest.skipUnless(get_amd_toolchain(), "AMD toolchain not available") class TestTinygradIntegration(unittest.TestCase): """Test that we can parse disassembled tinygrad kernels.""" diff --git a/extra/assembly/rdna3/test/test_llvm.py b/extra/assembly/rdna3/test/test_llvm.py index fe6b18aa05..fec84f02a0 100644 --- a/extra/assembly/rdna3/test/test_llvm.py +++ b/extra/assembly/rdna3/test/test_llvm.py @@ -4,7 +4,7 @@ import unittest, re, subprocess from tinygrad.helpers import fetch from extra.assembly.rdna3.autogen import * from extra.assembly.rdna3.asm import asm -from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc +from extra.assembly.rdna3.test.helpers import get_llvm_mc LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU" @@ -83,7 +83,7 @@ def compile_asm_batch(instrs: list[str]) -> list[bytes]: if not instrs: return [] asm_text = ".text\n" + "\n".join(instrs) + "\n" result = subprocess.run( - [_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], + [get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], input=asm_text, capture_output=True, text=True, timeout=30) if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}") # Parse all encodings from output @@ -132,7 +132,7 @@ def _make_disasm_test(name): undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}} # First pass: decode all instructions and collect disasm strings - to_test = [] # list of (asm_text, data, disasm_str) + to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error) skipped = 0 for asm_text, data in self.tests.get(name, []): if len(data) > fmt_cls._size(): continue @@ -172,14 +172,15 @@ def _make_disasm_test(name): llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)} # Match results back - passed, failed, failures = 0, 0, [] + passed, failed = 0, 0 + failures: list[str] = [] for idx, (asm_text, data, disasm_str, error) in enumerate(to_test): if error: failed += 1; failures.append(f"{error} for {data.hex()}") elif disasm_str is not None and idx in llvm_map: llvm_bytes = llvm_map[idx] - if llvm_bytes == data: passed += 1 - else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}") + if llvm_bytes is not None and llvm_bytes == data: passed += 1 + elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}") print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else "")) if failures[:10]: print(" " + "\n ".join(failures[:10])) diff --git a/extra/assembly/rdna3/test/test_rdna3_asm.py b/extra/assembly/rdna3/test/test_rdna3_asm.py index 13369bb489..45e7948988 100644 --- a/extra/assembly/rdna3/test/test_rdna3_asm.py +++ b/extra/assembly/rdna3/test/test_rdna3_asm.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 import unittest, subprocess from extra.assembly.rdna3.autogen import * -from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc +from extra.assembly.rdna3.test.helpers import get_llvm_mc def llvm_assemble(asm: str) -> bytes: """Assemble using llvm-mc and return bytes.""" result = subprocess.run( - [_get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"], + [get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"], input=asm, capture_output=True, text=True ) out = b'' diff --git a/extra/assembly/rdna3/test/test_roundtrip.py b/extra/assembly/rdna3/test/test_roundtrip.py index 9867e71d44..c8a21bdb3c 100644 --- a/extra/assembly/rdna3/test/test_roundtrip.py +++ b/extra/assembly/rdna3/test/test_roundtrip.py @@ -1,19 +1,10 @@ #!/usr/bin/env python3 """Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match.""" -import unittest, io, sys, re, subprocess, shutil +import unittest, io, sys, re, subprocess, os from extra.assembly.rdna3.autogen import * from extra.assembly.rdna3.lib import Inst from extra.assembly.rdna3.asm import asm - -def _get_llvm_mc(): - for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']: # prefer newer llvm-mc - if shutil.which(p): return p - raise FileNotFoundError("llvm-mc not found") - -def _get_llvm_objdump(): - for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']: - if shutil.which(p): return p - raise FileNotFoundError("llvm-objdump not found") +from extra.assembly.rdna3.test.helpers import get_llvm_mc, get_llvm_objdump # Instruction format detection based on encoding bits def detect_format(data: bytes) -> type[Inst] | None: @@ -78,7 +69,7 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]: def compile_asm(instr: str, compiler=None) -> bytes: """Compile a single instruction with llvm-mc and return the machine code bytes.""" - llvm_mc = _get_llvm_mc() + llvm_mc = get_llvm_mc() result = subprocess.run( [llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], input=f".text\n{instr}\n", capture_output=True, text=True) @@ -95,7 +86,7 @@ def compile_asm(instr: str, compiler=None) -> bytes: def compile_asm_batch(instrs: list[str]) -> list[bytes]: """Compile multiple instructions with a single llvm-mc call.""" if not instrs: return [] - llvm_mc = _get_llvm_mc() + llvm_mc = get_llvm_mc() src = ".text\n" + "\n".join(instrs) + "\n" result = subprocess.run( [llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'], @@ -112,7 +103,7 @@ def compile_asm_batch(instrs: list[str]) -> list[bytes]: if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}") return encodings -def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str | None]: +def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]: """Compile instructions with LLVM and get LLVM's disassembly.""" import tempfile, os if not instrs: return [] @@ -124,14 +115,14 @@ def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str | None]: obj_path = f.name try: result = subprocess.run( - [_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path], + [get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path], input=src, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}") # Disassemble with llvm-objdump - result = subprocess.run([_get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True) + result = subprocess.run([get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}") # Parse disassembly output - results = [] + results: list[str] = [] for line in result.stdout.splitlines(): if '//' not in line: continue instr = line.split('//')[0].strip() @@ -156,7 +147,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): compiler = HIPCompiler('gfx1100') # First pass: decode all instructions and collect info - decoded_instrs = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) + decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) for ki, kernel in enumerate(kernels): offset = 0 while offset < len(kernel.code): @@ -178,7 +169,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): reencoded = decoded.to_bytes() our_disasm = decoded.disasm() decode_ok = reencoded == orig_bytes - decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}" + decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}" decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)) except Exception as e: decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e))) @@ -187,8 +178,8 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): offset += size # Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile - asm_test_instrs = [] # (idx, our_disasm) for asm test - disasm_test_instrs = [] # (idx, our_disasm) for disasm comparison test + asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test + disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs): if our_disasm is None: continue @@ -209,7 +200,9 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): decode_passed, decode_failed, decode_skipped = 0, 0, 0 asm_passed, asm_failed, asm_skipped = 0, 0, 0 disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0 - decode_failures, asm_failures, disasm_failures = [], [], [] + decode_failures: list[str] = [] + asm_failures: list[str] = [] + disasm_failures: list[str] = [] for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs): # Decode test @@ -226,18 +219,15 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): asm_skipped += 1 elif idx in asm_llvm_map: llvm_bytes = asm_llvm_map[idx] - if llvm_bytes is None: + try: + our_bytes = asm(our_disasm).to_bytes() + if our_bytes[:len(llvm_bytes)] == llvm_bytes: + asm_passed += 1 + else: + asm_failed += 1 + asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}") + except Exception: asm_skipped += 1 - else: - try: - our_bytes = asm(our_disasm).to_bytes() - if our_bytes[:len(llvm_bytes)] == llvm_bytes: - asm_passed += 1 - else: - asm_failed += 1 - asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}") - except Exception: - asm_skipped += 1 else: asm_skipped += 1 @@ -246,9 +236,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase): disasm_skipped += 1 elif idx in disasm_llvm_map: llvm_disasm = disasm_llvm_map[idx] - if llvm_disasm is None: - disasm_skipped += 1 - elif our_disasm == llvm_disasm: + if our_disasm == llvm_disasm: disasm_passed += 1 else: disasm_failed += 1