assembly/amd: start work on SQTT parsing/emulation

This commit is contained in:
George Hotz
2026-01-01 18:40:58 -05:00
parent a8bea4ec52
commit 8d43212bc6
8 changed files with 1462 additions and 3 deletions

View File

@@ -2,6 +2,7 @@
# mypy: ignore-errors
from __future__ import annotations
import ctypes
from enum import IntEnum
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.asm import detect_format
@@ -90,6 +91,164 @@ class LDSMem:
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# ═══════════════════════════════════════════════════════════════════════════════
# TIMING MODEL - Instruction latencies and register dependency tracking
# ═══════════════════════════════════════════════════════════════════════════════
class InstType(IntEnum):
"""Instruction type for timing model."""
SALU = 0 # Scalar ALU
VALU = 1 # Vector ALU (simple)
TRANS32 = 2 # Transcendental (v_rcp, v_rsq, v_sqrt, v_log, v_exp, v_sin, v_cos)
SMEM = 3 # Scalar memory
VMEM = 4 # Vector memory (global/scratch)
LDS = 5 # LDS operations
BRANCH = 6 # Branch/control flow
NOP = 7 # No operation
OTHER = 8 # Other
# Latencies from s_delay_alu encoding (VALU_DEP_1-4, TRANS32_DEP_1-3, SALU_CYCLE_1-3)
INST_LATENCY = {
InstType.SALU: 1, # SALU_CYCLE_1
InstType.VALU: 1, # VALU_DEP_1 (with forwarding, no stall)
InstType.TRANS32: 4, # TRANS32_DEP_3 (conservative)
InstType.SMEM: 4, # Variable, use small fixed value
InstType.VMEM: 1, # Variable, skip for now (issue immediately)
InstType.LDS: 1, # Variable, skip for now
InstType.BRANCH: 1,
InstType.NOP: 1,
InstType.OTHER: 1,
}
# Transcendental ops - higher latency
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
def classify_inst(inst: Inst) -> InstType:
"""Classify instruction for timing model."""
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)):
return InstType.SALU
if isinstance(inst, SOPP):
if inst.op in (SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER): return InstType.OTHER
if hasattr(inst, 'op') and 'BRANCH' in getattr(inst.op, 'name', ''): return InstType.BRANCH
return InstType.NOP
if isinstance(inst, SMEM):
return InstType.SMEM
if isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)):
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
if any(t in op_name for t in _TRANS_OPS):
return InstType.TRANS32
return InstType.VALU
if isinstance(inst, VOPD):
return InstType.VALU
if isinstance(inst, FLAT):
return InstType.VMEM
if isinstance(inst, DS):
return InstType.LDS
return InstType.OTHER
class TimingState:
"""Timing state for cycle-accurate emulation with SQTT output."""
__slots__ = ('cycle', 'sgpr_ready', 'vgpr_ready', 'packets', 'wave_id', 'simd', 'cu')
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
self.cycle = 0 # Current cycle
self.sgpr_ready: dict[int, int] = {} # SGPR -> cycle when ready
self.vgpr_ready: dict[int, int] = {} # VGPR -> cycle when ready (all lanes same for now)
self.packets: list = [] # SQTT packets
self.wave_id, self.simd, self.cu = wave_id, simd, cu
def stall_for_read(self, inst: Inst, st: 'WaveState') -> int:
"""Calculate stall cycles needed before executing this instruction."""
max_ready = self.cycle
# Get source registers
srcs = []
if isinstance(inst, (SOP1, SOP2, SOPC)):
if hasattr(inst, 'ssrc0') and inst.ssrc0 < SGPR_COUNT: srcs.append(('s', inst.ssrc0))
if hasattr(inst, 'ssrc1') and inst.ssrc1 < SGPR_COUNT: srcs.append(('s', inst.ssrc1))
elif isinstance(inst, SOPK):
if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT: srcs.append(('s', inst.sdst))
elif isinstance(inst, SMEM):
if hasattr(inst, 'sbase'): srcs.append(('s', inst.sbase * 2)); srcs.append(('s', inst.sbase * 2 + 1))
if hasattr(inst, 'soffset') and inst.soffset < SGPR_COUNT: srcs.append(('s', inst.soffset))
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOPC)):
if hasattr(inst, 'src0'):
if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0))
elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256))
if hasattr(inst, 'src1'):
if inst.src1 < SGPR_COUNT: srcs.append(('s', inst.src1))
elif inst.src1 >= 256: srcs.append(('v', inst.src1 - 256))
if hasattr(inst, 'src2'):
if inst.src2 < SGPR_COUNT: srcs.append(('s', inst.src2))
elif inst.src2 >= 256: srcs.append(('v', inst.src2 - 256))
if hasattr(inst, 'vsrc1'): srcs.append(('v', inst.vsrc1))
# Find max ready time across sources
for typ, reg in srcs:
if typ == 's':
ready = self.sgpr_ready.get(reg, 0)
else:
ready = self.vgpr_ready.get(reg, 0)
if ready > max_ready:
max_ready = ready
return max(0, max_ready - self.cycle)
def schedule_write(self, inst: Inst, latency: int):
"""Record when destination registers will be ready."""
ready_cycle = self.cycle + latency
# Get destination registers
if isinstance(inst, (SOP1, SOP2, SOPK)):
if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT:
self.sgpr_ready[inst.sdst] = ready_cycle
if inst.dst_regs() == 2:
self.sgpr_ready[inst.sdst + 1] = ready_cycle
elif isinstance(inst, SMEM):
if hasattr(inst, 'sdata'):
cnt = SMEM_LOAD.get(inst.op, 1)
for i in range(cnt):
self.sgpr_ready[inst.sdata + i] = ready_cycle
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOPC, VOPD)):
if hasattr(inst, 'vdst'):
self.vgpr_ready[inst.vdst] = ready_cycle
if inst.dst_regs() == 2:
self.vgpr_ready[inst.vdst + 1] = ready_cycle
elif isinstance(inst, FLAT):
if 'LOAD' in inst.op.name and hasattr(inst, 'vdst'):
ndwords = _op_ndwords(inst.op.name)
for i in range(ndwords):
self.vgpr_ready[inst.vdst + i] = ready_cycle
def emit_wavestart(self):
"""Emit WAVESTART packet."""
from extra.assembly.amd.sqtt import Packet, Op
fields = (self.wave_id & 0xF) << 8 | (self.simd & 0x3) << 12
self.packets.append(Packet(opcode=Op.WAVESTART, delta=0, fields=fields, time=self.cycle))
def emit_inst(self, inst: Inst, pc: int, stall: int, dur: int):
"""Emit INST packet for instruction dispatch."""
from extra.assembly.amd.sqtt import Packet, Op
inst_type = classify_inst(inst)
# Fields: wave[3:0] at [11:8], inst_type at higher bits
fields = (self.wave_id & 0xF) << 8 | (inst_type & 0xFF) << 12
delta = stall + dur
self.packets.append(Packet(opcode=Op.INST, delta=delta, fields=fields, time=self.cycle))
def emit_waveend(self):
"""Emit WAVEEND packet."""
from extra.assembly.amd.sqtt import Packet, Op
fields = (self.wave_id & 0xF) << 8 | (self.simd & 0x3) << 12
self.packets.append(Packet(opcode=Op.WAVEEND, delta=0, fields=fields, time=self.cycle))
def to_sqtt_blob(self) -> bytes:
"""Encode all packets to raw SQTT blob."""
from extra.assembly.amd.sqtt import Packet, Op, encode
# Prepend LAYOUT_HEADER
header = Packet(opcode=Op.LAYOUT_HEADER, delta=0, fields=0x68900180, time=0)
return encode([header] + self.packets)
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
_VOPD_TO_VOP = {
VOPDOp.V_DUAL_FMAC_F32: VOP3Op.V_FMAC_F32, VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32,

394
extra/assembly/amd/sqtt.py Normal file
View File

@@ -0,0 +1,394 @@
"""SQTT (SQ Thread Trace) packet encoder and decoder for AMD GPUs.
This module provides encoding and decoding of raw SQTT byte streams.
The format is nibble-based with variable-width packets determined by a state machine.
Uses BitField infrastructure from dsl.py, similar to GPU instruction encoding.
"""
from __future__ import annotations
from enum import IntEnum
from typing import get_type_hints
from extra.assembly.amd.dsl import BitField, bits
# ═══════════════════════════════════════════════════════════════════════════════
# FIELD ENUMS
# ═══════════════════════════════════════════════════════════════════════════════
class MemSrc(IntEnum):
LDS = 0
LDS_ALT = 1
VMEM = 2
VMEM_ALT = 3
class AluSrc(IntEnum):
NONE = 0
SALU = 1
VALU = 2
VALU_ALT = 3
class InstOp(IntEnum):
SALU = 0x0
SMEM = 0x1
JUMP = 0x3
NEXT = 0x4
MESSAGE = 0x9
VALU = 0xb
VALU_64 = 0xd
VALU_MAD64 = 0xe
VMEM_LOAD = 0x21
VMEM_LOAD_ALT = 0x22
VMEM_STORE = 0x24
VMEM_STORE_64 = 0x25
VMEM_STORE_96 = 0x26
VMEM_STORE_128 = 0x27
VMEM_STORE_ALT = 0x28
LDS_LOAD = 0x29
LDS_STORE = 0x2b
LDS_STORE_128 = 0x2e
SIMD_LDS_LOAD = 0x50
SIMD_LDS_LOAD_ALT = 0x51
SIMD_LDS_STORE = 0x54
SIMD_VMEM_LOAD = 0x5a
SIMD_VMEM_LOAD_ALT = 0x5b
SIMD_VMEM_STORE = 0x5c
SIMD_VMEM_STORE_ALT = 0x5d
SIMD_VMEM_STORE_96 = 0x5e
SIMD_VMEM_STORE_128 = 0x5f
SALU_OR = 0x72
VALU_CMPX = 0x73
# ═══════════════════════════════════════════════════════════════════════════════
# PACKET TYPE BASE CLASS
# ═══════════════════════════════════════════════════════════════════════════════
class PacketType:
"""Base class for SQTT packet types."""
_encoding: tuple[BitField, int] | None = None
_field_types: dict[str, type] = {}
_values: dict[str, int]
_raw: int
_time: int
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if 'encoding' in cls.__dict__ and isinstance(cls.__dict__['encoding'], tuple):
cls._encoding = cls.__dict__['encoding']
# Cache field type annotations for enum conversion
try:
cls._field_types = {k: v for k, v in get_type_hints(cls).items()
if isinstance(v, type) and issubclass(v, IntEnum)}
except Exception:
cls._field_types = {}
@classmethod
def fields(cls) -> dict[str, BitField]:
return {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField) and k != 'encoding'}
@classmethod
def size_bits(cls) -> int:
max_bit = max((f.hi for f in cls.fields().values()), default=0)
return ((max_bit + 4) // 4) * 4
@classmethod
def size_nibbles(cls) -> int:
return cls.size_bits() // 4
@classmethod
def from_raw(cls, raw: int, time: int = 0):
inst = object.__new__(cls)
inst._raw = raw
inst._time = time
inst._values = {}
for name, bf in cls.fields().items():
val = (raw >> bf.lo) & bf.mask()
# Convert to enum if annotated
enum_type = cls._field_types.get(name)
if enum_type is not None:
try: val = enum_type(val)
except ValueError: pass
inst._values[name] = val
return inst
def __getattr__(self, name: str):
if name.startswith('_'): raise AttributeError(name)
return self._values.get(name, 0)
def __repr__(self) -> str:
fields_str = ", ".join(f"{k}={v}" for k, v in self._values.items() if not k.startswith('_'))
return f"{self.__class__.__name__}({fields_str})"
# ═══════════════════════════════════════════════════════════════════════════════
# PACKET TYPE DEFINITIONS
# ═══════════════════════════════════════════════════════════════════════════════
class VALUINST(PacketType):
encoding = bits[2:0] == 0b011
delta = bits[5:3]
flag = bits[6:6]
wave = bits[11:7]
class VMEMEXEC(PacketType):
encoding = bits[3:0] == 0b1111
delta = bits[5:4]
src: MemSrc = bits[7:6]
class ALUEXEC(PacketType):
encoding = bits[3:0] == 0b1110
delta = bits[5:4]
src: AluSrc = bits[7:6]
class IMMEDIATE(PacketType):
encoding = bits[3:0] == 0b1101
delta = bits[6:4]
wave = bits[11:7]
class IMMEDIATE_MASK(PacketType):
encoding = bits[4:0] == 0b00100
delta = bits[7:5]
mask = bits[23:8]
class WAVERDY(PacketType):
encoding = bits[4:0] == 0b10100
delta = bits[7:5]
mask = bits[23:8]
class TS_DELTA_S8_W3(PacketType):
encoding = bits[6:0] == 0b0100001
delta = bits[10:8]
_padding = bits[63:11]
class WAVEEND(PacketType):
encoding = bits[4:0] == 0b10101
delta = bits[7:5]
flag7 = bits[8:8]
simd = bits[10:9]
cu_lo = bits[13:11]
wave = bits[19:15]
@property
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
class WAVESTART(PacketType):
encoding = bits[4:0] == 0b01100
delta = bits[6:5]
flag7 = bits[7:7]
simd = bits[9:8]
cu_lo = bits[12:10]
wave = bits[17:13]
id7 = bits[31:18]
@property
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
class TS_DELTA_S5_W2(PacketType):
encoding = bits[4:0] == 0b11100
delta = bits[6:5]
_padding = bits[47:7]
class WAVEALLOC(PacketType):
encoding = bits[4:0] == 0b00101
delta = bits[7:5]
_padding = bits[19:8]
class TS_DELTA_S5_W3(PacketType):
encoding = bits[4:0] == 0b00110
delta = bits[7:5]
_padding = bits[51:8]
class PERF(PacketType):
encoding = bits[4:0] == 0b10110
delta = bits[7:5]
arg = bits[27:8]
class TS_DELTA_SHORT(PacketType):
encoding = bits[3:0] == 0b1000
delta = bits[7:4]
class NOP(PacketType):
encoding = bits[3:0] == 0b0000
delta = None # type: ignore
_padding = bits[3:0]
class TS_WAVE_STATE(PacketType):
encoding = bits[6:0] == 0b1010001
delta = bits[15:7]
coarse = bits[23:16]
@property
def wave_interest(self) -> bool: return bool(self.coarse & 1)
@property
def terminate_all(self) -> bool: return bool(self.coarse & 8)
class EVENT(PacketType):
encoding = bits[7:0] == 0b01100001
delta = bits[10:8]
event = bits[23:11]
class EVENT_BIG(PacketType):
encoding = bits[7:0] == 0b11100001
delta = bits[10:8]
event = bits[31:11]
class REG(PacketType):
encoding = bits[3:0] == 0b1001
delta = bits[6:4]
slot = bits[9:7]
hi_byte = bits[15:8]
subop = bits[31:16]
val32 = bits[63:32]
@property
def is_config(self) -> bool: return bool(self.hi_byte & 0x80)
class SNAPSHOT(PacketType):
encoding = bits[6:0] == 0b1110001
delta = bits[9:7]
snap = bits[63:10]
class TS_DELTA_OR_MARK(PacketType):
encoding = bits[6:0] == 0b0000001
delta = bits[47:12]
bit8 = bits[8:8]
bit9 = bits[9:9]
@property
def is_marker(self) -> bool: return bool(self.bit9 and not self.bit8)
class LAYOUT_HEADER(PacketType):
encoding = bits[6:0] == 0b0010001
delta = None # type: ignore
layout = bits[12:7]
simd = bits[14:13]
group = bits[17:15]
sel_a = bits[31:28]
sel_b = bits[36:33]
flag4 = bits[59:59]
_padding = bits[63:60]
class INST(PacketType):
encoding = bits[2:0] == 0b010
delta = bits[6:4]
flag1 = bits[3:3]
flag2 = bits[7:7]
wave = bits[12:8]
op: InstOp = bits[19:13]
class UTILCTR(PacketType):
encoding = bits[6:0] == 0b0110001
delta = bits[8:7]
ctr = bits[47:9]
# All packet types in encoding priority order (more specific masks first, NOP last as fallback)
PACKET_TYPES: list[type[PacketType]] = [
EVENT, EVENT_BIG,
TS_DELTA_S8_W3, TS_WAVE_STATE, SNAPSHOT, TS_DELTA_OR_MARK, LAYOUT_HEADER, UTILCTR,
IMMEDIATE_MASK, WAVERDY, WAVEEND, WAVESTART, TS_DELTA_S5_W2, WAVEALLOC, TS_DELTA_S5_W3, PERF,
VMEMEXEC, ALUEXEC, IMMEDIATE, TS_DELTA_SHORT, REG,
VALUINST, INST,
NOP,
]
PACKET_BY_NAME: dict[str, type[PacketType]] = {cls.__name__: cls for cls in PACKET_TYPES}
def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
table = [len(PACKET_TYPES) - 1] * 256 # default to NOP
opcode_to_class: dict[int, type[PacketType]] = {i: cls for i, cls in enumerate(PACKET_TYPES)}
for byte_val in range(256):
for opcode, pkt_cls in enumerate(PACKET_TYPES):
if pkt_cls._encoding is None: continue
mask_bf, pattern = pkt_cls._encoding
if (byte_val & mask_bf.mask()) == pattern:
table[byte_val] = opcode
break
return bytes(table), opcode_to_class
STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table()
OPCODE_TO_BYTES: dict[int, list[int]] = {}
for _byte_val, _opcode in enumerate(STATE_TO_OPCODE):
if _opcode not in OPCODE_TO_BYTES: OPCODE_TO_BYTES[_opcode] = []
OPCODE_TO_BYTES[_opcode].append(_byte_val)
BUDGET = {opcode: pkt_cls.size_nibbles() for opcode, pkt_cls in OPCODE_TO_CLASS.items()}
# ═══════════════════════════════════════════════════════════════════════════════
# DECODER
# ═══════════════════════════════════════════════════════════════════════════════
def decode(data: bytes) -> list[PacketType]:
"""Decode raw SQTT blob into list of packet instances."""
packets: list[PacketType] = []
n = len(data)
reg = 0
offset = 0
nib_count = 16
time = 0
while (offset >> 3) < n:
target = offset + nib_count * 4
while offset < target and (offset >> 3) < n:
byte = data[offset >> 3]
nib = (byte >> (offset & 4)) & 0xF
reg = ((reg >> 4) | (nib << 60)) & ((1 << 64) - 1)
offset += 4
if offset < target: break
opcode = STATE_TO_OPCODE[reg & 0xFF]
pkt_cls = OPCODE_TO_CLASS[opcode]
nib_count = BUDGET[opcode]
delta_field = getattr(pkt_cls, 'delta', None)
delta = (reg >> delta_field.lo) & delta_field.mask() if delta_field is not None else 0
if pkt_cls is TS_DELTA_OR_MARK:
bit8 = (reg >> TS_DELTA_OR_MARK.bit8.lo) & TS_DELTA_OR_MARK.bit8.mask()
bit9 = (reg >> TS_DELTA_OR_MARK.bit9.lo) & TS_DELTA_OR_MARK.bit9.mask()
if bit9 and not bit8: delta = 0
elif pkt_cls is TS_DELTA_SHORT:
delta = delta + 8
time += delta
pkt = pkt_cls.from_raw(reg, time)
packets.append(pkt)
return packets
# ═══════════════════════════════════════════════════════════════════════════════
# ENCODER
# ═══════════════════════════════════════════════════════════════════════════════
def encode(packets: list[PacketType]) -> bytes:
"""Encode a list of packet instances into raw SQTT blob."""
if not packets: return b''
read_lengths = [16]
for p in packets[:-1]:
read_lengths.append(p.size_nibbles())
total_nibbles = sum(read_lengths)
bits_arr = [0] * (total_nibbles * 4)
cumulative = 0
for i, p in enumerate(packets):
cumulative += read_lengths[i]
pkt_cls = type(p)
opcode = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is pkt_cls)
byte_vals = OPCODE_TO_BYTES.get(opcode)
if not byte_vals: raise ValueError(f"No encoding for {pkt_cls.__name__}")
opcode_byte = byte_vals[0]
delta_field = getattr(pkt_cls, 'delta', None)
if delta_field is not None and delta_field.hi < 8:
delta = p._values.get('delta', 0)
if isinstance(delta, IntEnum): delta = delta.value
if pkt_cls is TS_DELTA_SHORT: delta = max(0, delta - 8)
delta = delta & delta_field.mask()
opcode_byte = (opcode_byte & ~(delta_field.mask() << delta_field.lo)) | (delta << delta_field.lo)
opcode_nibble_pos = max(0, cumulative - 16)
opcode_bit_pos = opcode_nibble_pos * 4
for b in range(8):
if opcode_bit_pos + b < len(bits_arr):
bits_arr[opcode_bit_pos + b] = (opcode_byte >> b) & 1
nibbles = [sum(bits_arr[i + j] << j for j in range(4) if i + j < len(bits_arr)) for i in range(0, len(bits_arr), 4)]
while len(nibbles) % 2: nibbles.append(0)
return bytes(nibbles[i] | (nibbles[i + 1] << 4) for i in range(0, len(nibbles), 2))

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python3
"""SQTT InstOp discovery tool - finds instruction opcodes by running different instructions.
Run with: DEBUG=1 python extra/assembly/amd/test/discover_instops.py
For full traces: DEBUG=2 python extra/assembly/amd/test/discover_instops.py
"""
import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_ITRACE_SE_MASK"] = "2" # Enable instruction tracing on SE1
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
from tinygrad.helpers import DEBUG, colored
from tinygrad.runtime.ops_amd import SQTT_SIMD_SEL
from extra.assembly.amd.autogen.rdna3.ins import (
# VALU - basic (these are safe, just register ops)
v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32,
v_and_b32_e32, v_or_b32_e32, v_xor_b32_e32,
v_lshlrev_b32_e32, v_lshrrev_b32_e32,
# VALU - transcendental
v_exp_f32_e32, v_log_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32,
# VALU - 64-bit
v_lshlrev_b64, v_lshrrev_b64,
# VALU - compare (writes to VCC, safe)
v_cmp_eq_u32_e32,
v_cmpx_eq_u32_e32,
# SALU - basic (safe, just register ops)
s_mov_b32, s_add_u32, s_and_b32, s_or_b32,
s_lshl_b32, s_lshr_b32,
s_nop, s_endpgm,
# SALU - branch (safe if offset is 0 = next instruction)
s_branch, s_cbranch_scc0, s_cbranch_execz,
# SALU - message
s_sendmsg,
)
from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.sqtt import InstOp, INST
from extra.assembly.amd.test.test_sqtt_hw import (
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs
)
# ═══════════════════════════════════════════════════════════════════════════════
# INSTRUCTION TEST CASES - only safe instructions that don't access memory
# ═══════════════════════════════════════════════════════════════════════════════
INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
# SALU (0x0) - scalar ALU, just register operations
"SALU_mov": ("s_mov_b32", [s_mov_b32(s[0], 0), s_mov_b32(s[1], 1)]),
"SALU_add": ("s_add_u32", [s_mov_b32(s[0], 1), s_mov_b32(s[1], 2), s_add_u32(s[2], s[0], s[1])]),
"SALU_logic": ("s_and/or", [s_and_b32(s[2], s[0], s[1]), s_or_b32(s[3], s[0], s[1])]),
"SALU_shift": ("s_lshl/lshr", [s_lshl_b32(s[2], s[0], 1), s_lshr_b32(s[3], s[0], 1)]),
"SALU_nop": ("s_nop", [s_nop(0)]),
# JUMP (0x3) - branch to next instruction (offset 0)
"JUMP_branch": ("s_branch", [s_branch(0)]),
# VALU (0xb) - vector ALU, just register operations
"VALU_mov": ("v_mov_b32", [v_mov_b32_e32(v[0], 0), v_mov_b32_e32(v[1], 1.0)]),
"VALU_add": ("v_add_f32", [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0), v_add_f32_e32(v[2], v[0], v[1])]),
"VALU_mul": ("v_mul_f32", [v_mul_f32_e32(v[2], v[0], v[1])]),
"VALU_logic": ("v_and/or/xor", [v_and_b32_e32(v[2], v[0], v[1]), v_or_b32_e32(v[3], v[0], v[1]), v_xor_b32_e32(v[4], v[0], v[1])]),
"VALU_shift": ("v_lshl/lshr", [v_lshlrev_b32_e32(v[2], 1, v[0]), v_lshrrev_b32_e32(v[3], 1, v[0])]),
# VALU transcendental - still just register ops
"VALU_exp": ("v_exp_f32", [v_mov_b32_e32(v[0], 1.0), v_exp_f32_e32(v[1], v[0])]),
"VALU_log": ("v_log_f32", [v_mov_b32_e32(v[0], 1.0), v_log_f32_e32(v[1], v[0])]),
"VALU_rcp": ("v_rcp_f32", [v_mov_b32_e32(v[0], 1.0), v_rcp_f32_e32(v[1], v[0])]),
"VALU_sqrt": ("v_sqrt_f32", [v_mov_b32_e32(v[0], 1.0), v_sqrt_f32_e32(v[1], v[0])]),
# VALU 64-bit (0xd)
"VALU64_lshl": ("v_lshlrev_b64", [v_lshlrev_b64(v[0:1], 1, v[2:3])]),
# VALU MAD64 (0xe) - commented out, needs proper clamp arg
# "VALU_mad64": ("v_mad_u64_u32", [v_mad_u64_u32(v[0:1], None, v[2], v[3], v[4:5])]),
# VALU compare - writes to VCC
"VALU_cmp": ("v_cmp_eq_u32", [v_cmp_eq_u32_e32(v[0], v[1])]),
# VALU CMPX (0x73) - modifies EXEC
"VALU_cmpx": ("v_cmpx_eq_u32", [v_cmpx_eq_u32_e32(v[0], v[1])]),
}
def run_with_simd_retry(instructions: list, max_retries: int = 4) -> tuple[list[bytes], list, set]:
"""Run instructions and retry with different SIMD selections until we get INST packets."""
for simd in range(max_retries):
SQTT_SIMD_SEL.value = simd
blobs = run_asm_sqtt(instructions)
packets = decode_all_blobs(blobs)
ops = get_inst_ops(packets)
if ops:
return blobs, packets, ops
# Return last attempt even if no ops found
return blobs, packets, ops
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
"""Run all instruction tests and collect InstOp values."""
discovered: dict[int, set[str]] = {}
failures: dict[str, Exception] = {}
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
try:
blobs, packets, ops = run_with_simd_retry(instructions)
for op in ops:
if op not in discovered:
discovered[op] = set()
discovered[op].add(f"{test_name}")
if DEBUG >= 2:
print(f"\n{''*60}")
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]} simd_sel={SQTT_SIMD_SEL.value}")
print_blobs(blobs, filter_timing=True)
if DEBUG >= 1:
status = colored("", "green") if ops else colored("", "yellow")
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
print(f" {status} {test_name:25s} ops=[{ops_str}]")
except Exception as e:
failures[test_name] = e
if DEBUG >= 1:
print(f" {colored('', 'red')} {test_name:25s} FAILED: {e}")
return discovered, failures
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None:
"""Print discovery summary."""
known_ops = {e.value for e in InstOp}
discovered_ops = set(discovered.keys())
print("\n" + "=" * 60)
print("DISCOVERED INSTOP VALUES")
print("=" * 60)
for op in sorted(discovered_ops):
try:
name = InstOp(op).name
status = colored("known", "green")
except ValueError:
name = f"UNKNOWN"
status = colored("NEW!", "yellow")
sources = ", ".join(sorted(discovered[op]))
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
# Missing from enum
missing = known_ops - discovered_ops
if missing:
print("\n" + "=" * 60)
print("ENUM VALUES NOT DISCOVERED")
print("=" * 60)
print("(need memory ops: SMEM, VMEM, LDS)")
for op in sorted(missing):
print(f" 0x{op:02x} {InstOp(op).name}")
# New values to add
new_ops = discovered_ops - known_ops
if new_ops:
print("\n" + "=" * 60)
print(colored("NEW INSTOP VALUES TO ADD TO ENUM", "yellow"))
print("=" * 60)
for op in sorted(new_ops):
sources = ", ".join(sorted(discovered[op]))
print(f" {op:#04x}: \"{sources}\",")
# Stats
print("\n" + "=" * 60)
print("STATISTICS")
print("=" * 60)
print(f" Tests run: {len(INSTRUCTION_TESTS)}")
print(f" Tests passed: {len(INSTRUCTION_TESTS) - len(failures)}")
print(f" Tests failed: {len(failures)}")
print(f" Known ops: {len(known_ops)}")
print(f" Discovered: {len(discovered_ops)}")
if known_ops:
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
print(f" New ops found: {len(new_ops)}")
if __name__ == "__main__":
print("=" * 60)
print("SQTT InstOp Discovery Tool")
print("=" * 60)
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
discovered, failures = discover_all_instops()
print_summary(discovered, failures)

