diff --git a/test/amd/test_sqtt_encoder.py b/test/amd/test_sqtt_encoder.py new file mode 100644 index 0000000000..d3044e7dfa --- /dev/null +++ b/test/amd/test_sqtt_encoder.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +"""Tests for SQTT encoder: verifies the emulator produces correct SQTT traces for known kernels. + +Run with: AMD=1 MOCKGPU=1 python -m pytest test/amd/test_sqtt_encoder.py -v +""" +import ctypes, unittest +from tinygrad.helpers import Context +from tinygrad.renderer.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp +from tinygrad.runtime.autogen.amd.rdna3.ins import * + +def _run_kernel(instructions: list, lx=1, ly=1, lz=1, gx=1, gy=1, gz=1, args_ptr=0) -> bytes: + """Assemble instructions, run on emulator with PROFILE=1, return the SQTT blob.""" + from test.mockgpu.amd.emu import run_asm, sqtt_traces + code = b''.join(inst.to_bytes() for inst in instructions) + buf = (ctypes.c_char * len(code))(*code) + lib = ctypes.addressof(buf) + sqtt_traces.clear() + with Context(PROFILE=1): + run_asm(lib, len(code), gx, gy, gz, lx, ly, lz, args_ptr) + assert len(sqtt_traces) == 1, f"expected 1 trace, got {len(sqtt_traces)}" + return sqtt_traces.pop() + +class TestSQTTEncoder(unittest.TestCase): + + def test_simple_salu(self): + """A simple s_mov + s_endpgm kernel emits SALU INST packet.""" + blob = _run_kernel([s_mov_b32(s[0], 42), s_endpgm()]) + packets = list(decode(blob)) + inst_pkts = [p for p in packets if isinstance(p, INST)] + self.assertEqual(len(inst_pkts), 1) + self.assertEqual(inst_pkts[0].op, InstOp.SALU) + + def test_valu_emits_valuinst(self): + """Regular VALU ops emit VALUINST packets.""" + blob = _run_kernel([v_mov_b32_e32(v[0], 0), v_add_f32_e32(v[1], v[0], v[0]), s_endpgm()]) + packets = list(decode(blob)) + valu_pkts = [p for p in packets if isinstance(p, VALUINST)] + self.assertEqual(len(valu_pkts), 2) + # no INST packets for regular VALU + self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0) + + def test_waitcnt_emits_immediate(self): + """s_waitcnt and s_nop emit IMMEDIATE packets.""" + blob = _run_kernel([s_nop(simm16=0), s_waitcnt(simm16=0), s_endpgm()]) + imm_pkts = [p for p in decode(blob) if isinstance(p, IMMEDIATE)] + self.assertEqual(len(imm_pkts), 2) # s_nop + s_waitcnt + + def test_endpgm_skipped(self): + """s_endpgm does not emit any packet.""" + blob = _run_kernel([s_endpgm()]) + packets = list(decode(blob)) + self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0) + self.assertEqual(len([p for p in packets if isinstance(p, IMMEDIATE)]), 0) + + def test_wave_lifecycle(self): + """Every WAVESTART has a matching WAVEEND.""" + blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()]) + packets = list(decode(blob)) + self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), sum(1 for p in packets if isinstance(p, WAVEEND))) + + def test_layout_header(self): + """First packet is LAYOUT_HEADER with layout=3.""" + blob = _run_kernel([s_endpgm()]) + packets = list(decode(blob)) + self.assertIsInstance(packets[0], LAYOUT_HEADER) + self.assertEqual(packets[0].layout, 3) + + def test_blob_32byte_aligned(self): + """SQTT blob is 32-byte aligned.""" + blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_endpgm()]) + self.assertEqual(len(blob) % 32, 0) + + def test_multiple_waves(self): + """Multiple wavefronts each get their own WAVESTART/WAVEEND.""" + blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()], lx=64) # 64 threads = 2 waves (WAVE_SIZE=32) + packets = list(decode(blob)) + self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), 2) + self.assertEqual(sum(1 for p in packets if isinstance(p, WAVEEND)), 2) + + def test_branch_taken_and_not_taken(self): + """A loop with s_cbranch_scc1 emits JUMP when taken, JUMP_NO on final iteration.""" + # s[0] = 2; loop: s[0] -= 1; cmp s[0] != 0 (SCC=1 if true); cbranch_scc1 loop; endpgm + # iteration 1: s[0]=2→1, SCC=1 (1!=0), branch taken (JUMP) + # iteration 2: s[0]=1→0, SCC=0 (0==0), branch not taken (JUMP_NO) + blob = _run_kernel([s_mov_b32(s[0], 2), s_sub_u32(s[0], s[0], 1), s_cmp_lg_u32(s[0], 0), s_cbranch_scc1(simm16=-3), s_endpgm()]) + inst_pkts = [p for p in decode(blob) if isinstance(p, INST)] + ops = [p.op for p in inst_pkts] + self.assertIn(InstOp.JUMP, ops) + self.assertIn(InstOp.JUMP_NO, ops) + + def test_timestamps_monotonic(self): + """Timestamps are monotonically non-decreasing.""" + blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_mov_b32(s[2], 2), s_endpgm()]) + times = [p._time for p in decode(blob)] + self.assertEqual(times, sorted(times)) + + def test_no_trace_without_profile(self): + """No SQTT trace is emitted when PROFILE=0.""" + from test.mockgpu.amd.emu import run_asm, sqtt_traces + code = s_endpgm().to_bytes() + buf = (ctypes.c_char * len(code))(*code) + sqtt_traces.clear() + with Context(PROFILE=0): + run_asm(ctypes.addressof(buf), len(code), 1, 1, 1, 1, 1, 1, 0) + self.assertEqual(len(sqtt_traces), 0) + +if __name__ == "__main__": + unittest.main() diff --git a/test/mockgpu/amd/amdgpu.py b/test/mockgpu/amd/amdgpu.py index 6a15392a72..8a73fd8824 100644 --- a/test/mockgpu/amd/amdgpu.py +++ b/test/mockgpu/amd/amdgpu.py @@ -228,13 +228,21 @@ class PM4Executor(AMDQueue): event_dw = self._next_dword() match (event_dw & 0xFF): # event type case SQTT_EVENTS.THREAD_TRACE_FINISH: + # Get the most recent trace from the emulator (if available) + from test.mockgpu.amd.emu import sqtt_traces + blob = sqtt_traces.pop(0) if sqtt_traces else b'' old_idx = self.gpu.regs.grbm_index for se in range(self.gpu.regs.n_se): self.gpu.regs.grbm_index = 0b011 << 29 | se << 16 # select se, broadcast sa and instance self.gpu.regs[regSQ_THREAD_TRACE_STATUS] = 1 << 12 # FINISH_PENDING==0 FINISH_DONE==1 BUSY==0 - buf = ((self.gpu.regs[regSQ_THREAD_TRACE_BUF0_SIZE]&0xf)<<32|self.gpu.regs[regSQ_THREAD_TRACE_BUF0_BASE])<<12 # per page addressing - fake_used = 0x1000 # fake one page long trace - self.gpu.regs[regSQ_THREAD_TRACE_WPTR] = ((buf+fake_used)//32) & 0x1FFFFFFF + buf_addr = ((self.gpu.regs[regSQ_THREAD_TRACE_BUF0_SIZE]&0xf)<<32|self.gpu.regs[regSQ_THREAD_TRACE_BUF0_BASE])<<12 + + # Use real trace blob for SE 0 (which has itrace enabled), empty blob for other SEs + se_blob = blob if se == 0 else b'' + + # Write blob to trace buffer + if se_blob: ctypes.memmove(buf_addr, se_blob, len(se_blob)) + self.gpu.regs[regSQ_THREAD_TRACE_WPTR] = ((buf_addr + len(se_blob)) // 32) & 0x1FFFFFFF self.gpu.regs.grbm_index = old_idx case _: pass # NOTE: for now most events aren't emulated diff --git a/test/mockgpu/amd/emu.py b/test/mockgpu/amd/emu.py index 89f96e2f70..af0fdb687c 100644 --- a/test/mockgpu/amd/emu.py +++ b/test/mockgpu/amd/emu.py @@ -55,7 +55,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType from tinygrad.dtype import dtypes from tinygrad.device import Buffer, BufferSpec from tinygrad.runtime.autogen import hsa -from tinygrad.helpers import Context, DEBUG, colored +from tinygrad.helpers import Context, DEBUG, PROFILE, colored from tinygrad.engine.realize import get_runner from tinygrad.renderer.amd import decode_inst @@ -71,6 +71,125 @@ from test.mockgpu.amd.pcode import parse_block, _FUNCS MASK32 = 0xFFFFFFFF +# ═══════════════════════════════════════════════════════════════════════════════ +# SQTT TRACE COLLECTION +# ═══════════════════════════════════════════════════════════════════════════════ + +# Global trace storage: populated by run_asm as raw SQTT blobs, consumed by amdgpu.py +sqtt_traces: list[bytes] = [] + +# Encoder primitives +from tinygrad.renderer.amd.sqtt import _build_decode_tables, PACKET_TYPES_RDNA3, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp + +_NIB_COUNTS: dict = {cls: nc for _, (cls, nc, *_) in _build_decode_tables(PACKET_TYPES_RDNA3)[0].items()} + +def _encode_raw(pkt_cls, **kwargs) -> tuple[int, int]: + raw = pkt_cls.encoding.default + for k, v in kwargs.items(): raw = pkt_cls.__dict__[k].set(raw, v) + return raw, _NIB_COUNTS[pkt_cls] + +def _emit_nibbles(nibbles: list[int], pkt_cls, **kwargs): + raw, nc = _encode_raw(pkt_cls, **kwargs) + for i in range(nc): nibbles.append((raw >> (i * 4)) & 0xF) + +def _nibbles_to_bytes(nibbles: list[int]) -> bytes: + result = bytearray() + for i in range(0, len(nibbles), 2): result.append(nibbles[i] | ((nibbles[i + 1] if i + 1 < len(nibbles) else 0) << 4)) + return bytes(result) + +def _init_sqtt_encoder(): + """Initialize and return SQTT encoder state. Called once per dispatch with tracing enabled.""" + from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp as SOPPOp3 + from tinygrad.runtime.autogen.amd.rdna4.enum import SOPPOp as SOPPOp4 + import re + + _SOPP = (ir3.SOPP, ir4.SOPP, irc.SOPP) + _SMEM = (ir3.SMEM, ir4.SMEM, irc.SMEM) + _VALU = (ir3.VOP1, ir3.VOP2, ir3.VOP3, ir3.VOP3P, ir3.VOPC, ir3.VOPD, ir3.VOP3SD, ir3.VOP3_SDST, ir3.VOP1_SDST, + ir4.VOP1, ir4.VOP2, ir4.VOP3, ir4.VOP3P, ir4.VOPC, ir4.VOPD, ir4.VOP3SD, ir4.VOP3_SDST, ir4.VOP1_SDST, + irc.VOP1, irc.VOP2, irc.VOP3, irc.VOP3P, irc.VOPC, irc.VOP3SD, irc.VOP3_SDST) + _DS = (ir3.DS, ir4.DS, irc.DS) + _GLOBAL = (ir3.GLOBAL, ir4.VGLOBAL, irc.GLOBAL) + _FLAT = (ir3.FLAT, ir4.VFLAT, irc.FLAT) + _SCRATCH = (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH) + + # SOPP classification sets + _SOPP_SKIP = {SOPPOp3.S_ENDPGM.value, SOPPOp3.S_ENDPGM_SAVED.value, SOPPOp3.S_ENDPGM_ORDERED_PS_DONE.value, + SOPPOp3.S_DELAY_ALU.value} + _SOPP_IMMEDIATE = {SOPPOp3.S_NOP.value, SOPPOp3.S_CLAUSE.value, SOPPOp3.S_WAITCNT.value, SOPPOp3.S_WAITCNT_DEPCTR.value, + SOPPOp3.S_WAIT_IDLE.value, SOPPOp3.S_WAIT_EVENT.value, SOPPOp3.S_SLEEP.value, + SOPPOp3.S_SET_INST_PREFETCH_DISTANCE.value} + for _op in (SOPPOp4.S_WAIT_ALU, SOPPOp4.S_WAIT_LOADCNT, SOPPOp4.S_WAIT_STORECNT, SOPPOp4.S_WAIT_SAMPLECNT, + SOPPOp4.S_WAIT_BVHCNT, SOPPOp4.S_WAIT_EXPCNT, SOPPOp4.S_WAIT_DSCNT, SOPPOp4.S_WAIT_KMCNT, + SOPPOp4.S_WAIT_LOADCNT_DSCNT, SOPPOp4.S_WAIT_STORECNT_DSCNT): + _SOPP_IMMEDIATE.add(_op.value) + _SOPP_BARRIER = {SOPPOp3.S_BARRIER.value} + if hasattr(SOPPOp4, 'S_BARRIER_WAIT'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_WAIT.value) + if hasattr(SOPPOp4, 'S_BARRIER_LEAVE'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_LEAVE.value) + _SOPP_BRANCH = {SOPPOp3.S_BRANCH.value, SOPPOp3.S_CBRANCH_SCC0.value, SOPPOp3.S_CBRANCH_SCC1.value, + SOPPOp3.S_CBRANCH_VCCZ.value, SOPPOp3.S_CBRANCH_VCCNZ.value, + SOPPOp3.S_CBRANCH_EXECZ.value, SOPPOp3.S_CBRANCH_EXECNZ.value} + + # VALU sub-classification patterns + _VALU_TRANS_RE = re.compile(r'V_(EXP|LOG|RCP|RSQ|SQRT|SIN|COS|CEIL|FLOOR|TRUNC|RNDNE|FRACT|FREXP)_') + _VALU_64_SHIFT_RE = re.compile(r'V_(LSHLREV|LSHRREV|ASHRREV)_(B|I)64') + _VALU_MAD64_RE = re.compile(r'V_MAD_(U|I)64') + _VALU_64_RE = re.compile(r'V_\w+_F64') + + def _valu_op(op_name: str) -> InstOp|None: + if 'CMPX' in op_name: return InstOp.VALU_CMPX + if _VALU_64_SHIFT_RE.search(op_name): return InstOp.VALU_64_SHIFT + if _VALU_MAD64_RE.search(op_name): return InstOp.VALU_MAD64 + if _VALU_64_RE.search(op_name): return InstOp.VALU_64 + if _VALU_TRANS_RE.search(op_name): return InstOp.VALU_TRANS + return None + + def _mem_op(t, op_name: str) -> InstOp: + is_store = "STORE" in op_name + if issubclass(t, _DS): return InstOp.LDS_STORE if is_store else InstOp.LDS_LOAD + if issubclass(t, _GLOBAL): return InstOp.GLOBAL_STORE if is_store else InstOp.GLOBAL_LOAD + if issubclass(t, _FLAT): return InstOp.FLAT_STORE if is_store else InstOp.FLAT_LOAD + if issubclass(t, _SCRATCH): return InstOp.FLAT_STORE if is_store else InstOp.FLAT_LOAD + return InstOp.SALU + + nibbles: list[int] = [] + started: set[int] = set() + _emit_nibbles(nibbles, LAYOUT_HEADER, layout=3, sel_a=6) + + def emit(wave_id: int, inst, branch_taken: bool|None): + """Emit an SQTT packet for one executed instruction.""" + w = wave_id & 0x1F + if wave_id not in started: + _emit_nibbles(nibbles, WAVESTART, delta=1, simd=0, cu_lo=0, wave=w, id7=wave_id) + started.add(wave_id) + inst_type, inst_op, op_name = type(inst), inst.op.value if hasattr(inst, 'op') else 0, inst.op.name if hasattr(inst, 'op') else "" + if issubclass(inst_type, _SOPP): + if inst_op in _SOPP_SKIP: return + elif inst_op in _SOPP_IMMEDIATE: _emit_nibbles(nibbles, IMMEDIATE, delta=1, wave=w) + elif inst_op in _SOPP_BARRIER: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.BARRIER) + elif inst_op in _SOPP_BRANCH: + _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.JUMP if branch_taken else InstOp.JUMP_NO) + else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SALU) + elif issubclass(inst_type, _VALU): + op = _valu_op(op_name) + if op is None: _emit_nibbles(nibbles, VALUINST, delta=1, wave=w) + else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=op) + elif issubclass(inst_type, _SMEM): _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SMEM) + else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=_mem_op(inst_type, op_name)) + + def finish(wave_id: int): + """Emit WAVEEND for a completed wave.""" + if wave_id in started: _emit_nibbles(nibbles, WAVEEND, delta=1, simd=0, cu_lo=0, wave=wave_id & 0x1F) + + def finalize() -> bytes: + """Pad and return the encoded SQTT blob.""" + while len(nibbles) % 2 != 0: nibbles.append(0) + nibbles.extend([0] * 32) + while len(nibbles) % 64 != 0: nibbles.append(0) + return _nibbles_to_bytes(nibbles) + + return emit, finish, finalize + def _c(val, dtype=dtypes.uint32): return UOp.const(dtype, val) def _u64(lo: UOp, hi: UOp) -> UOp: @@ -1231,14 +1350,16 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"): canonical_name = f"{_op_name(inst).lower()}_{base.to_bytes(size, 'little').hex()}" sink = sink.replace(arg=KernelInfo(name=canonical_name)).rtag(1) - # NOTE: renderer output is not reproducible because of _MXCSRContext - with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0): + # NOTE: renderer output is not reproducible because of _MXCSRContext. PROFILE=0 prevents emulator instruction runners from polluting profiling. + with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0, PROFILE=0): runner = get_runner('CPU', sink) _canonical_runner_cache.append((base, mask, size, runner)) return runner _BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER} if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT) +_BRANCH_OPS: set[int] = {op.value for op in (ir3.SOPPOp.S_BRANCH, ir3.SOPPOp.S_CBRANCH_SCC0, ir3.SOPPOp.S_CBRANCH_SCC1, + ir3.SOPPOp.S_CBRANCH_VCCZ, ir3.SOPPOp.S_CBRANCH_VCCNZ, ir3.SOPPOp.S_CBRANCH_EXECZ, ir3.SOPPOp.S_CBRANCH_EXECNZ)} def _decode_at(pc: int, arch: str): """Decode and compile instruction at absolute address pc. Returns (runner, decoded_inst).""" @@ -1295,8 +1416,8 @@ class WaveState: # ═══════════════════════════════════════════════════════════════════════════════ def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int, - scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None) -> tuple[WaveState, list]: - """Initialize a single wavefront and return (WaveState, c_bufs placeholder). c_bufs filled in by caller.""" + scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None) -> WaveState: + """Initialize a single wavefront and return WaveState.""" n_lanes = min(WAVE_SIZE, total_threads - wave_start) st = WaveState(n_lanes) st.pc = lib @@ -1324,7 +1445,8 @@ def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int, def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c, scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int: """Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane).""" - program: dict[int, tuple[Callable, list[int], bool]] = {} # pc -> (fxn, globals, is_barrier) + from tinygrad.renderer.amd.dsl import Inst + program: dict[int, tuple[Callable, list[int], bool, Inst]] = {} # pc -> (fxn, globals, is_barrier, inst) lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512 total_threads = lx * ly * lz @@ -1333,18 +1455,24 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lds_buf = Buffer('CPU', max(lds_size // 4, 1), dtypes.uint32).ensure_allocated() scratch_buf = Buffer('CPU', scratch_size * WAVE_SIZE, dtypes.uint8).ensure_allocated() if scratch_size else None - def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool]: + # Initialize SQTT encoder — emits packets inline as instructions execute (only when profiling) + if PROFILE: + sqtt_emit, sqtt_finish, sqtt_finalize = _init_sqtt_encoder() + + def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool, Inst]: if pc not in program: prev_len = len(_canonical_runner_cache) runner, inst = _decode_at(pc, arch) is_barrier = isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS - program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier) + program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier, inst) if DEBUG >= 3: msg = f"[emu] PC={pc - lib}: {inst!r}" print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg) return program[pc] # Set DAZ+FTZ during emulator execution, restore afterward to avoid breaking hypothesis tests + # Only trace the first workgroup (like real HW traces one CU/SIMD), subsequent workgroups run but don't add to trace + tracing = bool(PROFILE) with _MXCSRContext(): for gidz in range(gz): for gidy in range(gy): @@ -1370,14 +1498,21 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, pc = st.pc if pc == ENDPGM_PC: done[wi] = True + if tracing: sqtt_finish(wi) break - fxn, globals_list, is_barrier = _ensure_compiled(pc) + fxn, globals_list, is_barrier, inst = _ensure_compiled(pc) fxn(*[c_bufs[g] for g in globals_list]) + if tracing: + inst_op = inst.op.value if hasattr(inst, 'op') else 0 + sqtt_emit(wi, inst, (st.pc != ENDPGM_PC and st.pc != pc + inst.size()) if inst_op in _BRANCH_OPS else None) if is_barrier: break # s_barrier hit: PC already advanced past it, pause this wave else: raise RuntimeError("exceeded 1M instructions in single wave, likely infinite loop") # All waves have either hit barrier or endpgm — release barrier waves for next round else: raise RuntimeError("exceeded 10M total scheduling rounds") + tracing = False # only trace the first workgroup # Reset LDS for next workgroup if lds_size > 0: ctypes.memset(lds_buf._buf.va_addr, 0, max(lds_size, 4)) + + if PROFILE: sqtt_traces.append(sqtt_finalize()) return 0