mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rdna3 test cleanups (#13878)
* rdna3 test cleanups * cleanups * ugh DONT SKIP
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# autogenerated by pcode.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.rdna3.pcode
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
from extra.assembly.rdna3.pcode import *
|
||||
|
||||
@@ -12675,27 +12676,11 @@ VOPCOp_FUNCTIONS = {
|
||||
}
|
||||
|
||||
|
||||
# Manually implemented lane instructions (require special vgpr_write handling)
|
||||
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# VGPR[lane][VDST] = S0.b32 - writes s0 to specified lane's VGPR
|
||||
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
|
||||
def _VOP3Op_V_READLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# D0 = VGPR[lane][SRC0] - reads from specified lane's VGPR
|
||||
rd_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
val = VGPR[rd_lane][src0_idx] if VGPR is not None and rd_lane < len(VGPR) and src0_idx < len(VGPR[rd_lane]) else s0
|
||||
return {'d0': val & 0xffffffff, 'scc': scc}
|
||||
|
||||
def _VOP1Op_V_READFIRSTLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# D0 = VGPR[first_active_lane][SRC0] - reads from first active lane
|
||||
first_lane = 0
|
||||
for i in range(32):
|
||||
if exec_mask & (1 << i):
|
||||
first_lane = i
|
||||
break
|
||||
val = VGPR[first_lane][src0_idx] if VGPR is not None and first_lane < len(VGPR) and src0_idx < len(VGPR[first_lane]) else s0
|
||||
return {'d0': val & 0xffffffff, 'scc': scc}
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
|
||||
COMPILED_FUNCTIONS = {
|
||||
SOP1Op: SOP1Op_FUNCTIONS,
|
||||
@@ -12711,9 +12696,4 @@ COMPILED_FUNCTIONS = {
|
||||
VOPCOp: VOPCOp_FUNCTIONS,
|
||||
}
|
||||
|
||||
# Add lane instructions to their respective dicts
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_READLANE_B32] = _VOP3Op_V_READLANE_B32
|
||||
VOP1Op_FUNCTIONS[VOP1Op.V_READFIRSTLANE_B32] = _VOP1Op_V_READFIRSTLANE_B32
|
||||
|
||||
def get_compiled_functions(): return COMPILED_FUNCTIONS
|
||||
@@ -1,4 +1,5 @@
|
||||
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import ctypes, os
|
||||
from extra.assembly.rdna3.lib import Inst, RawImm
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# library for RDNA3 assembly DSL
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from typing import overload, Annotated, TypeVar, Generic
|
||||
|
||||
@@ -806,6 +806,7 @@ def generate_gen_pcode(output_path: str = "extra/assembly/rdna3/autogen/gen_pcod
|
||||
lines = ['''# autogenerated by pcode.py - do not edit
|
||||
# to regenerate: python -m extra.assembly.rdna3.pcode
|
||||
# ruff: noqa: E501,F405,F403
|
||||
# mypy: ignore-errors
|
||||
from extra.assembly.rdna3.autogen import SOP1Op, SOP2Op, SOPCOp, SOPKOp, SOPPOp, VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp
|
||||
from extra.assembly.rdna3.pcode import *
|
||||
''']
|
||||
@@ -963,29 +964,13 @@ from extra.assembly.rdna3.pcode import *
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
|
||||
# Add manually implemented lane instructions
|
||||
# Add manually implemented V_WRITELANE_B32 (not in PDF pseudocode, requires special vgpr_write handling)
|
||||
lines.append('''
|
||||
# Manually implemented lane instructions (require special vgpr_write handling)
|
||||
# V_WRITELANE_B32: Write scalar to specific lane's VGPR (not in PDF pseudocode)
|
||||
def _VOP3Op_V_WRITELANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# VGPR[lane][VDST] = S0.b32 - writes s0 to specified lane's VGPR
|
||||
wr_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
return {'d0': d0, 'scc': scc, 'vgpr_write': (wr_lane, vdst_idx, s0 & 0xffffffff)}
|
||||
|
||||
def _VOP3Op_V_READLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# D0 = VGPR[lane][SRC0] - reads from specified lane's VGPR
|
||||
rd_lane = s1 & 0x1f # lane select (5 bits for wave32)
|
||||
val = VGPR[rd_lane][src0_idx] if VGPR is not None and rd_lane < len(VGPR) and src0_idx < len(VGPR[rd_lane]) else s0
|
||||
return {'d0': val & 0xffffffff, 'scc': scc}
|
||||
|
||||
def _VOP1Op_V_READFIRSTLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0):
|
||||
# D0 = VGPR[first_active_lane][SRC0] - reads from first active lane
|
||||
first_lane = 0
|
||||
for i in range(32):
|
||||
if exec_mask & (1 << i):
|
||||
first_lane = i
|
||||
break
|
||||
val = VGPR[first_lane][src0_idx] if VGPR is not None and first_lane < len(VGPR) and src0_idx < len(VGPR[first_lane]) else s0
|
||||
return {'d0': val & 0xffffffff, 'scc': scc}
|
||||
VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32
|
||||
''')
|
||||
|
||||
lines.append('COMPILED_FUNCTIONS = {')
|
||||
@@ -994,11 +979,6 @@ def _VOP1Op_V_READFIRSTLANE_B32(s0, s1, s2, d0, scc, vcc, lane, exec_mask, liter
|
||||
if by_cls.get(enum_cls): lines.append(f' {cls_name}: {cls_name}_FUNCTIONS,')
|
||||
lines.append('}')
|
||||
lines.append('')
|
||||
lines.append("# Add lane instructions to their respective dicts")
|
||||
lines.append("VOP3Op_FUNCTIONS[VOP3Op.V_WRITELANE_B32] = _VOP3Op_V_WRITELANE_B32")
|
||||
lines.append("VOP3Op_FUNCTIONS[VOP3Op.V_READLANE_B32] = _VOP3Op_V_READLANE_B32")
|
||||
lines.append("VOP1Op_FUNCTIONS[VOP1Op.V_READFIRSTLANE_B32] = _VOP1Op_V_READFIRSTLANE_B32")
|
||||
lines.append('')
|
||||
lines.append('def get_compiled_functions(): return COMPILED_FUNCTIONS')
|
||||
|
||||
Path(output_path).write_text('\n'.join(lines))
|
||||
|
||||
24
extra/assembly/rdna3/test/helpers.py
Normal file
24
extra/assembly/rdna3/test/helpers.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Shared test helpers for RDNA3 tests."""
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class KernelInfo:
|
||||
code: bytes
|
||||
global_size: tuple[int, int, int]
|
||||
local_size: tuple[int, int, int]
|
||||
buf_idxs: list[int] # indices into shared buffer pool
|
||||
buf_sizes: list[int] # sizes for each buffer index
|
||||
|
||||
# LLVM tool detection (shared across test files)
|
||||
def get_llvm_mc():
|
||||
"""Find llvm-mc executable, preferring newer versions."""
|
||||
for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']:
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-mc not found")
|
||||
|
||||
def get_llvm_objdump():
|
||||
"""Find llvm-objdump executable, preferring newer versions."""
|
||||
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-objdump not found")
|
||||
@@ -7,19 +7,13 @@ from pathlib import Path
|
||||
# This allows generating AMD GPU kernels without requiring real hardware
|
||||
os.environ["AMD"] = "1"
|
||||
os.environ["MOCKGPU"] = "1"
|
||||
os.environ["PYTHON_REMU"] = "1"
|
||||
|
||||
from extra.assembly.rdna3.emu import WaveState, decode_program, step_wave, WAVE_SIZE
|
||||
from extra.assembly.rdna3.test.helpers import KernelInfo
|
||||
|
||||
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
|
||||
|
||||
@dataclass
|
||||
class KernelInfo:
|
||||
code: bytes
|
||||
global_size: tuple[int, int, int]
|
||||
local_size: tuple[int, int, int]
|
||||
buf_idxs: list[int] # indices into shared buffer pool
|
||||
buf_sizes: list[int] # sizes for each buffer index
|
||||
|
||||
def _is_f32_nan(bits: int) -> bool:
|
||||
"""Check if 32-bit value is a NaN (exponent all 1s, mantissa non-zero)."""
|
||||
return (bits & 0x7f800000) == 0x7f800000 and (bits & 0x007fffff) != 0
|
||||
@@ -206,6 +200,7 @@ def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: t
|
||||
for i in range(128): python.set_sgpr(i, rust_after.sgpr[i])
|
||||
for lane in range(n_lanes):
|
||||
for i in range(256): python.set_vgpr(lane, i, rust_after.vgpr[lane][i])
|
||||
assert python.state is not None
|
||||
python.state.pc, python.state.scc, python.state.vcc, python.state.exec_mask = rust_after.pc, rust_after.scc, rust_after.vcc, rust_after.exec_mask
|
||||
|
||||
if rust_result == -1:
|
||||
@@ -347,7 +342,6 @@ def get_kernel_from_tinygrad(op_fn) -> tuple[bytes, tuple[int, int, int], tuple[
|
||||
k = kernels[-1]
|
||||
return k.code, k.global_size, k.local_size, k.buf_sizes
|
||||
|
||||
@unittest.skipUnless(REMU_PATH.exists(), "libremu.so not found")
|
||||
class TestTinygradKernels(unittest.TestCase):
|
||||
"""Compare emulators on real tinygrad-compiled kernels."""
|
||||
|
||||
|
||||
@@ -3,16 +3,7 @@
|
||||
import unittest, re, io, sys, subprocess
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.asm import waitcnt, asm
|
||||
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
|
||||
|
||||
def get_amd_toolchain():
|
||||
"""Check if AMD toolchain is available."""
|
||||
try:
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
HIPCompiler("gfx1100").compile(".text\ns_endpgm")
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
from extra.assembly.rdna3.test.helpers import get_llvm_mc
|
||||
|
||||
def disassemble(lib: bytes, arch: str = "gfx1100") -> str:
|
||||
"""Disassemble ELF binary using tinygrad's compiler, return raw output."""
|
||||
@@ -48,7 +39,6 @@ def assemble_and_disassemble(instructions: list, arch: str = "gfx1100") -> list[
|
||||
lib = HIPCompiler(arch).compile(asm_src)
|
||||
return parse_disassembly(disassemble(lib, arch))
|
||||
|
||||
@unittest.skipUnless(get_amd_toolchain(), "AMD toolchain not available")
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Test our assembler output matches LLVM disassembly."""
|
||||
|
||||
@@ -158,7 +148,6 @@ class TestIntegration(unittest.TestCase):
|
||||
return
|
||||
self.fail("Could not find s_mov_b32 in disassembly")
|
||||
|
||||
@unittest.skipUnless(get_amd_toolchain(), "AMD toolchain not available")
|
||||
class TestAsm(unittest.TestCase):
|
||||
"""Test asm() string parsing."""
|
||||
|
||||
@@ -214,7 +203,7 @@ class TestAsm(unittest.TestCase):
|
||||
def test_asm_vop3_modifiers(self):
|
||||
"""Test asm() with VOP3 modifiers (neg, abs, clamp)."""
|
||||
def get_llvm_encoding(instr: str) -> str:
|
||||
result = subprocess.run([_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'],
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-show-encoding'],
|
||||
input=instr, capture_output=True, text=True)
|
||||
if m := re.search(r'encoding:\s*\[(.*?)\]', result.stdout):
|
||||
return m.group(1).replace('0x','').replace(',','').replace(' ','')
|
||||
@@ -232,7 +221,6 @@ class TestAsm(unittest.TestCase):
|
||||
llvm_hex = get_llvm_encoding(t)
|
||||
self.assertEqual(our_hex, llvm_hex, f"mismatch for: {t}")
|
||||
|
||||
@unittest.skipUnless(get_amd_toolchain(), "AMD toolchain not available")
|
||||
class TestTinygradIntegration(unittest.TestCase):
|
||||
"""Test that we can parse disassembled tinygrad kernels."""
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest, re, subprocess
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.asm import asm
|
||||
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
|
||||
from extra.assembly.rdna3.test.helpers import get_llvm_mc
|
||||
|
||||
LLVM_BASE = "https://raw.githubusercontent.com/llvm/llvm-project/main/llvm/test/MC/AMDGPU"
|
||||
|
||||
@@ -83,7 +83,7 @@ def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
if not instrs: return []
|
||||
asm_text = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=asm_text, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings from output
|
||||
@@ -132,7 +132,7 @@ def _make_disasm_test(name):
|
||||
undocumented = {'smem': {34, 35}, 'sopk': {22, 23}, 'sopp': {8, 58, 59}}
|
||||
|
||||
# First pass: decode all instructions and collect disasm strings
|
||||
to_test = [] # list of (asm_text, data, disasm_str)
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = [] # (asm_text, data, disasm_str, error)
|
||||
skipped = 0
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
if len(data) > fmt_cls._size(): continue
|
||||
@@ -172,14 +172,15 @@ def _make_disasm_test(name):
|
||||
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
||||
|
||||
# Match results back
|
||||
passed, failed, failures = 0, 0, []
|
||||
passed, failed = 0, 0
|
||||
failures: list[str] = []
|
||||
for idx, (asm_text, data, disasm_str, error) in enumerate(to_test):
|
||||
if error:
|
||||
failed += 1; failures.append(f"{error} for {data.hex()}")
|
||||
elif disasm_str is not None and idx in llvm_map:
|
||||
llvm_bytes = llvm_map[idx]
|
||||
if llvm_bytes == data: passed += 1
|
||||
else: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
if llvm_bytes is not None and llvm_bytes == data: passed += 1
|
||||
elif llvm_bytes is not None: failed += 1; failures.append(f"'{disasm_str}': expected={data.hex()} got={llvm_bytes.hex()}")
|
||||
|
||||
print(f"{name.upper()} disasm: {passed} passed, {failed} failed" + (f", {skipped} skipped" if skipped else ""))
|
||||
if failures[:10]: print(" " + "\n ".join(failures[:10]))
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import unittest, subprocess
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.test.test_roundtrip import _get_llvm_mc
|
||||
from extra.assembly.rdna3.test.helpers import get_llvm_mc
|
||||
|
||||
def llvm_assemble(asm: str) -> bytes:
|
||||
"""Assemble using llvm-mc and return bytes."""
|
||||
result = subprocess.run(
|
||||
[_get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
|
||||
[get_llvm_mc(), "-triple=amdgcn", "-mcpu=gfx1100", "-show-encoding"],
|
||||
input=asm, capture_output=True, text=True
|
||||
)
|
||||
out = b''
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
|
||||
import unittest, io, sys, re, subprocess, shutil
|
||||
import unittest, io, sys, re, subprocess, os
|
||||
from extra.assembly.rdna3.autogen import *
|
||||
from extra.assembly.rdna3.lib import Inst
|
||||
from extra.assembly.rdna3.asm import asm
|
||||
|
||||
def _get_llvm_mc():
|
||||
for p in ['llvm-mc', 'llvm-mc-21', 'llvm-mc-20']: # prefer newer llvm-mc
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-mc not found")
|
||||
|
||||
def _get_llvm_objdump():
|
||||
for p in ['llvm-objdump', 'llvm-objdump-21', 'llvm-objdump-20']:
|
||||
if shutil.which(p): return p
|
||||
raise FileNotFoundError("llvm-objdump not found")
|
||||
from extra.assembly.rdna3.test.helpers import get_llvm_mc, get_llvm_objdump
|
||||
|
||||
# Instruction format detection based on encoding bits
|
||||
def detect_format(data: bytes) -> type[Inst] | None:
|
||||
@@ -78,7 +69,7 @@ def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
|
||||
def compile_asm(instr: str, compiler=None) -> bytes:
|
||||
"""Compile a single instruction with llvm-mc and return the machine code bytes."""
|
||||
llvm_mc = _get_llvm_mc()
|
||||
llvm_mc = get_llvm_mc()
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=f".text\n{instr}\n", capture_output=True, text=True)
|
||||
@@ -95,7 +86,7 @@ def compile_asm(instr: str, compiler=None) -> bytes:
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
llvm_mc = _get_llvm_mc()
|
||||
llvm_mc = get_llvm_mc()
|
||||
src = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
@@ -112,7 +103,7 @@ def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
|
||||
return encodings
|
||||
|
||||
def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str | None]:
|
||||
def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]:
|
||||
"""Compile instructions with LLVM and get LLVM's disassembly."""
|
||||
import tempfile, os
|
||||
if not instrs: return []
|
||||
@@ -124,14 +115,14 @@ def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str | None]:
|
||||
obj_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[_get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
# Disassemble with llvm-objdump
|
||||
result = subprocess.run([_get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True)
|
||||
result = subprocess.run([get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
|
||||
# Parse disassembly output
|
||||
results = []
|
||||
results: list[str] = []
|
||||
for line in result.stdout.splitlines():
|
||||
if '//' not in line: continue
|
||||
instr = line.split('//')[0].strip()
|
||||
@@ -156,7 +147,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
|
||||
# First pass: decode all instructions and collect info
|
||||
decoded_instrs = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
|
||||
decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
|
||||
for ki, kernel in enumerate(kernels):
|
||||
offset = 0
|
||||
while offset < len(kernel.code):
|
||||
@@ -178,7 +169,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
reencoded = decoded.to_bytes()
|
||||
our_disasm = decoded.disasm()
|
||||
decode_ok = reencoded == orig_bytes
|
||||
decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))
|
||||
except Exception as e:
|
||||
decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e)))
|
||||
@@ -187,8 +178,8 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
offset += size
|
||||
|
||||
# Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile
|
||||
asm_test_instrs = [] # (idx, our_disasm) for asm test
|
||||
disasm_test_instrs = [] # (idx, our_disasm) for disasm comparison test
|
||||
asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test
|
||||
disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test
|
||||
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
if our_disasm is None: continue
|
||||
@@ -209,7 +200,9 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
decode_passed, decode_failed, decode_skipped = 0, 0, 0
|
||||
asm_passed, asm_failed, asm_skipped = 0, 0, 0
|
||||
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
|
||||
decode_failures, asm_failures, disasm_failures = [], [], []
|
||||
decode_failures: list[str] = []
|
||||
asm_failures: list[str] = []
|
||||
disasm_failures: list[str] = []
|
||||
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
# Decode test
|
||||
@@ -226,18 +219,15 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
asm_skipped += 1
|
||||
elif idx in asm_llvm_map:
|
||||
llvm_bytes = asm_llvm_map[idx]
|
||||
if llvm_bytes is None:
|
||||
try:
|
||||
our_bytes = asm(our_disasm).to_bytes()
|
||||
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
|
||||
asm_passed += 1
|
||||
else:
|
||||
asm_failed += 1
|
||||
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
|
||||
except Exception:
|
||||
asm_skipped += 1
|
||||
else:
|
||||
try:
|
||||
our_bytes = asm(our_disasm).to_bytes()
|
||||
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
|
||||
asm_passed += 1
|
||||
else:
|
||||
asm_failed += 1
|
||||
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
|
||||
except Exception:
|
||||
asm_skipped += 1
|
||||
else:
|
||||
asm_skipped += 1
|
||||
|
||||
@@ -246,9 +236,7 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
disasm_skipped += 1
|
||||
elif idx in disasm_llvm_map:
|
||||
llvm_disasm = disasm_llvm_map[idx]
|
||||
if llvm_disasm is None:
|
||||
disasm_skipped += 1
|
||||
elif our_disasm == llvm_disasm:
|
||||
if our_disasm == llvm_disasm:
|
||||
disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
|
||||
Reference in New Issue
Block a user