View File

@@ -0,0 +1,79 @@
#!/usr/bin/env python3
"""Tests for SQTT packet codec (no hardware required)."""
import unittest
from extra.assembly.amd.sqtt import (
LAYOUT_HEADER, WAVESTART, WAVEEND, INST, NOP,
decode, encode, PACKET_TYPES, OPCODE_TO_CLASS
)
class TestSQTTCodec(unittest.TestCase):
"""Tests for SQTT encoder/decoder roundtrip."""
def test_roundtrip_simple(self):
"""Test encode/decode roundtrip for simple packets."""
test_packets = [
LAYOUT_HEADER.from_raw(0x100),
WAVESTART.from_raw(0x0),
INST.from_raw(0x10), # delta=1
INST.from_raw(0x10), # delta=1
WAVEEND.from_raw(0x40), # delta=2
]
encoded = encode(test_packets)
decoded = decode(encoded)
self.assertGreaterEqual(len(decoded), len(test_packets))
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
def test_decode_empty(self):
"""Test decoding empty data."""
packets = decode(b'')
self.assertEqual(packets, [])
def test_encode_empty(self):
"""Test encoding empty list."""
data = encode([])
self.assertEqual(data, b'')
def test_all_packet_types_have_encoding(self):
"""All packet types should have an encoding defined."""
for pkt_cls in PACKET_TYPES:
self.assertIsNotNone(pkt_cls._encoding, f"{pkt_cls.__name__} missing encoding")
def test_packet_from_raw(self):
"""Test creating packets from raw values."""
# INST with wave=5, op=0x21, delta=2
raw = (0x21 << 13) | (5 << 8) | (2 << 4) | 0b010
pkt = INST.from_raw(raw)
self.assertEqual(pkt.wave, 5)
self.assertEqual(pkt.op, 0x21)
self.assertEqual(pkt.delta, 2)
class TestDecodeRealBlob(unittest.TestCase):
"""Test decoding real SQTT blobs from examples."""
def test_decode_example_file(self):
"""Test decoding a real SQTT blob from examples."""
import pickle
from pathlib import Path
example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl"
if not example_path.exists():
self.skipTest(f"Example file not found: {example_path}")
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
with open(example_path, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)]
self.assertGreater(len(sqtt_events), 0, "No SQTT events in example")
packets = decode(sqtt_events[0].blob)
self.assertGreater(len(packets), 0, "No packets decoded")
# First packet should be LAYOUT_HEADER
self.assertIsInstance(packets[0], LAYOUT_HEADER)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""Tests comparing hardware SQTT traces to emulator SQTT output.
Run with: python -m pytest extra/assembly/amd/test/test_sqtt_compare.py -v
Requires AMD GPU with SQTT support.
"""
import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["AMD_LLVM"] = "0"
import unittest, sys, contextlib
from tinygrad import Tensor
from tinygrad.device import Device
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
from extra.assembly.amd.sqtt import decode, encode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, PacketType
# ═══════════════════════════════════════════════════════════════════════════════
# HARDWARE SQTT CAPTURE
# ═══════════════════════════════════════════════════════════════════════════════
dev = Device["AMD"]
def custom(arg: str, s: UOp | None = None) -> UOp:
return UOp(Ops.CUSTOM, src=(s,) if s is not None else (), arg=arg)
def asm_kernel(instrs: list[str], local_size: int = 1, global_size: int = 1) -> Tensor:
"""Create a kernel from inline assembly instructions."""
name = sys._getframe(1).f_code.co_name
def fxn(_):
L = UOp.special(local_size, "lidx0")
G = UOp.special(global_size, "gidx0")
op = custom("asm volatile (")
for inst in instrs:
op = custom(f' "{inst}\\n\\t"', op)
op = custom(");", op)
return UOp.sink(op, L, G, arg=KernelInfo(name=name))
return Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
@contextlib.contextmanager
def capture_hw_sqtt():
"""Capture raw SQTT blobs from hardware execution."""
dev.profile_events.clear()
result: dict[str, list[bytes]] = {}
yield result
for ev in dev.profile_events:
if isinstance(ev, ProfileSQTTEvent):
result.setdefault(ev.kern, []).append(ev.blob)
# ═══════════════════════════════════════════════════════════════════════════════
# TESTS
# ═══════════════════════════════════════════════════════════════════════════════
@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required")
class TestHardwareSQTT(unittest.TestCase):
"""Tests that verify hardware SQTT capture and parsing."""
def test_capture_and_parse(self):
"""Verify we can capture and parse SQTT for simple VALU instructions."""
with capture_hw_sqtt() as sqtt:
asm_kernel(["v_add_f32 v10, v10, v11", "v_add_f32 v11, v11, v12", "v_add_f32 v12, v12, v13"]).realize()
self.assertGreater(len(sqtt), 0, "No SQTT data captured")
kern = list(sqtt.keys())[0]
blob = sqtt[kern][0]
packets = decode(blob)
self.assertGreater(len(packets), 0, "No packets parsed")
pkt_types = {type(p) for p in packets}
# LAYOUT_HEADER should always be present
self.assertIn(LAYOUT_HEADER, pkt_types, "Missing LAYOUT_HEADER packet")
def test_print_raw_packets(self):
"""Debug test to print raw SQTT packets."""
with capture_hw_sqtt() as sqtt:
asm_kernel([
"v_mov_b32 v1, 1.0",
"v_add_f32 v2, v1, v1",
"s_nop 0",
"v_mul_f32 v3, v2, v2",
]).realize()
kern = list(sqtt.keys())[0]
blob = sqtt[kern][0]
packets = decode(blob)
print(f"\n=== Raw SQTT packets for {kern} ({len(blob)} bytes, {len(packets)} packets) ===")
for i, p in enumerate(packets):
delta = p.delta if p.delta is not None else 0
print(f" [{i:3d}] time={p._time:6d} delta={delta:4d} {type(p).__name__:18s} raw=0x{p._raw:x}")
def test_valu_timing(self):
"""Check VALU instruction timing from raw packets."""
with capture_hw_sqtt() as sqtt:
asm_kernel([
"v_add_f32 v10, v10, v11",
"v_add_f32 v11, v11, v12",
"v_add_f32 v12, v12, v13",
]).realize()
kern = list(sqtt.keys())[0]
packets = decode(sqtt[kern][0])
inst_packets = [p for p in packets if isinstance(p, INST)]
print(f"\n=== INST packets ===")
for i, p in enumerate(inst_packets):
print(f" [{i}] time={p._time} delta={p.delta} raw=0x{p._raw:x}")
class TestSQTTCodec(unittest.TestCase):
"""Tests for SQTT encoder/decoder roundtrip."""
def test_roundtrip_simple(self):
"""Test encode/decode roundtrip for simple packets."""
test_packets = [
LAYOUT_HEADER.from_raw(0x100),
WAVESTART.from_raw(0x0),
INST.from_raw(0x10), # delta=1
INST.from_raw(0x10), # delta=1
WAVEEND.from_raw(0x40), # delta=2
]
encoded = encode(test_packets)
decoded = decode(encoded)
self.assertGreaterEqual(len(decoded), len(test_packets))
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}: {orig} vs {dec}")
def test_decode_real_blob(self):
"""Test decoding a real SQTT blob from examples."""
import pickle
from pathlib import Path
example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl"
if not example_path.exists():
self.skipTest(f"Example file not found: {example_path}")
with open(example_path, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)]
self.assertGreater(len(sqtt_events), 0, "No SQTT events in example")
packets = decode(sqtt_events[0].blob)
self.assertGreater(len(packets), 0, "No packets decoded")
# Should see common packet types
pkt_types = {type(p) for p in packets}
self.assertIn(LAYOUT_HEADER, pkt_types)
@unittest.skip("Emulator SQTT not yet implemented")
class TestEmulatorSQTT(unittest.TestCase):
"""Tests comparing emulator SQTT to hardware SQTT."""
def test_simple_valu_match(self):
"""Simple VALU chain should produce matching SQTT packets."""
pass
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,276 @@
#!/usr/bin/env python3
"""Hardware tests for SQTT decoder - validates decoding of real SQTT streams.
Run with: python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
Requires AMD GPU with SQTT support.
For pretty trace output: DEBUG=2 python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
"""
import os
os.environ["SQTT"] = "1"
os.environ["PROFILE"] = "1"
os.environ["SQTT_ITRACE_SE_MASK"] = "2" # Enable instruction tracing on SE1
os.environ["SQTT_LIMIT_SE"] = "1" # Limit execution
import unittest
from tinygrad.helpers import DEBUG, colored
from tinygrad.device import Device
from tinygrad.runtime.ops_amd import AMDProgram, ProfileSQTTEvent
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32, s_mov_b32, s_add_u32, s_nop, s_endpgm
from extra.assembly.amd.dsl import v, s
from extra.assembly.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, ALUEXEC, VMEMEXEC, InstOp, AluSrc, MemSrc
dev = Device["AMD"]
# ═══════════════════════════════════════════════════════════════════════════════
# PRETTY PRINTING
# ═══════════════════════════════════════════════════════════════════════════════
PACKET_COLORS = {
"INST": "WHITE", "VALUINST": "WHITE",
"ALUEXEC": "yellow", "VMEMEXEC": "yellow",
"WAVESTART": "blue", "WAVEEND": "blue", "WAVEALLOC": "cyan", "WAVERDY": "cyan",
"LAYOUT_HEADER": "magenta",
"NOP": "BLACK", "TS_DELTA_SHORT": "BLACK", "TS_WAVE_STATE": "BLACK",
"TS_DELTA_OR_MARK": "BLACK", "TS_DELTA_S5_W2": "BLACK", "TS_DELTA_S5_W3": "BLACK", "TS_DELTA_S8_W3": "BLACK",
"REG": "green", "EVENT": "red", "EVENT_BIG": "red", "SNAPSHOT": "green", "UTILCTR": "green", "PERF": "green",
"IMMEDIATE": "cyan", "IMMEDIATE_MASK": "cyan",
}
def format_packet(p, last_time: int = 0) -> str:
"""Format a packet for pretty printing."""
name = type(p).__name__
color = PACKET_COLORS.get(name, "white")
delta = p._time - last_time
fields = []
if isinstance(p, INST):
op = p.op
op_name = op.name if isinstance(op, InstOp) else f"0x{op:02x}"
fields = [f"wave={p.wave}", f"op={op_name}"]
if p.flag1: fields.append("flag1")
if p.flag2: fields.append("flag2")
elif isinstance(p, VALUINST):
fields = [f"wave={p.wave}"]
if p.flag: fields.append("flag")
elif isinstance(p, ALUEXEC):
src_name = p.src.name if isinstance(p.src, AluSrc) else f"{p.src}"
fields = [f"src={src_name}"]
elif isinstance(p, VMEMEXEC):
src_name = p.src.name if isinstance(p.src, MemSrc) else f"{p.src}"
fields = [f"src={src_name}"]
elif isinstance(p, WAVESTART):
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
elif isinstance(p, WAVEEND):
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
elif hasattr(p, '_values'):
fields = [f"{k}={v}" for k, v in p._values.items() if not k.startswith('_') and k != 'delta']
return f"{p._time:8d} +{delta:6d} : " + colored(f"{name:18s}", color) + f" {', '.join(fields)}"
def print_trace(packets: list, filter_timing: bool = True) -> None:
"""Print a pretty trace of a single blob's packets."""
last_time = 0
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "UTILCTR"}
for p in packets:
name = type(p).__name__
if filter_timing and name in skip_types:
last_time = p._time
continue
print(format_packet(p, last_time))
last_time = p._time
def print_blobs(blobs: list[bytes], filter_timing: bool = True) -> None:
"""Print traces for all blobs."""
for i, blob in enumerate(blobs):
packets = decode(blob)
print(f"\n--- Blob {i}: {len(blob)} bytes, {len(packets)} packets ---")
print_trace(packets, filter_timing)
# ═══════════════════════════════════════════════════════════════════════════════
# ASSEMBLY HELPERS
# ═══════════════════════════════════════════════════════════════════════════════
def assemble(instructions: list) -> bytes:
return b''.join(inst.to_bytes() for inst in instructions)
def run_asm_sqtt(instructions: list, n_lanes: int = 1) -> list[bytes]:
"""Run instructions on AMD hardware and return SQTT blobs."""
compiler = HIPCompiler(dev.arch)
instructions = instructions + [s_endpgm()]
code = assemble(instructions)
byte_str = ', '.join(f'0x{b:02x}' for b in code)
asm_src = f""".text
.globl test
.p2align 8
.type test,@function
test:
.byte {byte_str}
.rodata
.p2align 6
.amdhsa_kernel test
.amdhsa_next_free_vgpr 256
.amdhsa_next_free_sgpr 96
.amdhsa_wavefront_size32 1
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_kernarg_size 8
.amdhsa_group_segment_fixed_size 65536
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: test
.symbol: test.kd
.kernarg_segment_size: 8
.group_segment_fixed_size: 65536
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.wavefront_size: 32
.sgpr_count: 96
.vgpr_count: 256
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata
"""
lib = compiler.compile(asm_src)
prg = AMDProgram(dev, "test", lib)
out_gpu = dev.allocator.alloc(2048)
dev.profile_events.clear()
prg(out_gpu, global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True)
return [ev.blob for ev in dev.profile_events if isinstance(ev, ProfileSQTTEvent)]
def decode_all_blobs(blobs: list[bytes]) -> list:
"""Decode all blobs and combine packets."""
all_packets = []
for blob in blobs:
all_packets.extend(decode(blob))
return all_packets
def get_inst_ops(packets: list) -> set:
"""Extract all InstOp values from INST packets."""
ops = set()
for p in packets:
if isinstance(p, INST):
ops.add(p.op if isinstance(p.op, int) else p.op.value)
return ops
# ═══════════════════════════════════════════════════════════════════════════════
# TESTS
# ═══════════════════════════════════════════════════════════════════════════════
@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required")
class TestSQTTDecode(unittest.TestCase):
"""Test SQTT decoder with real hardware traces."""
def test_basic_structure(self):
"""Verify basic SQTT stream structure: LAYOUT_HEADER, WAVESTART, instructions, WAVEEND."""
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
self.assertGreater(len(packets), 0, "No packets decoded")
self.assertGreater(len([p for p in packets if isinstance(p, LAYOUT_HEADER)]), 0, "No LAYOUT_HEADER packets")
self.assertGreater(len([p for p in packets if isinstance(p, WAVESTART)]), 0, "No WAVESTART packets")
self.assertGreater(len([p for p in packets if isinstance(p, WAVEEND)]), 0, "No WAVEEND packets")
if DEBUG >= 2:
print("\n=== Basic structure trace ===")
print_trace(packets)
def test_valu_instructions(self):
"""Verify VALU instructions produce INST or VALUINST packets."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 2.0),
v_add_f32_e32(v[2], v[0], v[1]),
v_add_f32_e32(v[3], v[2], v[1]),
v_mul_f32_e32(v[4], v[2], v[3]),
]
blobs = run_asm_sqtt(instructions)
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
inst_packets = [p for p in packets if isinstance(p, (INST, VALUINST))]
self.assertGreater(len(inst_packets), 0, "No INST/VALUINST packets for VALU instructions")
if DEBUG >= 2:
print("\n=== VALU instructions trace ===")
print_trace(packets)
def test_salu_instructions(self):
"""Verify SALU instructions produce appropriate packets."""
instructions = [
s_mov_b32(s[0], 0),
s_mov_b32(s[1], 1),
s_add_u32(s[2], s[0], s[1]),
s_add_u32(s[3], s[2], s[1]),
s_nop(0),
]
blobs = run_asm_sqtt(instructions)
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
if DEBUG >= 2:
print("\n=== SALU instructions trace ===")
print_trace(packets)
def test_timing_increases(self):
"""Verify time increases monotonically through packets within each blob."""
instructions = [
v_mov_b32_e32(v[0], 1.0),
v_mov_b32_e32(v[1], 2.0),
v_add_f32_e32(v[2], v[0], v[1]),
v_mul_f32_e32(v[3], v[2], v[1]),
]
blobs = run_asm_sqtt(instructions)
self.assertGreater(len(blobs), 0, "No SQTT data captured")
for blob in blobs:
packets = decode(blob)
prev_time = 0
for p in packets:
self.assertGreaterEqual(p._time, prev_time, f"Time decreased: {prev_time} -> {p._time}")
prev_time = p._time
def test_wave_id_consistency(self):
"""Verify wave IDs are consistent between WAVESTART/WAVEEND."""
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
wavestarts = [p for p in packets if isinstance(p, WAVESTART)]
waveends = [p for p in packets if isinstance(p, WAVEEND)]
if wavestarts and waveends:
start_waves = {p.wave for p in wavestarts}
end_waves = {p.wave for p in waveends}
self.assertTrue(start_waves & end_waves, "No matching wave IDs between WAVESTART and WAVEEND")
def test_nop_sequence(self):
"""Test a sequence of NOP instructions."""
blobs = run_asm_sqtt([s_nop(0), s_nop(0), s_nop(0)])
self.assertGreater(len(blobs), 0, "No SQTT data captured")
packets = decode_all_blobs(blobs)
self.assertGreater(len(packets), 0, "No packets decoded")
if DEBUG >= 2:
print("\n=== NOP sequence trace ===")
print_trace(packets, filter_timing=False)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,204 @@
#!/usr/bin/env python3
"""Tests validating SQTT packet definitions against the reference implementation.
Verifies that:
1. Encoding patterns produce the correct STATE_TO_OPCODE table
2. Packet sizes (derived from fields) match expected budget values
3. Field extractions match attempt_sqtt_parse.py
"""
import unittest
from extra.assembly.amd.sqtt import (
VALUINST, VMEMEXEC, ALUEXEC, IMMEDIATE, IMMEDIATE_MASK, WAVERDY,
WAVEEND, WAVESTART, PERF, TS_WAVE_STATE, EVENT, EVENT_BIG, REG, SNAPSHOT,
TS_DELTA_OR_MARK, LAYOUT_HEADER, INST, UTILCTR, TS_DELTA_SHORT, NOP,
TS_DELTA_S8_W3, TS_DELTA_S5_W2, TS_DELTA_S5_W3, WAVEALLOC,
decode, encode, OPCODE_TO_CLASS, STATE_TO_OPCODE, PACKET_TYPES, BUDGET,
AluSrc, MemSrc, InstOp
)
# Reference table from rocprof trace decoder (attempt_sqtt_parse.py)
REFERENCE_STATE_TABLE = bytes([
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x12, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
0x10, 0x13, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
])
# Reference opcode -> name mapping (old opcode values from rocprof)
OLD_OPCODE_TO_NAME = {
0x01: 'VALUINST', 0x02: 'VMEMEXEC', 0x03: 'ALUEXEC', 0x04: 'IMMEDIATE',
0x05: 'IMMEDIATE_MASK', 0x06: 'WAVERDY', 0x07: 'TS_DELTA_S8_W3',
0x08: 'WAVEEND', 0x09: 'WAVESTART', 0x0A: 'TS_DELTA_S5_W2',
0x0B: 'WAVEALLOC', 0x0C: 'TS_DELTA_S5_W3', 0x0D: 'PERF',
0x0F: 'TS_DELTA_SHORT', 0x10: 'NOP', 0x11: 'TS_WAVE_STATE',
0x12: 'EVENT', 0x13: 'EVENT_BIG', 0x14: 'REG', 0x15: 'SNAPSHOT',
0x16: 'TS_DELTA_OR_MARK', 0x17: 'LAYOUT_HEADER', 0x18: 'INST',
0x19: 'UTILCTR', 0x00: 'NOP',
}
# Reference budget values (nibbles for NEXT packet) from rocprof
REFERENCE_BUDGET_NIBBLES = {
'VALUINST': 3, 'VMEMEXEC': 2, 'ALUEXEC': 2, 'IMMEDIATE': 3,
'IMMEDIATE_MASK': 6, 'WAVERDY': 6, 'TS_DELTA_S8_W3': 16,
'WAVEEND': 5, 'WAVESTART': 8, 'TS_DELTA_S5_W2': 12,
'WAVEALLOC': 5, 'TS_DELTA_S5_W3': 13, 'PERF': 7,
'TS_DELTA_SHORT': 2, 'NOP': 1, 'TS_WAVE_STATE': 6,
'EVENT': 6, 'EVENT_BIG': 8, 'REG': 16, 'SNAPSHOT': 16,
'TS_DELTA_OR_MARK': 12, 'LAYOUT_HEADER': 16, 'INST': 5,
'UTILCTR': 12,
}
class TestEncodingsMatchStateTable(unittest.TestCase):
"""Verify encoding patterns produce the correct state decode table."""
def test_all_256_bytes_decode_correctly(self):
"""Each byte value should decode to the same packet type as reference."""
mismatches = []
for byte_val in range(256):
ref_opcode = REFERENCE_STATE_TABLE[byte_val]
ref_name = OLD_OPCODE_TO_NAME.get(ref_opcode, f"UNK_{ref_opcode:02x}")
our_opcode = STATE_TO_OPCODE[byte_val]
our_name = OPCODE_TO_CLASS[our_opcode].__name__
if ref_name != our_name:
mismatches.append((byte_val, ref_name, our_name))
if mismatches:
msg = "\n".join(f" 0x{b:02x}: expected {r}, got {o}" for b, r, o in mismatches[:10])
self.fail(f"State table mismatches ({len(mismatches)} total):\n{msg}")
class TestPacketSizesMatchBudget(unittest.TestCase):
"""Verify packet sizes (from field definitions) match expected budget values."""
def test_all_packet_sizes(self):
"""Each packet type's size should match the reference budget."""
for pkt_cls in PACKET_TYPES:
name = pkt_cls.__name__
expected = REFERENCE_BUDGET_NIBBLES.get(name)
if expected is None:
continue
actual = pkt_cls.size_nibbles()
self.assertEqual(expected, actual,
f"{name}: expected {expected} nibbles, got {actual} (size_bits={pkt_cls.size_bits()})")
class TestFieldExtraction(unittest.TestCase):
"""Test that field values are extracted correctly."""
def test_valuinst(self):
reg = 0b11110_1_001_011 # wave=0x1E, flag=1, delta=1
pkt = VALUINST.from_raw(reg)
self.assertEqual(pkt.delta, 1)
self.assertEqual(pkt.flag, 1)
self.assertEqual(pkt.wave, 0x1E)
def test_vmemexec_enum(self):
reg = 0b11_00_1111 # src=3 (VMEM_ALT), delta=0
pkt = VMEMEXEC.from_raw(reg)
self.assertEqual(pkt.src, MemSrc.VMEM_ALT)
def test_aluexec_enum(self):
reg = 0b10_01_1110 # src=2 (VALU), delta=1
pkt = ALUEXEC.from_raw(reg)
self.assertEqual(pkt.src, AluSrc.VALU)
def test_waveend(self):
reg = (0x15 << 15) | (0x7 << 11) | (0x3 << 9) | (1 << 8) | 0b10101
pkt = WAVEEND.from_raw(reg)
self.assertEqual(pkt.flag7, 1)
self.assertEqual(pkt.simd, 3)
self.assertEqual(pkt.cu_lo, 7)
self.assertEqual(pkt.wave, 0x15)
self.assertEqual(pkt.cu, 0xF) # cu_lo | (flag7 << 3) = 7 | 8 = 15
def test_wavestart(self):
reg = (0x7F << 18) | (0x15 << 13) | (0x7 << 10) | (0x3 << 8) | (1 << 7) | 0b01100
pkt = WAVESTART.from_raw(reg)
self.assertEqual(pkt.flag7, 1)
self.assertEqual(pkt.simd, 3)
self.assertEqual(pkt.cu_lo, 7)
self.assertEqual(pkt.wave, 0x15)
self.assertEqual(pkt.id7, 0x7F)
self.assertEqual(pkt.cu, 0xF)
def test_inst_enum(self):
reg = (0x21 << 13) | (0x15 << 8) | (1 << 7) | (1 << 3) | 0b010
pkt = INST.from_raw(reg)
self.assertEqual(pkt.flag1, 1)
self.assertEqual(pkt.flag2, 1)
self.assertEqual(pkt.wave, 0x15)
self.assertEqual(pkt.op, InstOp.VMEM_LOAD)
def test_layout_header(self):
reg = (0b101 << 33) | (0b1010 << 28) | (0b111 << 15) | (0b11 << 13) | (0b101010 << 7) | 0b0010001
pkt = LAYOUT_HEADER.from_raw(reg)
self.assertEqual(pkt.layout, 0b101010)
self.assertEqual(pkt.simd, 0b11)
self.assertEqual(pkt.group, 0b111)
self.assertEqual(pkt.sel_a, 0b1010)
self.assertEqual(pkt.sel_b, 0b101)
def test_ts_delta_or_mark_modes(self):
# delta mode: bit9=0, bit8=0
pkt_delta = TS_DELTA_OR_MARK.from_raw(0b0000001) # just the encoding pattern
self.assertFalse(pkt_delta.is_marker)
# marker mode: bit9=1, bit8=0
pkt_marker = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 9)) # bit9=1, bit8=0
self.assertTrue(pkt_marker.is_marker)
# other mode: bit9=1, bit8=1 (not marker)
pkt_other = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 8) | (1 << 9))
self.assertFalse(pkt_other.is_marker)
def test_reg(self):
# REG fields: slot=bits[9:7], hi_byte=bits[15:8], subop=bits[31:16], val32=bits[63:32]
# Note: slot[2:1] overlaps with hi_byte[1:0], so we need to set them consistently
# hi_byte=0x55 means bits 8-15 = 0b01010101, so slot bits 8-9 = 0b01
# slot bit 7 = 1, so slot = 0b011 = 3
reg = (0xDEADBEEF << 32) | (0xCAFE << 16) | (0x55 << 8) | (1 << 7) | 0b1001
pkt = REG.from_raw(reg)
self.assertEqual(pkt.slot, 0b011) # bit7=1, bits 8-9 from hi_byte low 2 bits = 01
self.assertEqual(pkt.hi_byte, 0x55)
self.assertEqual(pkt.subop, 0xCAFE)
self.assertEqual(pkt.val32, 0xDEADBEEF)
class TestRoundtrip(unittest.TestCase):
"""Test encode/decode roundtrip."""
def test_simple_roundtrip(self):
"""Test encode/decode roundtrip preserves packet types."""
test_packets = [
LAYOUT_HEADER.from_raw(0x100),
WAVESTART.from_raw(0x0),
INST.from_raw(0x10),
INST.from_raw(0x10),
WAVEEND.from_raw(0x40),
]
encoded = encode(test_packets)
decoded = decode(encoded)
self.assertGreaterEqual(len(decoded), len(test_packets))
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
if __name__ == "__main__":
unittest.main()

View File

@@ -21,7 +21,7 @@ from tinygrad.runtime.support.memory import AddrSpace
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0)
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
@@ -251,14 +251,15 @@ class AMDComputeQueue(HWQueue):
else:
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12)
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo)
# NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa.
# NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa.
# For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se,
# and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but
# sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and
# be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the
# CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE<x> and trace even kernels that only have one wavefront.
# Use SQTT_SIMD_SEL (0-3) to select which SIMD to trace within the WGP.
cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0)
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0)
reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \
self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT
token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0