mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add sqtt support to the emulator (#14791)
* add sqtt support to the emulator * more sqtt * cleanup * cleanups * simpler tests * some decent tests * test branch
This commit is contained in:
108
test/amd/test_sqtt_encoder.py
Normal file
108
test/amd/test_sqtt_encoder.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for SQTT encoder: verifies the emulator produces correct SQTT traces for known kernels.
|
||||
|
||||
Run with: AMD=1 MOCKGPU=1 python -m pytest test/amd/test_sqtt_encoder.py -v
|
||||
"""
|
||||
import ctypes, unittest
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.renderer.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
|
||||
def _run_kernel(instructions: list, lx=1, ly=1, lz=1, gx=1, gy=1, gz=1, args_ptr=0) -> bytes:
|
||||
"""Assemble instructions, run on emulator with PROFILE=1, return the SQTT blob."""
|
||||
from test.mockgpu.amd.emu import run_asm, sqtt_traces
|
||||
code = b''.join(inst.to_bytes() for inst in instructions)
|
||||
buf = (ctypes.c_char * len(code))(*code)
|
||||
lib = ctypes.addressof(buf)
|
||||
sqtt_traces.clear()
|
||||
with Context(PROFILE=1):
|
||||
run_asm(lib, len(code), gx, gy, gz, lx, ly, lz, args_ptr)
|
||||
assert len(sqtt_traces) == 1, f"expected 1 trace, got {len(sqtt_traces)}"
|
||||
return sqtt_traces.pop()
|
||||
|
||||
class TestSQTTEncoder(unittest.TestCase):
|
||||
|
||||
def test_simple_salu(self):
|
||||
"""A simple s_mov + s_endpgm kernel emits SALU INST packet."""
|
||||
blob = _run_kernel([s_mov_b32(s[0], 42), s_endpgm()])
|
||||
packets = list(decode(blob))
|
||||
inst_pkts = [p for p in packets if isinstance(p, INST)]
|
||||
self.assertEqual(len(inst_pkts), 1)
|
||||
self.assertEqual(inst_pkts[0].op, InstOp.SALU)
|
||||
|
||||
def test_valu_emits_valuinst(self):
|
||||
"""Regular VALU ops emit VALUINST packets."""
|
||||
blob = _run_kernel([v_mov_b32_e32(v[0], 0), v_add_f32_e32(v[1], v[0], v[0]), s_endpgm()])
|
||||
packets = list(decode(blob))
|
||||
valu_pkts = [p for p in packets if isinstance(p, VALUINST)]
|
||||
self.assertEqual(len(valu_pkts), 2)
|
||||
# no INST packets for regular VALU
|
||||
self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0)
|
||||
|
||||
def test_waitcnt_emits_immediate(self):
|
||||
"""s_waitcnt and s_nop emit IMMEDIATE packets."""
|
||||
blob = _run_kernel([s_nop(simm16=0), s_waitcnt(simm16=0), s_endpgm()])
|
||||
imm_pkts = [p for p in decode(blob) if isinstance(p, IMMEDIATE)]
|
||||
self.assertEqual(len(imm_pkts), 2) # s_nop + s_waitcnt
|
||||
|
||||
def test_endpgm_skipped(self):
|
||||
"""s_endpgm does not emit any packet."""
|
||||
blob = _run_kernel([s_endpgm()])
|
||||
packets = list(decode(blob))
|
||||
self.assertEqual(len([p for p in packets if isinstance(p, INST)]), 0)
|
||||
self.assertEqual(len([p for p in packets if isinstance(p, IMMEDIATE)]), 0)
|
||||
|
||||
def test_wave_lifecycle(self):
|
||||
"""Every WAVESTART has a matching WAVEEND."""
|
||||
blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()])
|
||||
packets = list(decode(blob))
|
||||
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), sum(1 for p in packets if isinstance(p, WAVEEND)))
|
||||
|
||||
def test_layout_header(self):
|
||||
"""First packet is LAYOUT_HEADER with layout=3."""
|
||||
blob = _run_kernel([s_endpgm()])
|
||||
packets = list(decode(blob))
|
||||
self.assertIsInstance(packets[0], LAYOUT_HEADER)
|
||||
self.assertEqual(packets[0].layout, 3)
|
||||
|
||||
def test_blob_32byte_aligned(self):
|
||||
"""SQTT blob is 32-byte aligned."""
|
||||
blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_endpgm()])
|
||||
self.assertEqual(len(blob) % 32, 0)
|
||||
|
||||
def test_multiple_waves(self):
|
||||
"""Multiple wavefronts each get their own WAVESTART/WAVEEND."""
|
||||
blob = _run_kernel([s_mov_b32(s[0], 0), s_endpgm()], lx=64) # 64 threads = 2 waves (WAVE_SIZE=32)
|
||||
packets = list(decode(blob))
|
||||
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVESTART)), 2)
|
||||
self.assertEqual(sum(1 for p in packets if isinstance(p, WAVEEND)), 2)
|
||||
|
||||
def test_branch_taken_and_not_taken(self):
|
||||
"""A loop with s_cbranch_scc1 emits JUMP when taken, JUMP_NO on final iteration."""
|
||||
# s[0] = 2; loop: s[0] -= 1; cmp s[0] != 0 (SCC=1 if true); cbranch_scc1 loop; endpgm
|
||||
# iteration 1: s[0]=2→1, SCC=1 (1!=0), branch taken (JUMP)
|
||||
# iteration 2: s[0]=1→0, SCC=0 (0==0), branch not taken (JUMP_NO)
|
||||
blob = _run_kernel([s_mov_b32(s[0], 2), s_sub_u32(s[0], s[0], 1), s_cmp_lg_u32(s[0], 0), s_cbranch_scc1(simm16=-3), s_endpgm()])
|
||||
inst_pkts = [p for p in decode(blob) if isinstance(p, INST)]
|
||||
ops = [p.op for p in inst_pkts]
|
||||
self.assertIn(InstOp.JUMP, ops)
|
||||
self.assertIn(InstOp.JUMP_NO, ops)
|
||||
|
||||
def test_timestamps_monotonic(self):
|
||||
"""Timestamps are monotonically non-decreasing."""
|
||||
blob = _run_kernel([s_mov_b32(s[0], 0), s_mov_b32(s[1], 1), s_mov_b32(s[2], 2), s_endpgm()])
|
||||
times = [p._time for p in decode(blob)]
|
||||
self.assertEqual(times, sorted(times))
|
||||
|
||||
def test_no_trace_without_profile(self):
|
||||
"""No SQTT trace is emitted when PROFILE=0."""
|
||||
from test.mockgpu.amd.emu import run_asm, sqtt_traces
|
||||
code = s_endpgm().to_bytes()
|
||||
buf = (ctypes.c_char * len(code))(*code)
|
||||
sqtt_traces.clear()
|
||||
with Context(PROFILE=0):
|
||||
run_asm(ctypes.addressof(buf), len(code), 1, 1, 1, 1, 1, 1, 0)
|
||||
self.assertEqual(len(sqtt_traces), 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -228,13 +228,21 @@ class PM4Executor(AMDQueue):
|
||||
event_dw = self._next_dword()
|
||||
match (event_dw & 0xFF): # event type
|
||||
case SQTT_EVENTS.THREAD_TRACE_FINISH:
|
||||
# Get the most recent trace from the emulator (if available)
|
||||
from test.mockgpu.amd.emu import sqtt_traces
|
||||
blob = sqtt_traces.pop(0) if sqtt_traces else b''
|
||||
old_idx = self.gpu.regs.grbm_index
|
||||
for se in range(self.gpu.regs.n_se):
|
||||
self.gpu.regs.grbm_index = 0b011 << 29 | se << 16 # select se, broadcast sa and instance
|
||||
self.gpu.regs[regSQ_THREAD_TRACE_STATUS] = 1 << 12 # FINISH_PENDING==0 FINISH_DONE==1 BUSY==0
|
||||
buf = ((self.gpu.regs[regSQ_THREAD_TRACE_BUF0_SIZE]&0xf)<<32|self.gpu.regs[regSQ_THREAD_TRACE_BUF0_BASE])<<12 # per page addressing
|
||||
fake_used = 0x1000 # fake one page long trace
|
||||
self.gpu.regs[regSQ_THREAD_TRACE_WPTR] = ((buf+fake_used)//32) & 0x1FFFFFFF
|
||||
buf_addr = ((self.gpu.regs[regSQ_THREAD_TRACE_BUF0_SIZE]&0xf)<<32|self.gpu.regs[regSQ_THREAD_TRACE_BUF0_BASE])<<12
|
||||
|
||||
# Use real trace blob for SE 0 (which has itrace enabled), empty blob for other SEs
|
||||
se_blob = blob if se == 0 else b''
|
||||
|
||||
# Write blob to trace buffer
|
||||
if se_blob: ctypes.memmove(buf_addr, se_blob, len(se_blob))
|
||||
self.gpu.regs[regSQ_THREAD_TRACE_WPTR] = ((buf_addr + len(se_blob)) // 32) & 0x1FFFFFFF
|
||||
self.gpu.regs.grbm_index = old_idx
|
||||
case _: pass # NOTE: for now most events aren't emulated
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.runtime.autogen import hsa
|
||||
from tinygrad.helpers import Context, DEBUG, colored
|
||||
from tinygrad.helpers import Context, DEBUG, PROFILE, colored
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
from tinygrad.renderer.amd import decode_inst
|
||||
@@ -71,6 +71,125 @@ from test.mockgpu.amd.pcode import parse_block, _FUNCS
|
||||
|
||||
MASK32 = 0xFFFFFFFF
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# SQTT TRACE COLLECTION
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
# Global trace storage: populated by run_asm as raw SQTT blobs, consumed by amdgpu.py
|
||||
sqtt_traces: list[bytes] = []
|
||||
|
||||
# Encoder primitives
|
||||
from tinygrad.renderer.amd.sqtt import _build_decode_tables, PACKET_TYPES_RDNA3, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, IMMEDIATE, VALUINST, InstOp
|
||||
|
||||
_NIB_COUNTS: dict = {cls: nc for _, (cls, nc, *_) in _build_decode_tables(PACKET_TYPES_RDNA3)[0].items()}
|
||||
|
||||
def _encode_raw(pkt_cls, **kwargs) -> tuple[int, int]:
|
||||
raw = pkt_cls.encoding.default
|
||||
for k, v in kwargs.items(): raw = pkt_cls.__dict__[k].set(raw, v)
|
||||
return raw, _NIB_COUNTS[pkt_cls]
|
||||
|
||||
def _emit_nibbles(nibbles: list[int], pkt_cls, **kwargs):
|
||||
raw, nc = _encode_raw(pkt_cls, **kwargs)
|
||||
for i in range(nc): nibbles.append((raw >> (i * 4)) & 0xF)
|
||||
|
||||
def _nibbles_to_bytes(nibbles: list[int]) -> bytes:
|
||||
result = bytearray()
|
||||
for i in range(0, len(nibbles), 2): result.append(nibbles[i] | ((nibbles[i + 1] if i + 1 < len(nibbles) else 0) << 4))
|
||||
return bytes(result)
|
||||
|
||||
def _init_sqtt_encoder():
|
||||
"""Initialize and return SQTT encoder state. Called once per dispatch with tracing enabled."""
|
||||
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp as SOPPOp3
|
||||
from tinygrad.runtime.autogen.amd.rdna4.enum import SOPPOp as SOPPOp4
|
||||
import re
|
||||
|
||||
_SOPP = (ir3.SOPP, ir4.SOPP, irc.SOPP)
|
||||
_SMEM = (ir3.SMEM, ir4.SMEM, irc.SMEM)
|
||||
_VALU = (ir3.VOP1, ir3.VOP2, ir3.VOP3, ir3.VOP3P, ir3.VOPC, ir3.VOPD, ir3.VOP3SD, ir3.VOP3_SDST, ir3.VOP1_SDST,
|
||||
ir4.VOP1, ir4.VOP2, ir4.VOP3, ir4.VOP3P, ir4.VOPC, ir4.VOPD, ir4.VOP3SD, ir4.VOP3_SDST, ir4.VOP1_SDST,
|
||||
irc.VOP1, irc.VOP2, irc.VOP3, irc.VOP3P, irc.VOPC, irc.VOP3SD, irc.VOP3_SDST)
|
||||
_DS = (ir3.DS, ir4.DS, irc.DS)
|
||||
_GLOBAL = (ir3.GLOBAL, ir4.VGLOBAL, irc.GLOBAL)
|
||||
_FLAT = (ir3.FLAT, ir4.VFLAT, irc.FLAT)
|
||||
_SCRATCH = (ir3.SCRATCH, ir4.VSCRATCH, irc.SCRATCH)
|
||||
|
||||
# SOPP classification sets
|
||||
_SOPP_SKIP = {SOPPOp3.S_ENDPGM.value, SOPPOp3.S_ENDPGM_SAVED.value, SOPPOp3.S_ENDPGM_ORDERED_PS_DONE.value,
|
||||
SOPPOp3.S_DELAY_ALU.value}
|
||||
_SOPP_IMMEDIATE = {SOPPOp3.S_NOP.value, SOPPOp3.S_CLAUSE.value, SOPPOp3.S_WAITCNT.value, SOPPOp3.S_WAITCNT_DEPCTR.value,
|
||||
SOPPOp3.S_WAIT_IDLE.value, SOPPOp3.S_WAIT_EVENT.value, SOPPOp3.S_SLEEP.value,
|
||||
SOPPOp3.S_SET_INST_PREFETCH_DISTANCE.value}
|
||||
for _op in (SOPPOp4.S_WAIT_ALU, SOPPOp4.S_WAIT_LOADCNT, SOPPOp4.S_WAIT_STORECNT, SOPPOp4.S_WAIT_SAMPLECNT,
|
||||
SOPPOp4.S_WAIT_BVHCNT, SOPPOp4.S_WAIT_EXPCNT, SOPPOp4.S_WAIT_DSCNT, SOPPOp4.S_WAIT_KMCNT,
|
||||
SOPPOp4.S_WAIT_LOADCNT_DSCNT, SOPPOp4.S_WAIT_STORECNT_DSCNT):
|
||||
_SOPP_IMMEDIATE.add(_op.value)
|
||||
_SOPP_BARRIER = {SOPPOp3.S_BARRIER.value}
|
||||
if hasattr(SOPPOp4, 'S_BARRIER_WAIT'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_WAIT.value)
|
||||
if hasattr(SOPPOp4, 'S_BARRIER_LEAVE'): _SOPP_BARRIER.add(SOPPOp4.S_BARRIER_LEAVE.value)
|
||||
_SOPP_BRANCH = {SOPPOp3.S_BRANCH.value, SOPPOp3.S_CBRANCH_SCC0.value, SOPPOp3.S_CBRANCH_SCC1.value,
|
||||
SOPPOp3.S_CBRANCH_VCCZ.value, SOPPOp3.S_CBRANCH_VCCNZ.value,
|
||||
SOPPOp3.S_CBRANCH_EXECZ.value, SOPPOp3.S_CBRANCH_EXECNZ.value}
|
||||
|
||||
# VALU sub-classification patterns
|
||||
_VALU_TRANS_RE = re.compile(r'V_(EXP|LOG|RCP|RSQ|SQRT|SIN|COS|CEIL|FLOOR|TRUNC|RNDNE|FRACT|FREXP)_')
|
||||
_VALU_64_SHIFT_RE = re.compile(r'V_(LSHLREV|LSHRREV|ASHRREV)_(B|I)64')
|
||||
_VALU_MAD64_RE = re.compile(r'V_MAD_(U|I)64')
|
||||
_VALU_64_RE = re.compile(r'V_\w+_F64')
|
||||
|
||||
def _valu_op(op_name: str) -> InstOp|None:
|
||||
if 'CMPX' in op_name: return InstOp.VALU_CMPX
|
||||
if _VALU_64_SHIFT_RE.search(op_name): return InstOp.VALU_64_SHIFT
|
||||
if _VALU_MAD64_RE.search(op_name): return InstOp.VALU_MAD64
|
||||
if _VALU_64_RE.search(op_name): return InstOp.VALU_64
|
||||
if _VALU_TRANS_RE.search(op_name): return InstOp.VALU_TRANS
|
||||
return None
|
||||
|
||||
def _mem_op(t, op_name: str) -> InstOp:
|
||||
is_store = "STORE" in op_name
|
||||
if issubclass(t, _DS): return InstOp.LDS_STORE if is_store else InstOp.LDS_LOAD
|
||||
if issubclass(t, _GLOBAL): return InstOp.GLOBAL_STORE if is_store else InstOp.GLOBAL_LOAD
|
||||
if issubclass(t, _FLAT): return InstOp.FLAT_STORE if is_store else InstOp.FLAT_LOAD
|
||||
if issubclass(t, _SCRATCH): return InstOp.FLAT_STORE if is_store else InstOp.FLAT_LOAD
|
||||
return InstOp.SALU
|
||||
|
||||
nibbles: list[int] = []
|
||||
started: set[int] = set()
|
||||
_emit_nibbles(nibbles, LAYOUT_HEADER, layout=3, sel_a=6)
|
||||
|
||||
def emit(wave_id: int, inst, branch_taken: bool|None):
|
||||
"""Emit an SQTT packet for one executed instruction."""
|
||||
w = wave_id & 0x1F
|
||||
if wave_id not in started:
|
||||
_emit_nibbles(nibbles, WAVESTART, delta=1, simd=0, cu_lo=0, wave=w, id7=wave_id)
|
||||
started.add(wave_id)
|
||||
inst_type, inst_op, op_name = type(inst), inst.op.value if hasattr(inst, 'op') else 0, inst.op.name if hasattr(inst, 'op') else ""
|
||||
if issubclass(inst_type, _SOPP):
|
||||
if inst_op in _SOPP_SKIP: return
|
||||
elif inst_op in _SOPP_IMMEDIATE: _emit_nibbles(nibbles, IMMEDIATE, delta=1, wave=w)
|
||||
elif inst_op in _SOPP_BARRIER: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.BARRIER)
|
||||
elif inst_op in _SOPP_BRANCH:
|
||||
_emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.JUMP if branch_taken else InstOp.JUMP_NO)
|
||||
else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SALU)
|
||||
elif issubclass(inst_type, _VALU):
|
||||
op = _valu_op(op_name)
|
||||
if op is None: _emit_nibbles(nibbles, VALUINST, delta=1, wave=w)
|
||||
else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=op)
|
||||
elif issubclass(inst_type, _SMEM): _emit_nibbles(nibbles, INST, delta=1, wave=w, op=InstOp.SMEM)
|
||||
else: _emit_nibbles(nibbles, INST, delta=1, wave=w, op=_mem_op(inst_type, op_name))
|
||||
|
||||
def finish(wave_id: int):
|
||||
"""Emit WAVEEND for a completed wave."""
|
||||
if wave_id in started: _emit_nibbles(nibbles, WAVEEND, delta=1, simd=0, cu_lo=0, wave=wave_id & 0x1F)
|
||||
|
||||
def finalize() -> bytes:
|
||||
"""Pad and return the encoded SQTT blob."""
|
||||
while len(nibbles) % 2 != 0: nibbles.append(0)
|
||||
nibbles.extend([0] * 32)
|
||||
while len(nibbles) % 64 != 0: nibbles.append(0)
|
||||
return _nibbles_to_bytes(nibbles)
|
||||
|
||||
return emit, finish, finalize
|
||||
|
||||
def _c(val, dtype=dtypes.uint32): return UOp.const(dtype, val)
|
||||
|
||||
def _u64(lo: UOp, hi: UOp) -> UOp:
|
||||
@@ -1231,14 +1350,16 @@ def _get_runner(inst_bytes: bytes, arch: str = "rdna3"):
|
||||
canonical_name = f"{_op_name(inst).lower()}_{base.to_bytes(size, 'little').hex()}"
|
||||
sink = sink.replace(arg=KernelInfo(name=canonical_name)).rtag(1)
|
||||
|
||||
# NOTE: renderer output is not reproducible because of _MXCSRContext
|
||||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0):
|
||||
# NOTE: renderer output is not reproducible because of _MXCSRContext. PROFILE=0 prevents emulator instruction runners from polluting profiling.
|
||||
with Context(NOOPT=1, CHECK_OOB=0, TUPLE_ORDER=0, EMULATED_DTYPES="", CAPTURE_PROCESS_REPLAY=0, PROFILE=0):
|
||||
runner = get_runner('CPU', sink)
|
||||
_canonical_runner_cache.append((base, mask, size, runner))
|
||||
return runner
|
||||
|
||||
_BARRIER_OPS = {ir3.SOPPOp.S_BARRIER, irc.SOPPOp.S_BARRIER}
|
||||
if hasattr(ir4.SOPPOp, 'S_BARRIER_WAIT'): _BARRIER_OPS.add(ir4.SOPPOp.S_BARRIER_WAIT)
|
||||
_BRANCH_OPS: set[int] = {op.value for op in (ir3.SOPPOp.S_BRANCH, ir3.SOPPOp.S_CBRANCH_SCC0, ir3.SOPPOp.S_CBRANCH_SCC1,
|
||||
ir3.SOPPOp.S_CBRANCH_VCCZ, ir3.SOPPOp.S_CBRANCH_VCCNZ, ir3.SOPPOp.S_CBRANCH_EXECZ, ir3.SOPPOp.S_CBRANCH_EXECNZ)}
|
||||
|
||||
def _decode_at(pc: int, arch: str):
|
||||
"""Decode and compile instruction at absolute address pc. Returns (runner, decoded_inst)."""
|
||||
@@ -1295,8 +1416,8 @@ class WaveState:
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int,
|
||||
scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None) -> tuple[WaveState, list]:
|
||||
"""Initialize a single wavefront and return (WaveState, c_bufs placeholder). c_bufs filled in by caller."""
|
||||
scratch_size: int, arch: str, gidx: int, gidy: int, gidz: int, user_data: list[int]|None) -> WaveState:
|
||||
"""Initialize a single wavefront and return WaveState."""
|
||||
n_lanes = min(WAVE_SIZE, total_threads - wave_start)
|
||||
st = WaveState(n_lanes)
|
||||
st.pc = lib
|
||||
@@ -1324,7 +1445,8 @@ def _init_wave(lib: int, wave_start: int, total_threads: int, lx: int, ly: int,
|
||||
def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int, lz: int, args_ptr: int, rsrc2: int = 0x19c,
|
||||
scratch_size: int = 0, arch: str = "rdna3", user_data: list[int]|None = None) -> int:
|
||||
"""Execute AMD assembly program. scratch_size is private_segment_fixed_size from kernel descriptor (per-lane)."""
|
||||
program: dict[int, tuple[Callable, list[int], bool]] = {} # pc -> (fxn, globals, is_barrier)
|
||||
from tinygrad.renderer.amd.dsl import Inst
|
||||
program: dict[int, tuple[Callable, list[int], bool, Inst]] = {} # pc -> (fxn, globals, is_barrier, inst)
|
||||
lds_size = ((rsrc2 & hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE) >> hsa.AMD_COMPUTE_PGM_RSRC_TWO_GRANULATED_LDS_SIZE_SHIFT) * 512
|
||||
total_threads = lx * ly * lz
|
||||
|
||||
@@ -1333,18 +1455,24 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
||||
lds_buf = Buffer('CPU', max(lds_size // 4, 1), dtypes.uint32).ensure_allocated()
|
||||
scratch_buf = Buffer('CPU', scratch_size * WAVE_SIZE, dtypes.uint8).ensure_allocated() if scratch_size else None
|
||||
|
||||
def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool]:
|
||||
# Initialize SQTT encoder — emits packets inline as instructions execute (only when profiling)
|
||||
if PROFILE:
|
||||
sqtt_emit, sqtt_finish, sqtt_finalize = _init_sqtt_encoder()
|
||||
|
||||
def _ensure_compiled(pc: int) -> tuple[Callable, list[int], bool, Inst]:
|
||||
if pc not in program:
|
||||
prev_len = len(_canonical_runner_cache)
|
||||
runner, inst = _decode_at(pc, arch)
|
||||
is_barrier = isinstance(inst, (ir3.SOPP, ir4.SOPP, irc.SOPP)) and inst.op in _BARRIER_OPS
|
||||
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier)
|
||||
program[pc] = (runner._prg.fxn, runner.p.globals, is_barrier, inst)
|
||||
if DEBUG >= 3:
|
||||
msg = f"[emu] PC={pc - lib}: {inst!r}"
|
||||
print(colored(msg, 'green') if len(_canonical_runner_cache) > prev_len else msg)
|
||||
return program[pc]
|
||||
|
||||
# Set DAZ+FTZ during emulator execution, restore afterward to avoid breaking hypothesis tests
|
||||
# Only trace the first workgroup (like real HW traces one CU/SIMD), subsequent workgroups run but don't add to trace
|
||||
tracing = bool(PROFILE)
|
||||
with _MXCSRContext():
|
||||
for gidz in range(gz):
|
||||
for gidy in range(gy):
|
||||
@@ -1370,14 +1498,21 @@ def run_asm(lib: int, lib_sz: int, gx: int, gy: int, gz: int, lx: int, ly: int,
|
||||
pc = st.pc
|
||||
if pc == ENDPGM_PC:
|
||||
done[wi] = True
|
||||
if tracing: sqtt_finish(wi)
|
||||
break
|
||||
fxn, globals_list, is_barrier = _ensure_compiled(pc)
|
||||
fxn, globals_list, is_barrier, inst = _ensure_compiled(pc)
|
||||
fxn(*[c_bufs[g] for g in globals_list])
|
||||
if tracing:
|
||||
inst_op = inst.op.value if hasattr(inst, 'op') else 0
|
||||
sqtt_emit(wi, inst, (st.pc != ENDPGM_PC and st.pc != pc + inst.size()) if inst_op in _BRANCH_OPS else None)
|
||||
if is_barrier: break # s_barrier hit: PC already advanced past it, pause this wave
|
||||
else: raise RuntimeError("exceeded 1M instructions in single wave, likely infinite loop")
|
||||
# All waves have either hit barrier or endpgm — release barrier waves for next round
|
||||
else: raise RuntimeError("exceeded 10M total scheduling rounds")
|
||||
tracing = False # only trace the first workgroup
|
||||
|
||||
# Reset LDS for next workgroup
|
||||
if lds_size > 0: ctypes.memset(lds_buf._buf.va_addr, 0, max(lds_size, 4))
|
||||
|
||||
if PROFILE: sqtt_traces.append(sqtt_finalize())
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user