Files
tinygrad/test/amd/test_sqtt_encoder.py
George Hotz 47d39a6b8b add sqtt support to the emulator (#14791)
* add sqtt support to the emulator

* more sqtt

* cleanup

* cleanups

* simpler tests

* some decent tests

* test branch
2026-02-16 16:48:26 +08:00

109 lines
4.8 KiB
Python

#!/usr/bin/env python3
"""Tests for SQTT encoder: verifies the emulator produces correct SQTT traces for known kernels.
Run with: AMD=1 MOCKGPU=1 python -m pytest test/amd/test_sqtt_encoder.py -v
"""
import ctypes, unittest
from tinygrad.helpers import Context
from tinygrad.renderer.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp
from tinygrad.runtime.autogen.amd.rdna3.ins import *
def _run_kernel(instructions: list, lx=1, ly=1, lz=1, gx=1, gy=1, gz=1, args_ptr=0) -> bytes:
"""Assemble instructions, run on emulator with PROFILE=1, return the SQTT blob."""
from test.mockgpu.amd.emu import run_asm, sqtt_traces
code = b''.join(inst.to_bytes() for inst in instructions)
buf = (ctypes.c_char * len(code))(*code)
lib = ctypes.addressof(buf)
sqtt_traces.clear()
with Context(PROFILE=1):
run_asm(lib, len(code), gx, gy, gz, lx, ly, lz, args_ptr)
assert len(sqtt_traces) == 1, f"expected 1 trace, got {len(sqtt_traces)}"
return sqtt_traces.pop()
class TestSQTTEncoder(unittest.TestCase):
def test_simple_salu(self):
"""A simple s_mov + s_endpgm kernel emits SALU INST packet."""
blob = _run_kernel([s_mov_b32(s[0], 42), s_endpgm()])
packets = list(decode(blob))
inst_pkts = [p for p in packets if isinstance(p, INST)]
self.assertEqual(len(inst_pkts), 1)
self.assertEqual(inst_pkts[0].op, InstOp.SALU)
def test_valu_emits_valuinst(self):
"""Regular VALU ops emit VALUINST packets."""
blob = _run_kernel([v_mov_b32_e32(v[0], 0), v_add_f32_e32(v[1], v[0], v[0]), s_endpgm()])
packets = list(decode(blob))
valu_pkts = [p for p in packets if isinstance(p, VALUINST)]
self.assertEqual(len(valu_pkts), 2)
# no INST packets for regular VALU
self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0)
def test_waitcnt_emits_immediate(self):
"""s_waitcnt and s_nop emit IMMEDIATE packets."""
blob = _run_kernel([s_nop(simm16=0), s_waitcnt(simm16=0), s_endpgm()])
imm_pkts = [p for p in decode(blob) if isinstance(p, IMMEDIATE)]
self.assertEqual(len(imm_pkts), 2) # s_nop + s_waitcnt
def test_endpgm_skipped(self):
"""s_endpgm does not emit any packet."""
blob = _run_kernel([s_endpgm()])
packets = list(decode(blob))
self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0)
self.assertEqual(len([p for p in packets if isinstance(p, IMMEDIATE)]), 0)
def test_wave_lifecycle(self):
"""Every WAVESTART has a matching WAVEEND."""
blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()])
packets = list(decode(blob))
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), sum(1 for p in packets if isinstance(p, WAVEEND)))
def test_layout_header(self):
"""First packet is LAYOUT_HEADER with layout=3."""
blob = _run_kernel([s_endpgm()])
packets = list(decode(blob))
self.assertIsInstance(packets[0], LAYOUT_HEADER)
self.assertEqual(packets[0].layout, 3)
def test_blob_32byte_aligned(self):
"""SQTT blob is 32-byte aligned."""
blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_endpgm()])
self.assertEqual(len(blob) % 32, 0)
def test_multiple_waves(self):
"""Multiple wavefronts each get their own WAVESTART/WAVEEND."""
blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()], lx=64) # 64 threads = 2 waves (WAVE_SIZE=32)
packets = list(decode(blob))
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), 2)
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVEEND)), 2)
def test_branch_taken_and_not_taken(self):
"""A loop with s_cbranch_scc1 emits JUMP when taken, JUMP_NO on final iteration."""
# s[0] = 2; loop: s[0] -= 1; cmp s[0] != 0 (SCC=1 if true); cbranch_scc1 loop; endpgm
# iteration 1: s[0]=2→1, SCC=1 (1!=0), branch taken (JUMP)
# iteration 2: s[0]=1→0, SCC=0 (0==0), branch not taken (JUMP_NO)
blob = _run_kernel([s_mov_b32(s[0], 2), s_sub_u32(s[0], s[0], 1), s_cmp_lg_u32(s[0], 0), s_cbranch_scc1(simm16=-3), s_endpgm()])
inst_pkts = [p for p in decode(blob) if isinstance(p, INST)]
ops = [p.op for p in inst_pkts]
self.assertIn(InstOp.JUMP, ops)
self.assertIn(InstOp.JUMP_NO, ops)
def test_timestamps_monotonic(self):
"""Timestamps are monotonically non-decreasing."""
blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_mov_b32(s[2], 2), s_endpgm()])
times = [p._time for p in decode(blob)]
self.assertEqual(times, sorted(times))
def test_no_trace_without_profile(self):
"""No SQTT trace is emitted when PROFILE=0."""
from test.mockgpu.amd.emu import run_asm, sqtt_traces
code = s_endpgm().to_bytes()
buf = (ctypes.c_char * len(code))(*code)
sqtt_traces.clear()
with Context(PROFILE=0):
run_asm(ctypes.addressof(buf), len(code), 1, 1, 1, 1, 1, 1, 0)
self.assertEqual(len(sqtt_traces), 0)
if __name__ == "__main__":
unittest.main()