Files
tinygrad/extra/assembly/amd/test/test_compare_emulators.py
George Hotz dfb813b760 assembly/amd: add pcode ds ops (#13939)
* assembly/amd: add pcode ds ops

* refactors

* fix ds op

* update autogen

* fix flat bug

* more tests

* fix emu test

* that's a hack

* generic

* fix all tests

* two tests

* fix test failure

* better

* remove __all__
2026-01-01 16:24:13 -05:00

403 lines
19 KiB
Python

# Test to compare Python and Rust RDNA3 emulators by running real tinygrad kernels
import unittest, ctypes, os
from dataclasses import dataclass
from pathlib import Path
# Set environment before any tinygrad imports to use MOCKGPU
# This allows generating AMD GPU kernels without requiring real hardware
os.environ["AMD"] = "1"
os.environ["MOCKGPU"] = "1"
os.environ["PYTHON_REMU"] = "1"
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges, LDSMem
from extra.assembly.amd.test.helpers import KernelInfo
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
def _is_f32_nan(bits: int) -> bool:
"""Check if 32-bit value is a NaN (exponent all 1s, mantissa non-zero)."""
return (bits & 0x7f800000) == 0x7f800000 and (bits & 0x007fffff) != 0
def _vals_equal(a: int, b: int) -> bool:
"""Compare two 32-bit values, treating all NaN bit patterns as equal."""
if a == b: return True
return _is_f32_nan(a) and _is_f32_nan(b)
@dataclass
class StateSnapshot:
pc: int
scc: int
vcc: int
exec_mask: int
sgpr: list[int]
vgpr: list[list[int]]
def diff(self, other: 'StateSnapshot', n_lanes: int, arrow: str = " vs ") -> list[str]:
"""Return list of differences between two states."""
diffs = []
if self.pc != other.pc: diffs.append(f"pc: {self.pc}{arrow}{other.pc}")
if self.scc != other.scc: diffs.append(f"scc: {self.scc}{arrow}{other.scc}")
if self.vcc != other.vcc: diffs.append(f"vcc: 0x{self.vcc:08x}{arrow}0x{other.vcc:08x}")
if self.exec_mask != other.exec_mask: diffs.append(f"exec: 0x{self.exec_mask:08x}{arrow}0x{other.exec_mask:08x}")
for i, (a, b) in enumerate(zip(self.sgpr, other.sgpr)):
# Skip VCC_LO/HI (106/107) and EXEC_LO/HI (126/127) as they alias vcc/exec_mask which are compared separately
if i in (106, 107, 126, 127): continue
if not _vals_equal(a, b): diffs.append(f"sgpr[{i}]: 0x{a:08x}{arrow}0x{b:08x}")
for lane in range(n_lanes):
for i, (a, b) in enumerate(zip(self.vgpr[lane], other.vgpr[lane])):
if not _vals_equal(a, b): diffs.append(f"vgpr[{lane}][{i}]: 0x{a:08x}{arrow}0x{b:08x}")
return diffs
class CStateSnapshot(ctypes.Structure):
_fields_ = [("pc", ctypes.c_uint32), ("scc", ctypes.c_uint32), ("vcc", ctypes.c_uint32), ("exec_mask", ctypes.c_uint32),
("sgpr", ctypes.c_uint32 * 128), ("vgpr", (ctypes.c_uint32 * 256) * 32)]
def to_snapshot(self) -> StateSnapshot:
return StateSnapshot(pc=self.pc, scc=self.scc, vcc=self.vcc, exec_mask=self.exec_mask,
sgpr=list(self.sgpr), vgpr=[list(self.vgpr[i]) for i in range(32)])
class RustEmulator:
def __init__(self):
self.lib = ctypes.CDLL(str(REMU_PATH))
self.lib.wave_create.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32]
self.lib.wave_create.restype = ctypes.c_void_p
self.lib.wave_step.argtypes = [ctypes.c_void_p]
self.lib.wave_step.restype = ctypes.c_int32
self.lib.wave_get_snapshot.argtypes = [ctypes.c_void_p, ctypes.POINTER(CStateSnapshot)]
self.lib.wave_set_sgpr.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32]
self.lib.wave_set_vgpr.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32]
self.lib.wave_init_lds.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
self.lib.wave_free.argtypes = [ctypes.c_void_p]
self.ctx = None
def create(self, kernel: bytes, n_lanes: int):
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
self.ctx = self.lib.wave_create(ctypes.addressof(kernel_buf), len(kernel), n_lanes)
self._kernel_buf = kernel_buf
def step(self) -> int: return self.lib.wave_step(self.ctx)
def set_sgpr(self, idx: int, val: int): self.lib.wave_set_sgpr(self.ctx, idx, val)
def set_vgpr(self, lane: int, idx: int, val: int): self.lib.wave_set_vgpr(self.ctx, lane, idx, val)
def init_lds(self, size: int): self.lib.wave_init_lds(self.ctx, size)
def get_snapshot(self) -> StateSnapshot:
snap = CStateSnapshot()
self.lib.wave_get_snapshot(self.ctx, ctypes.byref(snap))
return snap.to_snapshot()
def free(self):
if self.ctx: self.lib.wave_free(self.ctx); self.ctx = None
class PythonEmulator:
def __init__(self):
self.state: WaveState | None = None
self.program: dict | None = None
self.lds: bytearray | None = None
self.n_lanes = 0
def create(self, kernel: bytes, n_lanes: int):
self.program = decode_program(kernel)
self.state = WaveState()
self.state.exec_mask = (1 << n_lanes) - 1
self.lds = LDSMem(bytearray(65536))
self.n_lanes = n_lanes
def step(self) -> int:
assert self.program is not None and self.state is not None and self.lds is not None
return step_wave(self.program, self.state, self.lds, self.n_lanes)
def set_sgpr(self, idx: int, val: int):
assert self.state is not None
self.state.sgpr[idx] = val & 0xffffffff
def set_vgpr(self, lane: int, idx: int, val: int):
assert self.state is not None
self.state.vgpr[lane][idx] = val & 0xffffffff
def get_snapshot(self) -> StateSnapshot:
assert self.state is not None
return StateSnapshot(pc=self.state.pc, scc=self.state.scc, vcc=self.state.vcc & 0xffffffff,
exec_mask=self.state.exec_mask & 0xffffffff, sgpr=list(self.state.sgpr),
vgpr=[list(self.state.vgpr[i]) for i in range(WAVE_SIZE)])
def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: tuple[int, int, int],
program, max_steps: int, debug: bool, trace_len: int, kernel_idx: int = 0,
max_workgroups: int = 8) -> tuple[bool, str, int]:
"""Run a single kernel through both emulators. Returns (success, message, total_steps)."""
gx, gy, gz = global_size
total_steps = 0
wg_count = 0
for gidz in range(gz):
for gidy in range(gy):
for gidx in range(gx):
if wg_count >= max_workgroups: return True, f"Completed {wg_count} workgroups (limit reached)", total_steps
wg_count += 1
rust = RustEmulator()
python = PythonEmulator()
rust.create(kernel, n_lanes)
python.create(kernel, n_lanes)
# Initialize LDS (64KB, standard size for AMD GPUs)
rust.init_lds(65536)
for emu in (rust, python):
emu.set_sgpr(0, args_ptr & 0xffffffff)
emu.set_sgpr(1, (args_ptr >> 32) & 0xffffffff)
emu.set_sgpr(13, gidx)
emu.set_sgpr(14, gidy)
emu.set_sgpr(15, gidz)
step = 0
trace: list[tuple[int, int, str, StateSnapshot, StateSnapshot]] = []
try:
while step < max_steps:
rust_before = rust.get_snapshot()
python_before = python.get_snapshot()
inst = program.get(python_before.pc)
inst_str = inst.disasm() if inst else f"unknown at PC={python_before.pc}"
trace.append((step, python_before.pc, inst_str, rust_before, python_before))
if len(trace) > trace_len: trace.pop(0)
if debug: print(f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: PC={python_before.pc}, inst={inst_str}")
# Instructions with known Rust emulator bugs - sync Python to Rust after execution
# v_div_scale/v_div_fixup: Rust has different VCC handling
# v_cvt_f16_f32: Rust clears high 16 bits, but hardware (and Python) preserves them
sync_after = any(x in inst_str for x in ('v_div_scale_f32', 'v_div_scale_f64', 'v_div_fixup_f32', 'v_div_fixup_f64',
'v_cvt_f16_f32'))
diffs = rust_before.diff(python_before, n_lanes)
if diffs:
trace_lines = []
for idx, (s, pc, d, rb, pb) in enumerate(trace):
trace_lines.append(f" step {s}: PC={pc:3d} {d}")
if idx < len(trace) - 1:
next_rb, next_pb = trace[idx + 1][3:5]
rust_diffs = rb.diff(next_rb, n_lanes, "->")
python_diffs = pb.diff(next_pb, n_lanes, "->")
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
elif rust_diffs: trace_lines.append(f" python: (no changes)")
else:
# Last traced instruction - compare with current state
rust_diffs = rb.diff(rust_before, n_lanes, "->")
python_diffs = pb.diff(python_before, n_lanes, "->")
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
elif rust_diffs: trace_lines.append(f" python: (no changes)")
trace_str = "\n".join(trace_lines)
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ (rust vs python):\n " + "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}", total_steps
rust_result = rust.step()
python_result = python.step()
if rust_result != python_result:
# Rust returns 1 for unsupported instructions - skip test
if rust_result == 1 and python_result == 0:
raise unittest.SkipTest(f"Rust emulator doesn't support instruction: {inst_str}")
trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace)
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps
# Sync Python state to Rust after instructions with known Rust emulator differences
if sync_after:
rust_after = rust.get_snapshot()
for i in range(128): python.set_sgpr(i, rust_after.sgpr[i])
for lane in range(n_lanes):
for i in range(256): python.set_vgpr(lane, i, rust_after.vgpr[lane][i])
assert python.state is not None
python.state.pc, python.state.scc, python.state.vcc, python.state.exec_mask = rust_after.pc, rust_after.scc, rust_after.vcc, rust_after.exec_mask
if rust_result == -1:
total_steps += step + 1
break
if rust_result == 1:
total_steps += step + 1
break
if rust_result < 0 and rust_result != -2:
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: error code {rust_result}", total_steps
step += 1
else:
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Max steps ({max_steps}) reached", total_steps
finally:
rust.free()
return True, f"Completed {gx*gy*gz} workgroups", total_steps
def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int, int], max_steps: int = 1000,
debug: bool = False, trace_len: int = 10, buf_data: dict[int, bytes] | None = None) -> tuple[bool, str]:
"""Run all kernels through both emulators with shared buffer pool."""
if buf_data is None: buf_data = {}
# Allocate shared buffer pool with padding for over-reads (GPU loads up to 16 bytes at once)
buf_id_to_ptr: dict[int, int] = {}
buffers = []
for buf_id, size in buf_pool.items():
padded_size = ((size + 15) // 16) * 16 + 16 # round up to 16 bytes + extra padding
# Initialize with data from COPY if available
init_data = buf_data.get(buf_id, b'\x00' * padded_size)
init_list = list(init_data) + [0] * (padded_size - len(init_data))
buf = (ctypes.c_uint8 * padded_size)(*init_list[:padded_size])
buffers.append((buf, padded_size))
buf_id_to_ptr[buf_id] = ctypes.addressof(buf)
# Set up valid memory ranges
ranges = {(ctypes.addressof(b), size) for b, size in buffers}
total_steps = 0
for ki, kernel in enumerate(kernels):
# Create args array for this kernel's buffers
args = (ctypes.c_uint64 * len(kernel.buf_idxs))(*[buf_id_to_ptr[bid] for bid in kernel.buf_idxs])
args_ptr = ctypes.addressof(args)
# Update valid ranges to include this args array
kernel_ranges = ranges | {(args_ptr, ctypes.sizeof(args))}
set_valid_mem_ranges(kernel_ranges)
program = decode_program(kernel.code)
n_lanes = kernel.local_size[0] * kernel.local_size[1] * kernel.local_size[2]
ok, msg, steps = run_single_kernel(
kernel.code, min(n_lanes, 32), args_ptr, kernel.global_size,
program, max_steps, debug, trace_len, ki
)
total_steps += steps
if not ok:
return False, msg
return True, f"Completed {len(kernels)} kernels, {total_steps} total steps"
def compare_emulators_with_memory(kernel: bytes, n_lanes: int, buf_sizes: list, max_steps: int = 1000, debug: bool = False,
global_size: tuple[int, int, int] = (1, 1, 1), trace_len: int = 10) -> tuple[bool, str]:
"""Run both emulators with memory set up for tinygrad kernels, executing all workgroups. Legacy wrapper."""
# Allocate buffers
buffers = []
for size in buf_sizes:
buf = (ctypes.c_uint8 * size)(*[0] * size)
buffers.append(buf)
# Create args array with buffer pointers
args = (ctypes.c_uint64 * len(buffers))(*[ctypes.addressof(b) for b in buffers])
args_ptr = ctypes.addressof(args)
# Set up valid memory ranges for Python emulator
ranges = {(ctypes.addressof(b), len(b)) for b in buffers}
ranges.add((args_ptr, ctypes.sizeof(args)))
set_valid_mem_ranges(ranges)
program = decode_program(kernel)
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, program, max_steps, debug, trace_len)
return ok, msg
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], dict[int, bytes]]:
"""Compile a tinygrad operation and extract all kernels with their buffer mappings."""
from tinygrad import Tensor
from tinygrad.runtime.support.elf import elf_loader
out = op_fn(Tensor)
sched = out.schedule()
kernels = []
buf_pool: dict[int, int] = {} # buffer id -> size
buf_data: dict[int, bytes] = {} # buffer id -> initial data from COPY
for ei in sched:
lowered = ei.lower()
if ei.ast.op.name == 'COPY':
# Handle COPY: extract source data to initialize destination buffer
if len(lowered.bufs) >= 2:
dst_buf, src_buf = lowered.bufs[0], lowered.bufs[1]
dst_id = id(dst_buf)
if dst_id not in buf_pool:
buf_pool[dst_id] = dst_buf.nbytes
# Get source data if it's from numpy/CPU
if hasattr(src_buf, 'base') and src_buf.base is not None and hasattr(src_buf.base, '_buf'):
src_data = bytes(src_buf.base._buf)
buf_data[dst_id] = src_data
elif ei.ast.op.name == 'SINK':
if lowered.prg and lowered.prg.p.lib:
lib = bytes(lowered.prg.p.lib)
_, sections, _ = elf_loader(lib)
for sec in sections:
if sec.name == '.text':
buf_idxs = []
buf_sizes = []
for b in lowered.bufs:
buf_id = id(b)
if buf_id not in buf_pool:
buf_pool[buf_id] = b.nbytes
buf_idxs.append(buf_id)
buf_sizes.append(b.nbytes)
kernels.append(KernelInfo(
code=bytes(sec.content),
global_size=tuple(lowered.prg.p.global_size),
local_size=tuple(lowered.prg.p.local_size),
buf_idxs=buf_idxs,
buf_sizes=buf_sizes
))
if not kernels: raise RuntimeError("No kernel found")
return kernels, buf_pool, buf_data
def get_kernel_from_tinygrad(op_fn) -> tuple[bytes, tuple[int, int, int], tuple[int, int, int], list]:
"""Compile a tinygrad operation and extract the last (main) kernel binary. Legacy wrapper."""
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
k = kernels[-1]
return k.code, k.global_size, k.local_size, k.buf_sizes
class TestTinygradKernels(unittest.TestCase):
"""Compare emulators on real tinygrad-compiled kernels."""
def _test_kernel(self, op_fn, max_steps=10000):
kernels, buf_pool, buf_data = get_kernels_from_tinygrad(op_fn)
ok, msg = compare_emulators_multi_kernel(kernels, buf_pool, max_steps=max_steps, buf_data=buf_data)
self.assertTrue(ok, msg)
# Basic ops - consolidated tests covering key instruction patterns
def test_unary_ops(self): self._test_kernel(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu().exp().log().sqrt().reciprocal())
def test_binary_ops(self): self._test_kernel(lambda T: (T([1.0, 2.0]) + T([3.0, 4.0])) * T([0.5, 0.5]) - T([1.0, 1.0]))
def test_trig(self): self._test_kernel(lambda T: T([0.1, 1.0, 3.14, -1.0]*8).sin() + T([0.1, 1.0, 3.14, -1.0]*8).cos())
def test_compare(self): self._test_kernel(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
def test_bitwise(self): self._test_kernel(lambda T: (T([0xF0, 0x0F, 0xFF]*11).int() & T([0x0F, 0x0F, 0x00]*11).int()) | T([1]*33).int())
def test_int_ops(self): self._test_kernel(lambda T: ((T.empty(64).int() + T.empty(64).int()) * T.empty(64).int()).float())
# Reductions
def test_reduce(self): self._test_kernel(lambda T: T.empty(64).sum() + T.empty(64).max())
def test_argmax(self): self._test_kernel(lambda T: T.empty(64).argmax())
# Matmul
def test_gemm(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=100000)
@unittest.skip("Rust emulator crashes on this kernel (assertion failure in thread.rs)")
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(16, 16).half() @ T.empty(16, 16).half(), max_steps=100000)
# Complex ops
def test_softmax(self): self._test_kernel(lambda T: T.empty(16).softmax())
def test_layernorm(self): self._test_kernel(lambda T: T.empty(8, 8).layernorm())
# Memory patterns
def test_memory(self): self._test_kernel(lambda T: T.empty(4, 4).permute(1, 0).contiguous() + T.empty(4, 1).expand(4, 4))
# Cast ops
def test_cast(self): self._test_kernel(lambda T: T.empty(32).half().float() + T.empty(32).int().float())
# Pooling - regression for VCC wave32 mode
def test_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
# Convolution
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 2, 8, 8).conv2d(T.empty(2, 2, 3, 3)), max_steps=50000)
# Regression tests
def test_topk(self): self._test_kernel(lambda T: T.empty(64).topk(3)[0])
def test_interpolate(self): self._test_kernel(lambda T: T.empty(1,2,16,16).relu().cast('uint8').interpolate((8,8), mode="linear"))
def test_index_int64(self):
from tinygrad import dtypes
self._test_kernel(lambda T: T.empty(4, 4)[T.arange(4).cast(dtypes.int64), :])
def test_gelu(self): self._test_kernel(lambda T: T.empty(32, 32).gelu())
def test_cross_entropy(self):
import numpy as np
np.random.seed(0)
classes = np.random.randint(0, 10, (16,), dtype=np.int32).tolist()
x_np = np.random.randn(16, 10).astype(np.float32)
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(16,10) + 0).cross_entropy((T(classes).int().reshape(16) + 0)))
def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf())
if __name__ == "__main__":
unittest.main()