mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
assembly/amd: start work on SQTT parsing/emulation
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import ctypes
|
||||
from enum import IntEnum
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.pcode import Reg
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
@@ -90,6 +91,164 @@ class LDSMem:
|
||||
|
||||
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TIMING MODEL - Instruction latencies and register dependency tracking
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class InstType(IntEnum):
|
||||
"""Instruction type for timing model."""
|
||||
SALU = 0 # Scalar ALU
|
||||
VALU = 1 # Vector ALU (simple)
|
||||
TRANS32 = 2 # Transcendental (v_rcp, v_rsq, v_sqrt, v_log, v_exp, v_sin, v_cos)
|
||||
SMEM = 3 # Scalar memory
|
||||
VMEM = 4 # Vector memory (global/scratch)
|
||||
LDS = 5 # LDS operations
|
||||
BRANCH = 6 # Branch/control flow
|
||||
NOP = 7 # No operation
|
||||
OTHER = 8 # Other
|
||||
|
||||
# Latencies from s_delay_alu encoding (VALU_DEP_1-4, TRANS32_DEP_1-3, SALU_CYCLE_1-3)
|
||||
INST_LATENCY = {
|
||||
InstType.SALU: 1, # SALU_CYCLE_1
|
||||
InstType.VALU: 1, # VALU_DEP_1 (with forwarding, no stall)
|
||||
InstType.TRANS32: 4, # TRANS32_DEP_3 (conservative)
|
||||
InstType.SMEM: 4, # Variable, use small fixed value
|
||||
InstType.VMEM: 1, # Variable, skip for now (issue immediately)
|
||||
InstType.LDS: 1, # Variable, skip for now
|
||||
InstType.BRANCH: 1,
|
||||
InstType.NOP: 1,
|
||||
InstType.OTHER: 1,
|
||||
}
|
||||
|
||||
# Transcendental ops - higher latency
|
||||
_TRANS_OPS = {'V_RCP_F32', 'V_RCP_F64', 'V_RSQ_F32', 'V_RSQ_F64', 'V_SQRT_F32', 'V_SQRT_F64',
|
||||
'V_LOG_F32', 'V_EXP_F32', 'V_SIN_F32', 'V_COS_F32', 'V_RCP_F16', 'V_RSQ_F16', 'V_SQRT_F16'}
|
||||
|
||||
def classify_inst(inst: Inst) -> InstType:
|
||||
"""Classify instruction for timing model."""
|
||||
if isinstance(inst, (SOP1, SOP2, SOPC, SOPK)):
|
||||
return InstType.SALU
|
||||
if isinstance(inst, SOPP):
|
||||
if inst.op in (SOPPOp.S_ENDPGM, SOPPOp.S_BARRIER): return InstType.OTHER
|
||||
if hasattr(inst, 'op') and 'BRANCH' in getattr(inst.op, 'name', ''): return InstType.BRANCH
|
||||
return InstType.NOP
|
||||
if isinstance(inst, SMEM):
|
||||
return InstType.SMEM
|
||||
if isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC)):
|
||||
op_name = inst.op_name if hasattr(inst, 'op_name') else ''
|
||||
if any(t in op_name for t in _TRANS_OPS):
|
||||
return InstType.TRANS32
|
||||
return InstType.VALU
|
||||
if isinstance(inst, VOPD):
|
||||
return InstType.VALU
|
||||
if isinstance(inst, FLAT):
|
||||
return InstType.VMEM
|
||||
if isinstance(inst, DS):
|
||||
return InstType.LDS
|
||||
return InstType.OTHER
|
||||
|
||||
class TimingState:
|
||||
"""Timing state for cycle-accurate emulation with SQTT output."""
|
||||
__slots__ = ('cycle', 'sgpr_ready', 'vgpr_ready', 'packets', 'wave_id', 'simd', 'cu')
|
||||
|
||||
def __init__(self, wave_id: int = 0, simd: int = 0, cu: int = 0):
|
||||
self.cycle = 0 # Current cycle
|
||||
self.sgpr_ready: dict[int, int] = {} # SGPR -> cycle when ready
|
||||
self.vgpr_ready: dict[int, int] = {} # VGPR -> cycle when ready (all lanes same for now)
|
||||
self.packets: list = [] # SQTT packets
|
||||
self.wave_id, self.simd, self.cu = wave_id, simd, cu
|
||||
|
||||
def stall_for_read(self, inst: Inst, st: 'WaveState') -> int:
|
||||
"""Calculate stall cycles needed before executing this instruction."""
|
||||
max_ready = self.cycle
|
||||
|
||||
# Get source registers
|
||||
srcs = []
|
||||
if isinstance(inst, (SOP1, SOP2, SOPC)):
|
||||
if hasattr(inst, 'ssrc0') and inst.ssrc0 < SGPR_COUNT: srcs.append(('s', inst.ssrc0))
|
||||
if hasattr(inst, 'ssrc1') and inst.ssrc1 < SGPR_COUNT: srcs.append(('s', inst.ssrc1))
|
||||
elif isinstance(inst, SOPK):
|
||||
if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT: srcs.append(('s', inst.sdst))
|
||||
elif isinstance(inst, SMEM):
|
||||
if hasattr(inst, 'sbase'): srcs.append(('s', inst.sbase * 2)); srcs.append(('s', inst.sbase * 2 + 1))
|
||||
if hasattr(inst, 'soffset') and inst.soffset < SGPR_COUNT: srcs.append(('s', inst.soffset))
|
||||
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOPC)):
|
||||
if hasattr(inst, 'src0'):
|
||||
if inst.src0 < SGPR_COUNT: srcs.append(('s', inst.src0))
|
||||
elif inst.src0 >= 256: srcs.append(('v', inst.src0 - 256))
|
||||
if hasattr(inst, 'src1'):
|
||||
if inst.src1 < SGPR_COUNT: srcs.append(('s', inst.src1))
|
||||
elif inst.src1 >= 256: srcs.append(('v', inst.src1 - 256))
|
||||
if hasattr(inst, 'src2'):
|
||||
if inst.src2 < SGPR_COUNT: srcs.append(('s', inst.src2))
|
||||
elif inst.src2 >= 256: srcs.append(('v', inst.src2 - 256))
|
||||
if hasattr(inst, 'vsrc1'): srcs.append(('v', inst.vsrc1))
|
||||
|
||||
# Find max ready time across sources
|
||||
for typ, reg in srcs:
|
||||
if typ == 's':
|
||||
ready = self.sgpr_ready.get(reg, 0)
|
||||
else:
|
||||
ready = self.vgpr_ready.get(reg, 0)
|
||||
if ready > max_ready:
|
||||
max_ready = ready
|
||||
|
||||
return max(0, max_ready - self.cycle)
|
||||
|
||||
def schedule_write(self, inst: Inst, latency: int):
|
||||
"""Record when destination registers will be ready."""
|
||||
ready_cycle = self.cycle + latency
|
||||
|
||||
# Get destination registers
|
||||
if isinstance(inst, (SOP1, SOP2, SOPK)):
|
||||
if hasattr(inst, 'sdst') and inst.sdst < SGPR_COUNT:
|
||||
self.sgpr_ready[inst.sdst] = ready_cycle
|
||||
if inst.dst_regs() == 2:
|
||||
self.sgpr_ready[inst.sdst + 1] = ready_cycle
|
||||
elif isinstance(inst, SMEM):
|
||||
if hasattr(inst, 'sdata'):
|
||||
cnt = SMEM_LOAD.get(inst.op, 1)
|
||||
for i in range(cnt):
|
||||
self.sgpr_ready[inst.sdata + i] = ready_cycle
|
||||
elif isinstance(inst, (VOP1, VOP2, VOP3, VOP3SD, VOPC, VOPD)):
|
||||
if hasattr(inst, 'vdst'):
|
||||
self.vgpr_ready[inst.vdst] = ready_cycle
|
||||
if inst.dst_regs() == 2:
|
||||
self.vgpr_ready[inst.vdst + 1] = ready_cycle
|
||||
elif isinstance(inst, FLAT):
|
||||
if 'LOAD' in inst.op.name and hasattr(inst, 'vdst'):
|
||||
ndwords = _op_ndwords(inst.op.name)
|
||||
for i in range(ndwords):
|
||||
self.vgpr_ready[inst.vdst + i] = ready_cycle
|
||||
|
||||
def emit_wavestart(self):
|
||||
"""Emit WAVESTART packet."""
|
||||
from extra.assembly.amd.sqtt import Packet, Op
|
||||
fields = (self.wave_id & 0xF) << 8 | (self.simd & 0x3) << 12
|
||||
self.packets.append(Packet(opcode=Op.WAVESTART, delta=0, fields=fields, time=self.cycle))
|
||||
|
||||
def emit_inst(self, inst: Inst, pc: int, stall: int, dur: int):
|
||||
"""Emit INST packet for instruction dispatch."""
|
||||
from extra.assembly.amd.sqtt import Packet, Op
|
||||
inst_type = classify_inst(inst)
|
||||
# Fields: wave[3:0] at [11:8], inst_type at higher bits
|
||||
fields = (self.wave_id & 0xF) << 8 | (inst_type & 0xFF) << 12
|
||||
delta = stall + dur
|
||||
self.packets.append(Packet(opcode=Op.INST, delta=delta, fields=fields, time=self.cycle))
|
||||
|
||||
def emit_waveend(self):
|
||||
"""Emit WAVEEND packet."""
|
||||
from extra.assembly.amd.sqtt import Packet, Op
|
||||
fields = (self.wave_id & 0xF) << 8 | (self.simd & 0x3) << 12
|
||||
self.packets.append(Packet(opcode=Op.WAVEEND, delta=0, fields=fields, time=self.cycle))
|
||||
|
||||
def to_sqtt_blob(self) -> bytes:
|
||||
"""Encode all packets to raw SQTT blob."""
|
||||
from extra.assembly.amd.sqtt import Packet, Op, encode
|
||||
# Prepend LAYOUT_HEADER
|
||||
header = Packet(opcode=Op.LAYOUT_HEADER, delta=0, fields=0x68900180, time=0)
|
||||
return encode([header] + self.packets)
|
||||
|
||||
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
|
||||
_VOPD_TO_VOP = {
|
||||
VOPDOp.V_DUAL_FMAC_F32: VOP3Op.V_FMAC_F32, VOPDOp.V_DUAL_FMAAK_F32: VOP2Op.V_FMAAK_F32, VOPDOp.V_DUAL_FMAMK_F32: VOP2Op.V_FMAMK_F32,
|
||||
|
||||
394
extra/assembly/amd/sqtt.py
Normal file
394
extra/assembly/amd/sqtt.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""SQTT (SQ Thread Trace) packet encoder and decoder for AMD GPUs.
|
||||
|
||||
This module provides encoding and decoding of raw SQTT byte streams.
|
||||
The format is nibble-based with variable-width packets determined by a state machine.
|
||||
Uses BitField infrastructure from dsl.py, similar to GPU instruction encoding.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from typing import get_type_hints
|
||||
from extra.assembly.amd.dsl import BitField, bits
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# FIELD ENUMS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class MemSrc(IntEnum):
|
||||
LDS = 0
|
||||
LDS_ALT = 1
|
||||
VMEM = 2
|
||||
VMEM_ALT = 3
|
||||
|
||||
class AluSrc(IntEnum):
|
||||
NONE = 0
|
||||
SALU = 1
|
||||
VALU = 2
|
||||
VALU_ALT = 3
|
||||
|
||||
class InstOp(IntEnum):
|
||||
SALU = 0x0
|
||||
SMEM = 0x1
|
||||
JUMP = 0x3
|
||||
NEXT = 0x4
|
||||
MESSAGE = 0x9
|
||||
VALU = 0xb
|
||||
VALU_64 = 0xd
|
||||
VALU_MAD64 = 0xe
|
||||
VMEM_LOAD = 0x21
|
||||
VMEM_LOAD_ALT = 0x22
|
||||
VMEM_STORE = 0x24
|
||||
VMEM_STORE_64 = 0x25
|
||||
VMEM_STORE_96 = 0x26
|
||||
VMEM_STORE_128 = 0x27
|
||||
VMEM_STORE_ALT = 0x28
|
||||
LDS_LOAD = 0x29
|
||||
LDS_STORE = 0x2b
|
||||
LDS_STORE_128 = 0x2e
|
||||
SIMD_LDS_LOAD = 0x50
|
||||
SIMD_LDS_LOAD_ALT = 0x51
|
||||
SIMD_LDS_STORE = 0x54
|
||||
SIMD_VMEM_LOAD = 0x5a
|
||||
SIMD_VMEM_LOAD_ALT = 0x5b
|
||||
SIMD_VMEM_STORE = 0x5c
|
||||
SIMD_VMEM_STORE_ALT = 0x5d
|
||||
SIMD_VMEM_STORE_96 = 0x5e
|
||||
SIMD_VMEM_STORE_128 = 0x5f
|
||||
SALU_OR = 0x72
|
||||
VALU_CMPX = 0x73
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PACKET TYPE BASE CLASS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class PacketType:
|
||||
"""Base class for SQTT packet types."""
|
||||
_encoding: tuple[BitField, int] | None = None
|
||||
_field_types: dict[str, type] = {}
|
||||
_values: dict[str, int]
|
||||
_raw: int
|
||||
_time: int
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if 'encoding' in cls.__dict__ and isinstance(cls.__dict__['encoding'], tuple):
|
||||
cls._encoding = cls.__dict__['encoding']
|
||||
# Cache field type annotations for enum conversion
|
||||
try:
|
||||
cls._field_types = {k: v for k, v in get_type_hints(cls).items()
|
||||
if isinstance(v, type) and issubclass(v, IntEnum)}
|
||||
except Exception:
|
||||
cls._field_types = {}
|
||||
|
||||
@classmethod
|
||||
def fields(cls) -> dict[str, BitField]:
|
||||
return {k: v for k, v in cls.__dict__.items() if isinstance(v, BitField) and k != 'encoding'}
|
||||
|
||||
@classmethod
|
||||
def size_bits(cls) -> int:
|
||||
max_bit = max((f.hi for f in cls.fields().values()), default=0)
|
||||
return ((max_bit + 4) // 4) * 4
|
||||
|
||||
@classmethod
|
||||
def size_nibbles(cls) -> int:
|
||||
return cls.size_bits() // 4
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: int, time: int = 0):
|
||||
inst = object.__new__(cls)
|
||||
inst._raw = raw
|
||||
inst._time = time
|
||||
inst._values = {}
|
||||
for name, bf in cls.fields().items():
|
||||
val = (raw >> bf.lo) & bf.mask()
|
||||
# Convert to enum if annotated
|
||||
enum_type = cls._field_types.get(name)
|
||||
if enum_type is not None:
|
||||
try: val = enum_type(val)
|
||||
except ValueError: pass
|
||||
inst._values[name] = val
|
||||
return inst
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name.startswith('_'): raise AttributeError(name)
|
||||
return self._values.get(name, 0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
fields_str = ", ".join(f"{k}={v}" for k, v in self._values.items() if not k.startswith('_'))
|
||||
return f"{self.__class__.__name__}({fields_str})"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PACKET TYPE DEFINITIONS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class VALUINST(PacketType):
|
||||
encoding = bits[2:0] == 0b011
|
||||
delta = bits[5:3]
|
||||
flag = bits[6:6]
|
||||
wave = bits[11:7]
|
||||
|
||||
class VMEMEXEC(PacketType):
|
||||
encoding = bits[3:0] == 0b1111
|
||||
delta = bits[5:4]
|
||||
src: MemSrc = bits[7:6]
|
||||
|
||||
class ALUEXEC(PacketType):
|
||||
encoding = bits[3:0] == 0b1110
|
||||
delta = bits[5:4]
|
||||
src: AluSrc = bits[7:6]
|
||||
|
||||
class IMMEDIATE(PacketType):
|
||||
encoding = bits[3:0] == 0b1101
|
||||
delta = bits[6:4]
|
||||
wave = bits[11:7]
|
||||
|
||||
class IMMEDIATE_MASK(PacketType):
|
||||
encoding = bits[4:0] == 0b00100
|
||||
delta = bits[7:5]
|
||||
mask = bits[23:8]
|
||||
|
||||
class WAVERDY(PacketType):
|
||||
encoding = bits[4:0] == 0b10100
|
||||
delta = bits[7:5]
|
||||
mask = bits[23:8]
|
||||
|
||||
class TS_DELTA_S8_W3(PacketType):
|
||||
encoding = bits[6:0] == 0b0100001
|
||||
delta = bits[10:8]
|
||||
_padding = bits[63:11]
|
||||
|
||||
class WAVEEND(PacketType):
|
||||
encoding = bits[4:0] == 0b10101
|
||||
delta = bits[7:5]
|
||||
flag7 = bits[8:8]
|
||||
simd = bits[10:9]
|
||||
cu_lo = bits[13:11]
|
||||
wave = bits[19:15]
|
||||
@property
|
||||
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
|
||||
|
||||
class WAVESTART(PacketType):
|
||||
encoding = bits[4:0] == 0b01100
|
||||
delta = bits[6:5]
|
||||
flag7 = bits[7:7]
|
||||
simd = bits[9:8]
|
||||
cu_lo = bits[12:10]
|
||||
wave = bits[17:13]
|
||||
id7 = bits[31:18]
|
||||
@property
|
||||
def cu(self) -> int: return self.cu_lo | (self.flag7 << 3)
|
||||
|
||||
class TS_DELTA_S5_W2(PacketType):
|
||||
encoding = bits[4:0] == 0b11100
|
||||
delta = bits[6:5]
|
||||
_padding = bits[47:7]
|
||||
|
||||
class WAVEALLOC(PacketType):
|
||||
encoding = bits[4:0] == 0b00101
|
||||
delta = bits[7:5]
|
||||
_padding = bits[19:8]
|
||||
|
||||
class TS_DELTA_S5_W3(PacketType):
|
||||
encoding = bits[4:0] == 0b00110
|
||||
delta = bits[7:5]
|
||||
_padding = bits[51:8]
|
||||
|
||||
class PERF(PacketType):
|
||||
encoding = bits[4:0] == 0b10110
|
||||
delta = bits[7:5]
|
||||
arg = bits[27:8]
|
||||
|
||||
class TS_DELTA_SHORT(PacketType):
|
||||
encoding = bits[3:0] == 0b1000
|
||||
delta = bits[7:4]
|
||||
|
||||
class NOP(PacketType):
|
||||
encoding = bits[3:0] == 0b0000
|
||||
delta = None # type: ignore
|
||||
_padding = bits[3:0]
|
||||
|
||||
class TS_WAVE_STATE(PacketType):
|
||||
encoding = bits[6:0] == 0b1010001
|
||||
delta = bits[15:7]
|
||||
coarse = bits[23:16]
|
||||
@property
|
||||
def wave_interest(self) -> bool: return bool(self.coarse & 1)
|
||||
@property
|
||||
def terminate_all(self) -> bool: return bool(self.coarse & 8)
|
||||
|
||||
class EVENT(PacketType):
|
||||
encoding = bits[7:0] == 0b01100001
|
||||
delta = bits[10:8]
|
||||
event = bits[23:11]
|
||||
|
||||
class EVENT_BIG(PacketType):
|
||||
encoding = bits[7:0] == 0b11100001
|
||||
delta = bits[10:8]
|
||||
event = bits[31:11]
|
||||
|
||||
class REG(PacketType):
|
||||
encoding = bits[3:0] == 0b1001
|
||||
delta = bits[6:4]
|
||||
slot = bits[9:7]
|
||||
hi_byte = bits[15:8]
|
||||
subop = bits[31:16]
|
||||
val32 = bits[63:32]
|
||||
@property
|
||||
def is_config(self) -> bool: return bool(self.hi_byte & 0x80)
|
||||
|
||||
class SNAPSHOT(PacketType):
|
||||
encoding = bits[6:0] == 0b1110001
|
||||
delta = bits[9:7]
|
||||
snap = bits[63:10]
|
||||
|
||||
class TS_DELTA_OR_MARK(PacketType):
|
||||
encoding = bits[6:0] == 0b0000001
|
||||
delta = bits[47:12]
|
||||
bit8 = bits[8:8]
|
||||
bit9 = bits[9:9]
|
||||
@property
|
||||
def is_marker(self) -> bool: return bool(self.bit9 and not self.bit8)
|
||||
|
||||
class LAYOUT_HEADER(PacketType):
|
||||
encoding = bits[6:0] == 0b0010001
|
||||
delta = None # type: ignore
|
||||
layout = bits[12:7]
|
||||
simd = bits[14:13]
|
||||
group = bits[17:15]
|
||||
sel_a = bits[31:28]
|
||||
sel_b = bits[36:33]
|
||||
flag4 = bits[59:59]
|
||||
_padding = bits[63:60]
|
||||
|
||||
class INST(PacketType):
|
||||
encoding = bits[2:0] == 0b010
|
||||
delta = bits[6:4]
|
||||
flag1 = bits[3:3]
|
||||
flag2 = bits[7:7]
|
||||
wave = bits[12:8]
|
||||
op: InstOp = bits[19:13]
|
||||
|
||||
class UTILCTR(PacketType):
|
||||
encoding = bits[6:0] == 0b0110001
|
||||
delta = bits[8:7]
|
||||
ctr = bits[47:9]
|
||||
|
||||
# All packet types in encoding priority order (more specific masks first, NOP last as fallback)
|
||||
PACKET_TYPES: list[type[PacketType]] = [
|
||||
EVENT, EVENT_BIG,
|
||||
TS_DELTA_S8_W3, TS_WAVE_STATE, SNAPSHOT, TS_DELTA_OR_MARK, LAYOUT_HEADER, UTILCTR,
|
||||
IMMEDIATE_MASK, WAVERDY, WAVEEND, WAVESTART, TS_DELTA_S5_W2, WAVEALLOC, TS_DELTA_S5_W3, PERF,
|
||||
VMEMEXEC, ALUEXEC, IMMEDIATE, TS_DELTA_SHORT, REG,
|
||||
VALUINST, INST,
|
||||
NOP,
|
||||
]
|
||||
|
||||
PACKET_BY_NAME: dict[str, type[PacketType]] = {cls.__name__: cls for cls in PACKET_TYPES}
|
||||
|
||||
def _build_state_table() -> tuple[bytes, dict[int, type[PacketType]]]:
|
||||
table = [len(PACKET_TYPES) - 1] * 256 # default to NOP
|
||||
opcode_to_class: dict[int, type[PacketType]] = {i: cls for i, cls in enumerate(PACKET_TYPES)}
|
||||
|
||||
for byte_val in range(256):
|
||||
for opcode, pkt_cls in enumerate(PACKET_TYPES):
|
||||
if pkt_cls._encoding is None: continue
|
||||
mask_bf, pattern = pkt_cls._encoding
|
||||
if (byte_val & mask_bf.mask()) == pattern:
|
||||
table[byte_val] = opcode
|
||||
break
|
||||
|
||||
return bytes(table), opcode_to_class
|
||||
|
||||
STATE_TO_OPCODE, OPCODE_TO_CLASS = _build_state_table()
|
||||
|
||||
OPCODE_TO_BYTES: dict[int, list[int]] = {}
|
||||
for _byte_val, _opcode in enumerate(STATE_TO_OPCODE):
|
||||
if _opcode not in OPCODE_TO_BYTES: OPCODE_TO_BYTES[_opcode] = []
|
||||
OPCODE_TO_BYTES[_opcode].append(_byte_val)
|
||||
|
||||
BUDGET = {opcode: pkt_cls.size_nibbles() for opcode, pkt_cls in OPCODE_TO_CLASS.items()}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# DECODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def decode(data: bytes) -> list[PacketType]:
|
||||
"""Decode raw SQTT blob into list of packet instances."""
|
||||
packets: list[PacketType] = []
|
||||
n = len(data)
|
||||
reg = 0
|
||||
offset = 0
|
||||
nib_count = 16
|
||||
time = 0
|
||||
|
||||
while (offset >> 3) < n:
|
||||
target = offset + nib_count * 4
|
||||
while offset < target and (offset >> 3) < n:
|
||||
byte = data[offset >> 3]
|
||||
nib = (byte >> (offset & 4)) & 0xF
|
||||
reg = ((reg >> 4) | (nib << 60)) & ((1 << 64) - 1)
|
||||
offset += 4
|
||||
if offset < target: break
|
||||
|
||||
opcode = STATE_TO_OPCODE[reg & 0xFF]
|
||||
pkt_cls = OPCODE_TO_CLASS[opcode]
|
||||
nib_count = BUDGET[opcode]
|
||||
|
||||
delta_field = getattr(pkt_cls, 'delta', None)
|
||||
delta = (reg >> delta_field.lo) & delta_field.mask() if delta_field is not None else 0
|
||||
|
||||
if pkt_cls is TS_DELTA_OR_MARK:
|
||||
bit8 = (reg >> TS_DELTA_OR_MARK.bit8.lo) & TS_DELTA_OR_MARK.bit8.mask()
|
||||
bit9 = (reg >> TS_DELTA_OR_MARK.bit9.lo) & TS_DELTA_OR_MARK.bit9.mask()
|
||||
if bit9 and not bit8: delta = 0
|
||||
elif pkt_cls is TS_DELTA_SHORT:
|
||||
delta = delta + 8
|
||||
|
||||
time += delta
|
||||
pkt = pkt_cls.from_raw(reg, time)
|
||||
packets.append(pkt)
|
||||
|
||||
return packets
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ENCODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def encode(packets: list[PacketType]) -> bytes:
|
||||
"""Encode a list of packet instances into raw SQTT blob."""
|
||||
if not packets: return b''
|
||||
|
||||
read_lengths = [16]
|
||||
for p in packets[:-1]:
|
||||
read_lengths.append(p.size_nibbles())
|
||||
|
||||
total_nibbles = sum(read_lengths)
|
||||
bits_arr = [0] * (total_nibbles * 4)
|
||||
|
||||
cumulative = 0
|
||||
for i, p in enumerate(packets):
|
||||
cumulative += read_lengths[i]
|
||||
pkt_cls = type(p)
|
||||
opcode = next(op for op, cls in OPCODE_TO_CLASS.items() if cls is pkt_cls)
|
||||
|
||||
byte_vals = OPCODE_TO_BYTES.get(opcode)
|
||||
if not byte_vals: raise ValueError(f"No encoding for {pkt_cls.__name__}")
|
||||
opcode_byte = byte_vals[0]
|
||||
|
||||
delta_field = getattr(pkt_cls, 'delta', None)
|
||||
if delta_field is not None and delta_field.hi < 8:
|
||||
delta = p._values.get('delta', 0)
|
||||
if isinstance(delta, IntEnum): delta = delta.value
|
||||
if pkt_cls is TS_DELTA_SHORT: delta = max(0, delta - 8)
|
||||
delta = delta & delta_field.mask()
|
||||
opcode_byte = (opcode_byte & ~(delta_field.mask() << delta_field.lo)) | (delta << delta_field.lo)
|
||||
|
||||
opcode_nibble_pos = max(0, cumulative - 16)
|
||||
opcode_bit_pos = opcode_nibble_pos * 4
|
||||
|
||||
for b in range(8):
|
||||
if opcode_bit_pos + b < len(bits_arr):
|
||||
bits_arr[opcode_bit_pos + b] = (opcode_byte >> b) & 1
|
||||
|
||||
nibbles = [sum(bits_arr[i + j] << j for j in range(4) if i + j < len(bits_arr)) for i in range(0, len(bits_arr), 4)]
|
||||
while len(nibbles) % 2: nibbles.append(0)
|
||||
return bytes(nibbles[i] | (nibbles[i + 1] << 4) for i in range(0, len(nibbles), 2))
|
||||
190
extra/assembly/amd/test/discover_instops.py
Normal file
190
extra/assembly/amd/test/discover_instops.py
Normal file
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python3
|
||||
"""SQTT InstOp discovery tool - finds instruction opcodes by running different instructions.
|
||||
|
||||
Run with: DEBUG=1 python extra/assembly/amd/test/discover_instops.py
|
||||
For full traces: DEBUG=2 python extra/assembly/amd/test/discover_instops.py
|
||||
"""
|
||||
import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_ITRACE_SE_MASK"] = "2" # Enable instruction tracing on SE1
|
||||
os.environ["SQTT_LIMIT_SE"] = "2" # Force work to traced SE only
|
||||
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.runtime.ops_amd import SQTT_SIMD_SEL
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (
|
||||
# VALU - basic (these are safe, just register ops)
|
||||
v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32,
|
||||
v_and_b32_e32, v_or_b32_e32, v_xor_b32_e32,
|
||||
v_lshlrev_b32_e32, v_lshrrev_b32_e32,
|
||||
# VALU - transcendental
|
||||
v_exp_f32_e32, v_log_f32_e32, v_rcp_f32_e32, v_sqrt_f32_e32,
|
||||
# VALU - 64-bit
|
||||
v_lshlrev_b64, v_lshrrev_b64,
|
||||
# VALU - compare (writes to VCC, safe)
|
||||
v_cmp_eq_u32_e32,
|
||||
v_cmpx_eq_u32_e32,
|
||||
# SALU - basic (safe, just register ops)
|
||||
s_mov_b32, s_add_u32, s_and_b32, s_or_b32,
|
||||
s_lshl_b32, s_lshr_b32,
|
||||
s_nop, s_endpgm,
|
||||
# SALU - branch (safe if offset is 0 = next instruction)
|
||||
s_branch, s_cbranch_scc0, s_cbranch_execz,
|
||||
# SALU - message
|
||||
s_sendmsg,
|
||||
)
|
||||
from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.sqtt import InstOp, INST
|
||||
|
||||
from extra.assembly.amd.test.test_sqtt_hw import (
|
||||
run_asm_sqtt, decode_all_blobs, get_inst_ops, print_blobs
|
||||
)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# INSTRUCTION TEST CASES - only safe instructions that don't access memory
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
INSTRUCTION_TESTS: dict[str, tuple[str, list]] = {
|
||||
# SALU (0x0) - scalar ALU, just register operations
|
||||
"SALU_mov": ("s_mov_b32", [s_mov_b32(s[0], 0), s_mov_b32(s[1], 1)]),
|
||||
"SALU_add": ("s_add_u32", [s_mov_b32(s[0], 1), s_mov_b32(s[1], 2), s_add_u32(s[2], s[0], s[1])]),
|
||||
"SALU_logic": ("s_and/or", [s_and_b32(s[2], s[0], s[1]), s_or_b32(s[3], s[0], s[1])]),
|
||||
"SALU_shift": ("s_lshl/lshr", [s_lshl_b32(s[2], s[0], 1), s_lshr_b32(s[3], s[0], 1)]),
|
||||
"SALU_nop": ("s_nop", [s_nop(0)]),
|
||||
|
||||
# JUMP (0x3) - branch to next instruction (offset 0)
|
||||
"JUMP_branch": ("s_branch", [s_branch(0)]),
|
||||
|
||||
# VALU (0xb) - vector ALU, just register operations
|
||||
"VALU_mov": ("v_mov_b32", [v_mov_b32_e32(v[0], 0), v_mov_b32_e32(v[1], 1.0)]),
|
||||
"VALU_add": ("v_add_f32", [v_mov_b32_e32(v[0], 1.0), v_mov_b32_e32(v[1], 2.0), v_add_f32_e32(v[2], v[0], v[1])]),
|
||||
"VALU_mul": ("v_mul_f32", [v_mul_f32_e32(v[2], v[0], v[1])]),
|
||||
"VALU_logic": ("v_and/or/xor", [v_and_b32_e32(v[2], v[0], v[1]), v_or_b32_e32(v[3], v[0], v[1]), v_xor_b32_e32(v[4], v[0], v[1])]),
|
||||
"VALU_shift": ("v_lshl/lshr", [v_lshlrev_b32_e32(v[2], 1, v[0]), v_lshrrev_b32_e32(v[3], 1, v[0])]),
|
||||
|
||||
# VALU transcendental - still just register ops
|
||||
"VALU_exp": ("v_exp_f32", [v_mov_b32_e32(v[0], 1.0), v_exp_f32_e32(v[1], v[0])]),
|
||||
"VALU_log": ("v_log_f32", [v_mov_b32_e32(v[0], 1.0), v_log_f32_e32(v[1], v[0])]),
|
||||
"VALU_rcp": ("v_rcp_f32", [v_mov_b32_e32(v[0], 1.0), v_rcp_f32_e32(v[1], v[0])]),
|
||||
"VALU_sqrt": ("v_sqrt_f32", [v_mov_b32_e32(v[0], 1.0), v_sqrt_f32_e32(v[1], v[0])]),
|
||||
|
||||
# VALU 64-bit (0xd)
|
||||
"VALU64_lshl": ("v_lshlrev_b64", [v_lshlrev_b64(v[0:1], 1, v[2:3])]),
|
||||
|
||||
# VALU MAD64 (0xe) - commented out, needs proper clamp arg
|
||||
# "VALU_mad64": ("v_mad_u64_u32", [v_mad_u64_u32(v[0:1], None, v[2], v[3], v[4:5])]),
|
||||
|
||||
# VALU compare - writes to VCC
|
||||
"VALU_cmp": ("v_cmp_eq_u32", [v_cmp_eq_u32_e32(v[0], v[1])]),
|
||||
|
||||
# VALU CMPX (0x73) - modifies EXEC
|
||||
"VALU_cmpx": ("v_cmpx_eq_u32", [v_cmpx_eq_u32_e32(v[0], v[1])]),
|
||||
}
|
||||
|
||||
|
||||
def run_with_simd_retry(instructions: list, max_retries: int = 4) -> tuple[list[bytes], list, set]:
|
||||
"""Run instructions and retry with different SIMD selections until we get INST packets."""
|
||||
for simd in range(max_retries):
|
||||
SQTT_SIMD_SEL.value = simd
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
packets = decode_all_blobs(blobs)
|
||||
ops = get_inst_ops(packets)
|
||||
if ops:
|
||||
return blobs, packets, ops
|
||||
# Return last attempt even if no ops found
|
||||
return blobs, packets, ops
|
||||
|
||||
def discover_all_instops() -> tuple[dict[int, set[str]], dict[str, Exception]]:
|
||||
"""Run all instruction tests and collect InstOp values."""
|
||||
discovered: dict[int, set[str]] = {}
|
||||
failures: dict[str, Exception] = {}
|
||||
|
||||
for test_name, (instr_name, instructions) in INSTRUCTION_TESTS.items():
|
||||
try:
|
||||
blobs, packets, ops = run_with_simd_retry(instructions)
|
||||
|
||||
for op in ops:
|
||||
if op not in discovered:
|
||||
discovered[op] = set()
|
||||
discovered[op].add(f"{test_name}")
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(f"\n{'─'*60}")
|
||||
print(f"{test_name} ({instr_name}): ops={[hex(op) for op in sorted(ops)]} simd_sel={SQTT_SIMD_SEL.value}")
|
||||
print_blobs(blobs, filter_timing=True)
|
||||
if DEBUG >= 1:
|
||||
status = colored("✓", "green") if ops else colored("∅", "yellow")
|
||||
ops_str = ", ".join(hex(op) for op in sorted(ops)) if ops else "none"
|
||||
print(f" {status} {test_name:25s} ops=[{ops_str}]")
|
||||
|
||||
except Exception as e:
|
||||
failures[test_name] = e
|
||||
if DEBUG >= 1:
|
||||
print(f" {colored('✗', 'red')} {test_name:25s} FAILED: {e}")
|
||||
|
||||
return discovered, failures
|
||||
|
||||
|
||||
def print_summary(discovered: dict[int, set[str]], failures: dict[str, Exception]) -> None:
|
||||
"""Print discovery summary."""
|
||||
known_ops = {e.value for e in InstOp}
|
||||
discovered_ops = set(discovered.keys())
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("DISCOVERED INSTOP VALUES")
|
||||
print("=" * 60)
|
||||
|
||||
for op in sorted(discovered_ops):
|
||||
try:
|
||||
name = InstOp(op).name
|
||||
status = colored("known", "green")
|
||||
except ValueError:
|
||||
name = f"UNKNOWN"
|
||||
status = colored("NEW!", "yellow")
|
||||
|
||||
sources = ", ".join(sorted(discovered[op]))
|
||||
print(f" 0x{op:02x} {name:20s} ({status}) <- {sources}")
|
||||
|
||||
# Missing from enum
|
||||
missing = known_ops - discovered_ops
|
||||
if missing:
|
||||
print("\n" + "=" * 60)
|
||||
print("ENUM VALUES NOT DISCOVERED")
|
||||
print("=" * 60)
|
||||
print("(need memory ops: SMEM, VMEM, LDS)")
|
||||
for op in sorted(missing):
|
||||
print(f" 0x{op:02x} {InstOp(op).name}")
|
||||
|
||||
# New values to add
|
||||
new_ops = discovered_ops - known_ops
|
||||
if new_ops:
|
||||
print("\n" + "=" * 60)
|
||||
print(colored("NEW INSTOP VALUES TO ADD TO ENUM", "yellow"))
|
||||
print("=" * 60)
|
||||
for op in sorted(new_ops):
|
||||
sources = ", ".join(sorted(discovered[op]))
|
||||
print(f" {op:#04x}: \"{sources}\",")
|
||||
|
||||
# Stats
|
||||
print("\n" + "=" * 60)
|
||||
print("STATISTICS")
|
||||
print("=" * 60)
|
||||
print(f" Tests run: {len(INSTRUCTION_TESTS)}")
|
||||
print(f" Tests passed: {len(INSTRUCTION_TESTS) - len(failures)}")
|
||||
print(f" Tests failed: {len(failures)}")
|
||||
print(f" Known ops: {len(known_ops)}")
|
||||
print(f" Discovered: {len(discovered_ops)}")
|
||||
if known_ops:
|
||||
print(f" Coverage: {len(discovered_ops & known_ops)}/{len(known_ops)} ({100*len(discovered_ops & known_ops)//len(known_ops)}%)")
|
||||
print(f" New ops found: {len(new_ops)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("SQTT InstOp Discovery Tool")
|
||||
print("=" * 60)
|
||||
print(f"Testing {len(INSTRUCTION_TESTS)} instruction categories...\n")
|
||||
|
||||
discovered, failures = discover_all_instops()
|
||||
print_summary(discovered, failures)
|
||||
79
extra/assembly/amd/test/test_sqtt.py
Normal file
79
extra/assembly/amd/test/test_sqtt.py
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for SQTT packet codec (no hardware required)."""
|
||||
import unittest
|
||||
from extra.assembly.amd.sqtt import (
|
||||
LAYOUT_HEADER, WAVESTART, WAVEEND, INST, NOP,
|
||||
decode, encode, PACKET_TYPES, OPCODE_TO_CLASS
|
||||
)
|
||||
|
||||
|
||||
class TestSQTTCodec(unittest.TestCase):
|
||||
"""Tests for SQTT encoder/decoder roundtrip."""
|
||||
|
||||
def test_roundtrip_simple(self):
|
||||
"""Test encode/decode roundtrip for simple packets."""
|
||||
test_packets = [
|
||||
LAYOUT_HEADER.from_raw(0x100),
|
||||
WAVESTART.from_raw(0x0),
|
||||
INST.from_raw(0x10), # delta=1
|
||||
INST.from_raw(0x10), # delta=1
|
||||
WAVEEND.from_raw(0x40), # delta=2
|
||||
]
|
||||
encoded = encode(test_packets)
|
||||
decoded = decode(encoded)
|
||||
|
||||
self.assertGreaterEqual(len(decoded), len(test_packets))
|
||||
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
|
||||
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
|
||||
|
||||
def test_decode_empty(self):
|
||||
"""Test decoding empty data."""
|
||||
packets = decode(b'')
|
||||
self.assertEqual(packets, [])
|
||||
|
||||
def test_encode_empty(self):
|
||||
"""Test encoding empty list."""
|
||||
data = encode([])
|
||||
self.assertEqual(data, b'')
|
||||
|
||||
def test_all_packet_types_have_encoding(self):
|
||||
"""All packet types should have an encoding defined."""
|
||||
for pkt_cls in PACKET_TYPES:
|
||||
self.assertIsNotNone(pkt_cls._encoding, f"{pkt_cls.__name__} missing encoding")
|
||||
|
||||
def test_packet_from_raw(self):
|
||||
"""Test creating packets from raw values."""
|
||||
# INST with wave=5, op=0x21, delta=2
|
||||
raw = (0x21 << 13) | (5 << 8) | (2 << 4) | 0b010
|
||||
pkt = INST.from_raw(raw)
|
||||
self.assertEqual(pkt.wave, 5)
|
||||
self.assertEqual(pkt.op, 0x21)
|
||||
self.assertEqual(pkt.delta, 2)
|
||||
|
||||
|
||||
class TestDecodeRealBlob(unittest.TestCase):
|
||||
"""Test decoding real SQTT blobs from examples."""
|
||||
|
||||
def test_decode_example_file(self):
|
||||
"""Test decoding a real SQTT blob from examples."""
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl"
|
||||
if not example_path.exists():
|
||||
self.skipTest(f"Example file not found: {example_path}")
|
||||
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
with open(example_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)]
|
||||
self.assertGreater(len(sqtt_events), 0, "No SQTT events in example")
|
||||
|
||||
packets = decode(sqtt_events[0].blob)
|
||||
self.assertGreater(len(packets), 0, "No packets decoded")
|
||||
# First packet should be LAYOUT_HEADER
|
||||
self.assertIsInstance(packets[0], LAYOUT_HEADER)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
156
extra/assembly/amd/test/test_sqtt_compare.py
Normal file
156
extra/assembly/amd/test/test_sqtt_compare.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests comparing hardware SQTT traces to emulator SQTT output.
|
||||
|
||||
Run with: python -m pytest extra/assembly/amd/test/test_sqtt_compare.py -v
|
||||
Requires AMD GPU with SQTT support.
|
||||
"""
|
||||
import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["AMD_LLVM"] = "0"
|
||||
|
||||
import unittest, sys, contextlib
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
|
||||
from extra.assembly.amd.sqtt import decode, encode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, PacketType
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# HARDWARE SQTT CAPTURE
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
dev = Device["AMD"]
|
||||
|
||||
def custom(arg: str, s: UOp | None = None) -> UOp:
|
||||
return UOp(Ops.CUSTOM, src=(s,) if s is not None else (), arg=arg)
|
||||
|
||||
def asm_kernel(instrs: list[str], local_size: int = 1, global_size: int = 1) -> Tensor:
|
||||
"""Create a kernel from inline assembly instructions."""
|
||||
name = sys._getframe(1).f_code.co_name
|
||||
def fxn(_):
|
||||
L = UOp.special(local_size, "lidx0")
|
||||
G = UOp.special(global_size, "gidx0")
|
||||
op = custom("asm volatile (")
|
||||
for inst in instrs:
|
||||
op = custom(f' "{inst}\\n\\t"', op)
|
||||
op = custom(");", op)
|
||||
return UOp.sink(op, L, G, arg=KernelInfo(name=name))
|
||||
return Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_hw_sqtt():
|
||||
"""Capture raw SQTT blobs from hardware execution."""
|
||||
dev.profile_events.clear()
|
||||
result: dict[str, list[bytes]] = {}
|
||||
yield result
|
||||
for ev in dev.profile_events:
|
||||
if isinstance(ev, ProfileSQTTEvent):
|
||||
result.setdefault(ev.kern, []).append(ev.blob)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TESTS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required")
|
||||
class TestHardwareSQTT(unittest.TestCase):
|
||||
"""Tests that verify hardware SQTT capture and parsing."""
|
||||
|
||||
def test_capture_and_parse(self):
|
||||
"""Verify we can capture and parse SQTT for simple VALU instructions."""
|
||||
with capture_hw_sqtt() as sqtt:
|
||||
asm_kernel(["v_add_f32 v10, v10, v11", "v_add_f32 v11, v11, v12", "v_add_f32 v12, v12, v13"]).realize()
|
||||
self.assertGreater(len(sqtt), 0, "No SQTT data captured")
|
||||
kern = list(sqtt.keys())[0]
|
||||
blob = sqtt[kern][0]
|
||||
packets = decode(blob)
|
||||
self.assertGreater(len(packets), 0, "No packets parsed")
|
||||
pkt_types = {type(p) for p in packets}
|
||||
# LAYOUT_HEADER should always be present
|
||||
self.assertIn(LAYOUT_HEADER, pkt_types, "Missing LAYOUT_HEADER packet")
|
||||
|
||||
def test_print_raw_packets(self):
|
||||
"""Debug test to print raw SQTT packets."""
|
||||
with capture_hw_sqtt() as sqtt:
|
||||
asm_kernel([
|
||||
"v_mov_b32 v1, 1.0",
|
||||
"v_add_f32 v2, v1, v1",
|
||||
"s_nop 0",
|
||||
"v_mul_f32 v3, v2, v2",
|
||||
]).realize()
|
||||
kern = list(sqtt.keys())[0]
|
||||
blob = sqtt[kern][0]
|
||||
packets = decode(blob)
|
||||
print(f"\n=== Raw SQTT packets for {kern} ({len(blob)} bytes, {len(packets)} packets) ===")
|
||||
for i, p in enumerate(packets):
|
||||
delta = p.delta if p.delta is not None else 0
|
||||
print(f" [{i:3d}] time={p._time:6d} delta={delta:4d} {type(p).__name__:18s} raw=0x{p._raw:x}")
|
||||
|
||||
def test_valu_timing(self):
|
||||
"""Check VALU instruction timing from raw packets."""
|
||||
with capture_hw_sqtt() as sqtt:
|
||||
asm_kernel([
|
||||
"v_add_f32 v10, v10, v11",
|
||||
"v_add_f32 v11, v11, v12",
|
||||
"v_add_f32 v12, v12, v13",
|
||||
]).realize()
|
||||
kern = list(sqtt.keys())[0]
|
||||
packets = decode(sqtt[kern][0])
|
||||
inst_packets = [p for p in packets if isinstance(p, INST)]
|
||||
print(f"\n=== INST packets ===")
|
||||
for i, p in enumerate(inst_packets):
|
||||
print(f" [{i}] time={p._time} delta={p.delta} raw=0x{p._raw:x}")
|
||||
|
||||
|
||||
class TestSQTTCodec(unittest.TestCase):
|
||||
"""Tests for SQTT encoder/decoder roundtrip."""
|
||||
|
||||
def test_roundtrip_simple(self):
|
||||
"""Test encode/decode roundtrip for simple packets."""
|
||||
test_packets = [
|
||||
LAYOUT_HEADER.from_raw(0x100),
|
||||
WAVESTART.from_raw(0x0),
|
||||
INST.from_raw(0x10), # delta=1
|
||||
INST.from_raw(0x10), # delta=1
|
||||
WAVEEND.from_raw(0x40), # delta=2
|
||||
]
|
||||
encoded = encode(test_packets)
|
||||
decoded = decode(encoded)
|
||||
|
||||
self.assertGreaterEqual(len(decoded), len(test_packets))
|
||||
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
|
||||
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}: {orig} vs {dec}")
|
||||
|
||||
def test_decode_real_blob(self):
|
||||
"""Test decoding a real SQTT blob from examples."""
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
example_path = Path(__file__).parent.parent.parent.parent / "sqtt/examples/profile_plus_run_0.pkl"
|
||||
if not example_path.exists():
|
||||
self.skipTest(f"Example file not found: {example_path}")
|
||||
|
||||
with open(example_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
sqtt_events = [e for e in data if isinstance(e, ProfileSQTTEvent)]
|
||||
self.assertGreater(len(sqtt_events), 0, "No SQTT events in example")
|
||||
|
||||
packets = decode(sqtt_events[0].blob)
|
||||
self.assertGreater(len(packets), 0, "No packets decoded")
|
||||
# Should see common packet types
|
||||
pkt_types = {type(p) for p in packets}
|
||||
self.assertIn(LAYOUT_HEADER, pkt_types)
|
||||
|
||||
|
||||
@unittest.skip("Emulator SQTT not yet implemented")
|
||||
class TestEmulatorSQTT(unittest.TestCase):
|
||||
"""Tests comparing emulator SQTT to hardware SQTT."""
|
||||
|
||||
def test_simple_valu_match(self):
|
||||
"""Simple VALU chain should produce matching SQTT packets."""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
276
extra/assembly/amd/test/test_sqtt_hw.py
Normal file
276
extra/assembly/amd/test/test_sqtt_hw.py
Normal file
@@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Hardware tests for SQTT decoder - validates decoding of real SQTT streams.
|
||||
|
||||
Run with: python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
|
||||
Requires AMD GPU with SQTT support.
|
||||
|
||||
For pretty trace output: DEBUG=2 python -m pytest extra/assembly/amd/test/test_sqtt_hw.py -v -s
|
||||
"""
|
||||
import os
|
||||
os.environ["SQTT"] = "1"
|
||||
os.environ["PROFILE"] = "1"
|
||||
os.environ["SQTT_ITRACE_SE_MASK"] = "2" # Enable instruction tracing on SE1
|
||||
os.environ["SQTT_LIMIT_SE"] = "1" # Limit execution
|
||||
|
||||
import unittest
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.ops_amd import AMDProgram, ProfileSQTTEvent
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import v_mov_b32_e32, v_add_f32_e32, v_mul_f32_e32, s_mov_b32, s_add_u32, s_nop, s_endpgm
|
||||
from extra.assembly.amd.dsl import v, s
|
||||
from extra.assembly.amd.sqtt import decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, ALUEXEC, VMEMEXEC, InstOp, AluSrc, MemSrc
|
||||
|
||||
dev = Device["AMD"]
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PRETTY PRINTING
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
PACKET_COLORS = {
|
||||
"INST": "WHITE", "VALUINST": "WHITE",
|
||||
"ALUEXEC": "yellow", "VMEMEXEC": "yellow",
|
||||
"WAVESTART": "blue", "WAVEEND": "blue", "WAVEALLOC": "cyan", "WAVERDY": "cyan",
|
||||
"LAYOUT_HEADER": "magenta",
|
||||
"NOP": "BLACK", "TS_DELTA_SHORT": "BLACK", "TS_WAVE_STATE": "BLACK",
|
||||
"TS_DELTA_OR_MARK": "BLACK", "TS_DELTA_S5_W2": "BLACK", "TS_DELTA_S5_W3": "BLACK", "TS_DELTA_S8_W3": "BLACK",
|
||||
"REG": "green", "EVENT": "red", "EVENT_BIG": "red", "SNAPSHOT": "green", "UTILCTR": "green", "PERF": "green",
|
||||
"IMMEDIATE": "cyan", "IMMEDIATE_MASK": "cyan",
|
||||
}
|
||||
|
||||
def format_packet(p, last_time: int = 0) -> str:
|
||||
"""Format a packet for pretty printing."""
|
||||
name = type(p).__name__
|
||||
color = PACKET_COLORS.get(name, "white")
|
||||
delta = p._time - last_time
|
||||
|
||||
fields = []
|
||||
if isinstance(p, INST):
|
||||
op = p.op
|
||||
op_name = op.name if isinstance(op, InstOp) else f"0x{op:02x}"
|
||||
fields = [f"wave={p.wave}", f"op={op_name}"]
|
||||
if p.flag1: fields.append("flag1")
|
||||
if p.flag2: fields.append("flag2")
|
||||
elif isinstance(p, VALUINST):
|
||||
fields = [f"wave={p.wave}"]
|
||||
if p.flag: fields.append("flag")
|
||||
elif isinstance(p, ALUEXEC):
|
||||
src_name = p.src.name if isinstance(p.src, AluSrc) else f"{p.src}"
|
||||
fields = [f"src={src_name}"]
|
||||
elif isinstance(p, VMEMEXEC):
|
||||
src_name = p.src.name if isinstance(p.src, MemSrc) else f"{p.src}"
|
||||
fields = [f"src={src_name}"]
|
||||
elif isinstance(p, WAVESTART):
|
||||
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
|
||||
elif isinstance(p, WAVEEND):
|
||||
fields = [f"wave={p.wave}", f"simd={p.simd}", f"cu={p.cu}"]
|
||||
elif hasattr(p, '_values'):
|
||||
fields = [f"{k}={v}" for k, v in p._values.items() if not k.startswith('_') and k != 'delta']
|
||||
|
||||
return f"{p._time:8d} +{delta:6d} : " + colored(f"{name:18s}", color) + f" {', '.join(fields)}"
|
||||
|
||||
def print_trace(packets: list, filter_timing: bool = True) -> None:
|
||||
"""Print a pretty trace of a single blob's packets."""
|
||||
last_time = 0
|
||||
skip_types = {"NOP", "TS_DELTA_SHORT", "TS_WAVE_STATE", "TS_DELTA_OR_MARK", "TS_DELTA_S5_W2", "TS_DELTA_S5_W3", "TS_DELTA_S8_W3", "REG", "UTILCTR"}
|
||||
for p in packets:
|
||||
name = type(p).__name__
|
||||
if filter_timing and name in skip_types:
|
||||
last_time = p._time
|
||||
continue
|
||||
print(format_packet(p, last_time))
|
||||
last_time = p._time
|
||||
|
||||
def print_blobs(blobs: list[bytes], filter_timing: bool = True) -> None:
|
||||
"""Print traces for all blobs."""
|
||||
for i, blob in enumerate(blobs):
|
||||
packets = decode(blob)
|
||||
print(f"\n--- Blob {i}: {len(blob)} bytes, {len(packets)} packets ---")
|
||||
print_trace(packets, filter_timing)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ASSEMBLY HELPERS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def assemble(instructions: list) -> bytes:
|
||||
return b''.join(inst.to_bytes() for inst in instructions)
|
||||
|
||||
def run_asm_sqtt(instructions: list, n_lanes: int = 1) -> list[bytes]:
|
||||
"""Run instructions on AMD hardware and return SQTT blobs."""
|
||||
compiler = HIPCompiler(dev.arch)
|
||||
instructions = instructions + [s_endpgm()]
|
||||
code = assemble(instructions)
|
||||
|
||||
byte_str = ', '.join(f'0x{b:02x}' for b in code)
|
||||
asm_src = f""".text
|
||||
.globl test
|
||||
.p2align 8
|
||||
.type test,@function
|
||||
test:
|
||||
.byte {byte_str}
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel test
|
||||
.amdhsa_next_free_vgpr 256
|
||||
.amdhsa_next_free_sgpr 96
|
||||
.amdhsa_wavefront_size32 1
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_kernarg_size 8
|
||||
.amdhsa_group_segment_fixed_size 65536
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: test
|
||||
.symbol: test.kd
|
||||
.kernarg_segment_size: 8
|
||||
.group_segment_fixed_size: 65536
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 8
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 96
|
||||
.vgpr_count: 256
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
|
||||
lib = compiler.compile(asm_src)
|
||||
prg = AMDProgram(dev, "test", lib)
|
||||
out_gpu = dev.allocator.alloc(2048)
|
||||
dev.profile_events.clear()
|
||||
prg(out_gpu, global_size=(1, 1, 1), local_size=(n_lanes, 1, 1), wait=True)
|
||||
return [ev.blob for ev in dev.profile_events if isinstance(ev, ProfileSQTTEvent)]
|
||||
|
||||
def decode_all_blobs(blobs: list[bytes]) -> list:
|
||||
"""Decode all blobs and combine packets."""
|
||||
all_packets = []
|
||||
for blob in blobs:
|
||||
all_packets.extend(decode(blob))
|
||||
return all_packets
|
||||
|
||||
def get_inst_ops(packets: list) -> set:
|
||||
"""Extract all InstOp values from INST packets."""
|
||||
ops = set()
|
||||
for p in packets:
|
||||
if isinstance(p, INST):
|
||||
ops.add(p.op if isinstance(p.op, int) else p.op.value)
|
||||
return ops
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# TESTS
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@unittest.skipIf(not hasattr(dev, 'profile_events'), "AMD device required")
|
||||
class TestSQTTDecode(unittest.TestCase):
|
||||
"""Test SQTT decoder with real hardware traces."""
|
||||
|
||||
def test_basic_structure(self):
|
||||
"""Verify basic SQTT stream structure: LAYOUT_HEADER, WAVESTART, instructions, WAVEEND."""
|
||||
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
self.assertGreater(len(packets), 0, "No packets decoded")
|
||||
self.assertGreater(len([p for p in packets if isinstance(p, LAYOUT_HEADER)]), 0, "No LAYOUT_HEADER packets")
|
||||
self.assertGreater(len([p for p in packets if isinstance(p, WAVESTART)]), 0, "No WAVESTART packets")
|
||||
self.assertGreater(len([p for p in packets if isinstance(p, WAVEEND)]), 0, "No WAVEEND packets")
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== Basic structure trace ===")
|
||||
print_trace(packets)
|
||||
|
||||
def test_valu_instructions(self):
|
||||
"""Verify VALU instructions produce INST or VALUINST packets."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
v_add_f32_e32(v[2], v[0], v[1]),
|
||||
v_add_f32_e32(v[3], v[2], v[1]),
|
||||
v_mul_f32_e32(v[4], v[2], v[3]),
|
||||
]
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
inst_packets = [p for p in packets if isinstance(p, (INST, VALUINST))]
|
||||
self.assertGreater(len(inst_packets), 0, "No INST/VALUINST packets for VALU instructions")
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== VALU instructions trace ===")
|
||||
print_trace(packets)
|
||||
|
||||
def test_salu_instructions(self):
|
||||
"""Verify SALU instructions produce appropriate packets."""
|
||||
instructions = [
|
||||
s_mov_b32(s[0], 0),
|
||||
s_mov_b32(s[1], 1),
|
||||
s_add_u32(s[2], s[0], s[1]),
|
||||
s_add_u32(s[3], s[2], s[1]),
|
||||
s_nop(0),
|
||||
]
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== SALU instructions trace ===")
|
||||
print_trace(packets)
|
||||
|
||||
def test_timing_increases(self):
|
||||
"""Verify time increases monotonically through packets within each blob."""
|
||||
instructions = [
|
||||
v_mov_b32_e32(v[0], 1.0),
|
||||
v_mov_b32_e32(v[1], 2.0),
|
||||
v_add_f32_e32(v[2], v[0], v[1]),
|
||||
v_mul_f32_e32(v[3], v[2], v[1]),
|
||||
]
|
||||
blobs = run_asm_sqtt(instructions)
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
for blob in blobs:
|
||||
packets = decode(blob)
|
||||
prev_time = 0
|
||||
for p in packets:
|
||||
self.assertGreaterEqual(p._time, prev_time, f"Time decreased: {prev_time} -> {p._time}")
|
||||
prev_time = p._time
|
||||
|
||||
def test_wave_id_consistency(self):
|
||||
"""Verify wave IDs are consistent between WAVESTART/WAVEEND."""
|
||||
blobs = run_asm_sqtt([v_mov_b32_e32(v[0], 0)])
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
|
||||
wavestarts = [p for p in packets if isinstance(p, WAVESTART)]
|
||||
waveends = [p for p in packets if isinstance(p, WAVEEND)]
|
||||
|
||||
if wavestarts and waveends:
|
||||
start_waves = {p.wave for p in wavestarts}
|
||||
end_waves = {p.wave for p in waveends}
|
||||
self.assertTrue(start_waves & end_waves, "No matching wave IDs between WAVESTART and WAVEEND")
|
||||
|
||||
def test_nop_sequence(self):
|
||||
"""Test a sequence of NOP instructions."""
|
||||
blobs = run_asm_sqtt([s_nop(0), s_nop(0), s_nop(0)])
|
||||
|
||||
self.assertGreater(len(blobs), 0, "No SQTT data captured")
|
||||
packets = decode_all_blobs(blobs)
|
||||
self.assertGreater(len(packets), 0, "No packets decoded")
|
||||
|
||||
if DEBUG >= 2:
|
||||
print("\n=== NOP sequence trace ===")
|
||||
print_trace(packets, filter_timing=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
204
extra/assembly/amd/test/test_sqtt_ops.py
Normal file
204
extra/assembly/amd/test/test_sqtt_ops.py
Normal file
@@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests validating SQTT packet definitions against the reference implementation.
|
||||
|
||||
Verifies that:
|
||||
1. Encoding patterns produce the correct STATE_TO_OPCODE table
|
||||
2. Packet sizes (derived from fields) match expected budget values
|
||||
3. Field extractions match attempt_sqtt_parse.py
|
||||
"""
|
||||
import unittest
|
||||
from extra.assembly.amd.sqtt import (
|
||||
VALUINST, VMEMEXEC, ALUEXEC, IMMEDIATE, IMMEDIATE_MASK, WAVERDY,
|
||||
WAVEEND, WAVESTART, PERF, TS_WAVE_STATE, EVENT, EVENT_BIG, REG, SNAPSHOT,
|
||||
TS_DELTA_OR_MARK, LAYOUT_HEADER, INST, UTILCTR, TS_DELTA_SHORT, NOP,
|
||||
TS_DELTA_S8_W3, TS_DELTA_S5_W2, TS_DELTA_S5_W3, WAVEALLOC,
|
||||
decode, encode, OPCODE_TO_CLASS, STATE_TO_OPCODE, PACKET_TYPES, BUDGET,
|
||||
AluSrc, MemSrc, InstOp
|
||||
)
|
||||
|
||||
# Reference table from rocprof trace decoder (attempt_sqtt_parse.py)
|
||||
REFERENCE_STATE_TABLE = bytes([
|
||||
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x12, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
0x10, 0x13, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02,
|
||||
0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02,
|
||||
])
|
||||
|
||||
# Reference opcode -> name mapping (old opcode values from rocprof)
|
||||
OLD_OPCODE_TO_NAME = {
|
||||
0x01: 'VALUINST', 0x02: 'VMEMEXEC', 0x03: 'ALUEXEC', 0x04: 'IMMEDIATE',
|
||||
0x05: 'IMMEDIATE_MASK', 0x06: 'WAVERDY', 0x07: 'TS_DELTA_S8_W3',
|
||||
0x08: 'WAVEEND', 0x09: 'WAVESTART', 0x0A: 'TS_DELTA_S5_W2',
|
||||
0x0B: 'WAVEALLOC', 0x0C: 'TS_DELTA_S5_W3', 0x0D: 'PERF',
|
||||
0x0F: 'TS_DELTA_SHORT', 0x10: 'NOP', 0x11: 'TS_WAVE_STATE',
|
||||
0x12: 'EVENT', 0x13: 'EVENT_BIG', 0x14: 'REG', 0x15: 'SNAPSHOT',
|
||||
0x16: 'TS_DELTA_OR_MARK', 0x17: 'LAYOUT_HEADER', 0x18: 'INST',
|
||||
0x19: 'UTILCTR', 0x00: 'NOP',
|
||||
}
|
||||
|
||||
# Reference budget values (nibbles for NEXT packet) from rocprof
|
||||
REFERENCE_BUDGET_NIBBLES = {
|
||||
'VALUINST': 3, 'VMEMEXEC': 2, 'ALUEXEC': 2, 'IMMEDIATE': 3,
|
||||
'IMMEDIATE_MASK': 6, 'WAVERDY': 6, 'TS_DELTA_S8_W3': 16,
|
||||
'WAVEEND': 5, 'WAVESTART': 8, 'TS_DELTA_S5_W2': 12,
|
||||
'WAVEALLOC': 5, 'TS_DELTA_S5_W3': 13, 'PERF': 7,
|
||||
'TS_DELTA_SHORT': 2, 'NOP': 1, 'TS_WAVE_STATE': 6,
|
||||
'EVENT': 6, 'EVENT_BIG': 8, 'REG': 16, 'SNAPSHOT': 16,
|
||||
'TS_DELTA_OR_MARK': 12, 'LAYOUT_HEADER': 16, 'INST': 5,
|
||||
'UTILCTR': 12,
|
||||
}
|
||||
|
||||
|
||||
class TestEncodingsMatchStateTable(unittest.TestCase):
|
||||
"""Verify encoding patterns produce the correct state decode table."""
|
||||
|
||||
def test_all_256_bytes_decode_correctly(self):
|
||||
"""Each byte value should decode to the same packet type as reference."""
|
||||
mismatches = []
|
||||
for byte_val in range(256):
|
||||
ref_opcode = REFERENCE_STATE_TABLE[byte_val]
|
||||
ref_name = OLD_OPCODE_TO_NAME.get(ref_opcode, f"UNK_{ref_opcode:02x}")
|
||||
|
||||
our_opcode = STATE_TO_OPCODE[byte_val]
|
||||
our_name = OPCODE_TO_CLASS[our_opcode].__name__
|
||||
|
||||
if ref_name != our_name:
|
||||
mismatches.append((byte_val, ref_name, our_name))
|
||||
|
||||
if mismatches:
|
||||
msg = "\n".join(f" 0x{b:02x}: expected {r}, got {o}" for b, r, o in mismatches[:10])
|
||||
self.fail(f"State table mismatches ({len(mismatches)} total):\n{msg}")
|
||||
|
||||
|
||||
class TestPacketSizesMatchBudget(unittest.TestCase):
|
||||
"""Verify packet sizes (from field definitions) match expected budget values."""
|
||||
|
||||
def test_all_packet_sizes(self):
|
||||
"""Each packet type's size should match the reference budget."""
|
||||
for pkt_cls in PACKET_TYPES:
|
||||
name = pkt_cls.__name__
|
||||
expected = REFERENCE_BUDGET_NIBBLES.get(name)
|
||||
if expected is None:
|
||||
continue
|
||||
|
||||
actual = pkt_cls.size_nibbles()
|
||||
self.assertEqual(expected, actual,
|
||||
f"{name}: expected {expected} nibbles, got {actual} (size_bits={pkt_cls.size_bits()})")
|
||||
|
||||
|
||||
class TestFieldExtraction(unittest.TestCase):
|
||||
"""Test that field values are extracted correctly."""
|
||||
|
||||
def test_valuinst(self):
|
||||
reg = 0b11110_1_001_011 # wave=0x1E, flag=1, delta=1
|
||||
pkt = VALUINST.from_raw(reg)
|
||||
self.assertEqual(pkt.delta, 1)
|
||||
self.assertEqual(pkt.flag, 1)
|
||||
self.assertEqual(pkt.wave, 0x1E)
|
||||
|
||||
def test_vmemexec_enum(self):
|
||||
reg = 0b11_00_1111 # src=3 (VMEM_ALT), delta=0
|
||||
pkt = VMEMEXEC.from_raw(reg)
|
||||
self.assertEqual(pkt.src, MemSrc.VMEM_ALT)
|
||||
|
||||
def test_aluexec_enum(self):
|
||||
reg = 0b10_01_1110 # src=2 (VALU), delta=1
|
||||
pkt = ALUEXEC.from_raw(reg)
|
||||
self.assertEqual(pkt.src, AluSrc.VALU)
|
||||
|
||||
def test_waveend(self):
|
||||
reg = (0x15 << 15) | (0x7 << 11) | (0x3 << 9) | (1 << 8) | 0b10101
|
||||
pkt = WAVEEND.from_raw(reg)
|
||||
self.assertEqual(pkt.flag7, 1)
|
||||
self.assertEqual(pkt.simd, 3)
|
||||
self.assertEqual(pkt.cu_lo, 7)
|
||||
self.assertEqual(pkt.wave, 0x15)
|
||||
self.assertEqual(pkt.cu, 0xF) # cu_lo | (flag7 << 3) = 7 | 8 = 15
|
||||
|
||||
def test_wavestart(self):
|
||||
reg = (0x7F << 18) | (0x15 << 13) | (0x7 << 10) | (0x3 << 8) | (1 << 7) | 0b01100
|
||||
pkt = WAVESTART.from_raw(reg)
|
||||
self.assertEqual(pkt.flag7, 1)
|
||||
self.assertEqual(pkt.simd, 3)
|
||||
self.assertEqual(pkt.cu_lo, 7)
|
||||
self.assertEqual(pkt.wave, 0x15)
|
||||
self.assertEqual(pkt.id7, 0x7F)
|
||||
self.assertEqual(pkt.cu, 0xF)
|
||||
|
||||
def test_inst_enum(self):
|
||||
reg = (0x21 << 13) | (0x15 << 8) | (1 << 7) | (1 << 3) | 0b010
|
||||
pkt = INST.from_raw(reg)
|
||||
self.assertEqual(pkt.flag1, 1)
|
||||
self.assertEqual(pkt.flag2, 1)
|
||||
self.assertEqual(pkt.wave, 0x15)
|
||||
self.assertEqual(pkt.op, InstOp.VMEM_LOAD)
|
||||
|
||||
def test_layout_header(self):
|
||||
reg = (0b101 << 33) | (0b1010 << 28) | (0b111 << 15) | (0b11 << 13) | (0b101010 << 7) | 0b0010001
|
||||
pkt = LAYOUT_HEADER.from_raw(reg)
|
||||
self.assertEqual(pkt.layout, 0b101010)
|
||||
self.assertEqual(pkt.simd, 0b11)
|
||||
self.assertEqual(pkt.group, 0b111)
|
||||
self.assertEqual(pkt.sel_a, 0b1010)
|
||||
self.assertEqual(pkt.sel_b, 0b101)
|
||||
|
||||
def test_ts_delta_or_mark_modes(self):
|
||||
# delta mode: bit9=0, bit8=0
|
||||
pkt_delta = TS_DELTA_OR_MARK.from_raw(0b0000001) # just the encoding pattern
|
||||
self.assertFalse(pkt_delta.is_marker)
|
||||
|
||||
# marker mode: bit9=1, bit8=0
|
||||
pkt_marker = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 9)) # bit9=1, bit8=0
|
||||
self.assertTrue(pkt_marker.is_marker)
|
||||
|
||||
# other mode: bit9=1, bit8=1 (not marker)
|
||||
pkt_other = TS_DELTA_OR_MARK.from_raw(0b0000001 | (1 << 8) | (1 << 9))
|
||||
self.assertFalse(pkt_other.is_marker)
|
||||
|
||||
def test_reg(self):
|
||||
# REG fields: slot=bits[9:7], hi_byte=bits[15:8], subop=bits[31:16], val32=bits[63:32]
|
||||
# Note: slot[2:1] overlaps with hi_byte[1:0], so we need to set them consistently
|
||||
# hi_byte=0x55 means bits 8-15 = 0b01010101, so slot bits 8-9 = 0b01
|
||||
# slot bit 7 = 1, so slot = 0b011 = 3
|
||||
reg = (0xDEADBEEF << 32) | (0xCAFE << 16) | (0x55 << 8) | (1 << 7) | 0b1001
|
||||
pkt = REG.from_raw(reg)
|
||||
self.assertEqual(pkt.slot, 0b011) # bit7=1, bits 8-9 from hi_byte low 2 bits = 01
|
||||
self.assertEqual(pkt.hi_byte, 0x55)
|
||||
self.assertEqual(pkt.subop, 0xCAFE)
|
||||
self.assertEqual(pkt.val32, 0xDEADBEEF)
|
||||
|
||||
|
||||
class TestRoundtrip(unittest.TestCase):
|
||||
"""Test encode/decode roundtrip."""
|
||||
|
||||
def test_simple_roundtrip(self):
|
||||
"""Test encode/decode roundtrip preserves packet types."""
|
||||
test_packets = [
|
||||
LAYOUT_HEADER.from_raw(0x100),
|
||||
WAVESTART.from_raw(0x0),
|
||||
INST.from_raw(0x10),
|
||||
INST.from_raw(0x10),
|
||||
WAVEEND.from_raw(0x40),
|
||||
]
|
||||
encoded = encode(test_packets)
|
||||
decoded = decode(encoded)
|
||||
|
||||
self.assertGreaterEqual(len(decoded), len(test_packets))
|
||||
for i, (orig, dec) in enumerate(zip(test_packets, decoded)):
|
||||
self.assertEqual(type(orig), type(dec), f"type mismatch at {i}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -21,7 +21,7 @@ from tinygrad.runtime.support.memory import AddrSpace
|
||||
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
SQTT = ContextVar("SQTT", abs(VIZ.value)>=2)
|
||||
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0)
|
||||
SQTT_ITRACE_SE_MASK, SQTT_LIMIT_SE, SQTT_SIMD_SEL = ContextVar("SQTT_ITRACE_SE_MASK", 0b11), ContextVar("SQTT_LIMIT_SE", 0), ContextVar("SQTT_SIMD_SEL", 0)
|
||||
PMC = ContextVar("PMC", abs(VIZ.value)>=2)
|
||||
EVENT_INDEX_PARTIAL_FLUSH = 4 # based on a comment in nvd.h
|
||||
WAIT_REG_MEM_FUNCTION_EQ = 3 # ==
|
||||
@@ -251,14 +251,15 @@ class AMDComputeQueue(HWQueue):
|
||||
else:
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_SIZE, base_hi=buf0_hi, size=buf0s[se].size >> 12)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_BUF0_BASE, base_lo=buf0_lo)
|
||||
# NOTE: SQTT can only trace instructions on one simd per se, this selects first simd in first wgp in first sa.
|
||||
# NOTE: SQTT can only trace instructions on one simd per se, this selects the simd in first wgp in first sa.
|
||||
# For RGP to display instruction trace it has to see it on first SE. Howerver ACE/MEC/whatever does the dispatching starting with second se,
|
||||
# and on amdgpu/non-AM it also does weird things with dispatch order inside se: around 7 times out of 10 it starts from the last cu, but
|
||||
# sometimes not, especially if the kernel has more than one wavefront which means that kernels with small global size might get unlucky and
|
||||
# be dispatched on something else and not be seen in instruction tracing tab. You can force the wavefronts of a kernel to be dispatched on the
|
||||
# CUs you want to by disabling other CUs via bits in regCOMPUTE_STATIC_THREAD_MGMT_SE<x> and trace even kernels that only have one wavefront.
|
||||
# Use SQTT_SIMD_SEL (0-3) to select which SIMD to trace within the WGP.
|
||||
cs_wtype = (1 << 6) if self.dev.target >= (12,0,0) else self.soc.SQ_TT_WTYPE_INCLUDE_CS_BIT
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=0, wgp_sel=0, sa_sel=0)
|
||||
self.wreg(self.gc.regSQ_THREAD_TRACE_MASK, wtype_include=cs_wtype, simd_sel=SQTT_SIMD_SEL.value, wgp_sel=0, sa_sel=0)
|
||||
reg_include = self.soc.SQ_TT_TOKEN_MASK_SQDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_SHDEC_BIT | self.soc.SQ_TT_TOKEN_MASK_GFXUDEC_BIT | \
|
||||
self.soc.SQ_TT_TOKEN_MASK_COMP_BIT | self.soc.SQ_TT_TOKEN_MASK_CONTEXT_BIT
|
||||
token_exclude = (1 << self.soc.SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT) if self.dev.target < (12,0,0) else 0
|
||||
|
||||
Reference in New Issue
Block a user