mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
* add sqtt support to the emulator * more sqtt * cleanup * cleanups * simpler tests * some decent tests * test branch
109 lines
4.8 KiB
Python
109 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Tests for SQTT encoder: verifies the emulator produces correct SQTT traces for known kernels.
|
|
|
|
Run with: AMD=1 MOCKGPU=1 python -m pytest test/amd/test_sqtt_encoder.py -v
|
|
"""
|
|
import ctypes, unittest
|
|
from tinygrad.helpers import Context
|
|
from tinygrad.renderer.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp
|
|
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
|
|
|
def _run_kernel(instructions: list, lx=1, ly=1, lz=1, gx=1, gy=1, gz=1, args_ptr=0) -> bytes:
|
|
"""Assemble instructions, run on emulator with PROFILE=1, return the SQTT blob."""
|
|
from test.mockgpu.amd.emu import run_asm, sqtt_traces
|
|
code = b''.join(inst.to_bytes() for inst in instructions)
|
|
buf = (ctypes.c_char * len(code))(*code)
|
|
lib = ctypes.addressof(buf)
|
|
sqtt_traces.clear()
|
|
with Context(PROFILE=1):
|
|
run_asm(lib, len(code), gx, gy, gz, lx, ly, lz, args_ptr)
|
|
assert len(sqtt_traces) == 1, f"expected 1 trace, got {len(sqtt_traces)}"
|
|
return sqtt_traces.pop()
|
|
|
|
class TestSQTTEncoder(unittest.TestCase):
|
|
|
|
def test_simple_salu(self):
|
|
"""A simple s_mov + s_endpgm kernel emits SALU INST packet."""
|
|
blob = _run_kernel([s_mov_b32(s[0], 42), s_endpgm()])
|
|
packets = list(decode(blob))
|
|
inst_pkts = [p for p in packets if isinstance(p, INST)]
|
|
self.assertEqual(len(inst_pkts), 1)
|
|
self.assertEqual(inst_pkts[0].op, InstOp.SALU)
|
|
|
|
def test_valu_emits_valuinst(self):
|
|
"""Regular VALU ops emit VALUINST packets."""
|
|
blob = _run_kernel([v_mov_b32_e32(v[0], 0), v_add_f32_e32(v[1], v[0], v[0]), s_endpgm()])
|
|
packets = list(decode(blob))
|
|
valu_pkts = [p for p in packets if isinstance(p, VALUINST)]
|
|
self.assertEqual(len(valu_pkts), 2)
|
|
# no INST packets for regular VALU
|
|
self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0)
|
|
|
|
def test_waitcnt_emits_immediate(self):
|
|
"""s_waitcnt and s_nop emit IMMEDIATE packets."""
|
|
blob = _run_kernel([s_nop(simm16=0), s_waitcnt(simm16=0), s_endpgm()])
|
|
imm_pkts = [p for p in decode(blob) if isinstance(p, IMMEDIATE)]
|
|
self.assertEqual(len(imm_pkts), 2) # s_nop + s_waitcnt
|
|
|
|
def test_endpgm_skipped(self):
|
|
"""s_endpgm does not emit any packet."""
|
|
blob = _run_kernel([s_endpgm()])
|
|
packets = list(decode(blob))
|
|
self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0)
|
|
self.assertEqual(len([p for p in packets if isinstance(p, IMMEDIATE)]), 0)
|
|
|
|
def test_wave_lifecycle(self):
|
|
"""Every WAVESTART has a matching WAVEEND."""
|
|
blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()])
|
|
packets = list(decode(blob))
|
|
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), sum(1 for p in packets if isinstance(p, WAVEEND)))
|
|
|
|
def test_layout_header(self):
|
|
"""First packet is LAYOUT_HEADER with layout=3."""
|
|
blob = _run_kernel([s_endpgm()])
|
|
packets = list(decode(blob))
|
|
self.assertIsInstance(packets[0], LAYOUT_HEADER)
|
|
self.assertEqual(packets[0].layout, 3)
|
|
|
|
def test_blob_32byte_aligned(self):
|
|
"""SQTT blob is 32-byte aligned."""
|
|
blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_endpgm()])
|
|
self.assertEqual(len(blob) % 32, 0)
|
|
|
|
def test_multiple_waves(self):
|
|
"""Multiple wavefronts each get their own WAVESTART/WAVEEND."""
|
|
blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()], lx=64) # 64 threads = 2 waves (WAVE_SIZE=32)
|
|
packets = list(decode(blob))
|
|
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), 2)
|
|
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVEEND)), 2)
|
|
|
|
def test_branch_taken_and_not_taken(self):
|
|
"""A loop with s_cbranch_scc1 emits JUMP when taken, JUMP_NO on final iteration."""
|
|
# s[0] = 2; loop: s[0] -= 1; cmp s[0] != 0 (SCC=1 if true); cbranch_scc1 loop; endpgm
|
|
# iteration 1: s[0]=2→1, SCC=1 (1!=0), branch taken (JUMP)
|
|
# iteration 2: s[0]=1→0, SCC=0 (0==0), branch not taken (JUMP_NO)
|
|
blob = _run_kernel([s_mov_b32(s[0], 2), s_sub_u32(s[0], s[0], 1), s_cmp_lg_u32(s[0], 0), s_cbranch_scc1(simm16=-3), s_endpgm()])
|
|
inst_pkts = [p for p in decode(blob) if isinstance(p, INST)]
|
|
ops = [p.op for p in inst_pkts]
|
|
self.assertIn(InstOp.JUMP, ops)
|
|
self.assertIn(InstOp.JUMP_NO, ops)
|
|
|
|
def test_timestamps_monotonic(self):
|
|
"""Timestamps are monotonically non-decreasing."""
|
|
blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_mov_b32(s[2], 2), s_endpgm()])
|
|
times = [p._time for p in decode(blob)]
|
|
self.assertEqual(times, sorted(times))
|
|
|
|
def test_no_trace_without_profile(self):
|
|
"""No SQTT trace is emitted when PROFILE=0."""
|
|
from test.mockgpu.amd.emu import run_asm, sqtt_traces
|
|
code = s_endpgm().to_bytes()
|
|
buf = (ctypes.c_char * len(code))(*code)
|
|
sqtt_traces.clear()
|
|
with Context(PROFILE=0):
|
|
run_asm(ctypes.addressof(buf), len(code), 1, 1, 1, 1, 1, 1, 0)
|
|
self.assertEqual(len(sqtt_traces), 0)
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|